Merge pull request #531 from hpcaitech/hotfix/hf-load

[fix] better support local ckpt
This commit is contained in:
Zheng Zangwei (Alex Zheng) 2024-06-22 23:54:47 +08:00 committed by GitHub
commit ea44eb6b9e
2 changed files with 4 additions and 3 deletions

View file

@ -448,7 +448,7 @@ class STDiT3(PreTrainedModel):
@MODELS.register_module("STDiT3-XL/2")
def STDiT3_XL_2(from_pretrained=None, **kwargs):
force_huggingface = kwargs.pop("force_huggingface", False)
if force_huggingface or from_pretrained is not None and not os.path.isdir(from_pretrained):
if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained):
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
else:
config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
@ -460,7 +460,8 @@ def STDiT3_XL_2(from_pretrained=None, **kwargs):
@MODELS.register_module("STDiT3-3B/2")
def STDiT3_3B_2(from_pretrained=None, **kwargs):
if from_pretrained is not None and not os.path.isdir(from_pretrained):
force_huggingface = kwargs.pop("force_huggingface", False)
if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained):
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
else:
config = STDiT3Config(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs)

View file

@ -277,7 +277,7 @@ def OpenSoraVAE_V1_2(
scale=scale,
)
if force_huggingface or (from_pretrained is not None and not os.path.isdir(from_pretrained)):
if force_huggingface or (from_pretrained is not None and not os.path.exists(from_pretrained)):
model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs)
else:
config = VideoAutoencoderPipelineConfig(**kwargs)