[gradio] hotfix stdit model init

This commit is contained in:
ver217 2024-06-21 10:39:18 +08:00
parent 578438e0ee
commit b3f746126c

View file

@ -100,7 +100,7 @@ def build_models(model_type, config, enable_optimization=False):
# handle model download logic in HuggingFace Space
from opensora.models.stdit.stdit3 import STDiT3
model_kwargs = {k: v for k, v in config.model.items() if k not in ("type", "from_pretrained")}
model_kwargs = {k: v for k, v in config.model.items() if k not in ("type", "from_pretrained", "force_huggingface")}
stdit = STDiT3.from_pretrained(HF_STDIT_MAP[model_type], **model_kwargs)
stdit = stdit.cuda()