[fix] better support local ckpt

This commit is contained in:
zhengzangw 2024-06-22 15:54:27 +00:00
parent 3cd29bd488
commit 9a9a6c2f3e
2 changed files with 4 additions and 3 deletions

View file

@ -448,7 +448,7 @@ class STDiT3(PreTrainedModel):
@MODELS.register_module("STDiT3-XL/2")
def STDiT3_XL_2(from_pretrained=None, **kwargs):
force_huggingface = kwargs.pop("force_huggingface", False)
if force_huggingface or from_pretrained is not None and not os.path.isdir(from_pretrained):
if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained):
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
else:
config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
@ -460,7 +460,8 @@ def STDiT3_XL_2(from_pretrained=None, **kwargs):
@MODELS.register_module("STDiT3-3B/2")
def STDiT3_3B_2(from_pretrained=None, **kwargs):
if from_pretrained is not None and not os.path.isdir(from_pretrained):
force_huggingface = kwargs.pop("force_huggingface", False)
if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained):
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
else:
config = STDiT3Config(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs)

View file

@ -277,7 +277,7 @@ def OpenSoraVAE_V1_2(
scale=scale,
)
if force_huggingface or (from_pretrained is not None and not os.path.isdir(from_pretrained)):
if force_huggingface or (from_pretrained is not None and not os.path.exists(from_pretrained)):
model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs)
else:
config = VideoAutoencoderPipelineConfig(**kwargs)