diff --git a/configs/opensora-v1-2/inference/sample.py b/configs/opensora-v1-2/inference/sample.py index 3e2c623..49efca2 100644 --- a/configs/opensora-v1-2/inference/sample.py +++ b/configs/opensora-v1-2/inference/sample.py @@ -19,6 +19,7 @@ model = dict( qk_norm=True, enable_flash_attn=True, enable_layernorm_kernel=True, + force_huggingface=True, ) vae = dict( type="OpenSoraVAE_V1_2", diff --git a/opensora/models/stdit/stdit3.py b/opensora/models/stdit/stdit3.py index e857a37..e6dc8b4 100644 --- a/opensora/models/stdit/stdit3.py +++ b/opensora/models/stdit/stdit3.py @@ -458,9 +458,9 @@ def STDiT3_XL_2(from_pretrained=None, **kwargs): @MODELS.register_module("STDiT3-3B/2") -def STDiT3_3B_2(from_pretrained=None, **kwargs): +def STDiT3_3B_2(from_pretrained=None, force_huggingface=True, **kwargs): # check if from_pretrained is a path - if from_pretrained is not None and not os.path.isdir(from_pretrained): + if force_huggingface or (from_pretrained is not None and not os.path.isdir(from_pretrained)): model = STDiT3.from_pretrained(from_pretrained, **kwargs) else: config = STDiT3Config(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs)