mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
Merge pull request #531 from hpcaitech/hotfix/hf-load
[fix] better support local ckpt
This commit is contained in:
commit
ea44eb6b9e
|
|
@ -448,7 +448,7 @@ class STDiT3(PreTrainedModel):
|
|||
@MODELS.register_module("STDiT3-XL/2")
|
||||
def STDiT3_XL_2(from_pretrained=None, **kwargs):
|
||||
force_huggingface = kwargs.pop("force_huggingface", False)
|
||||
if force_huggingface or from_pretrained is not None and not os.path.isdir(from_pretrained):
|
||||
if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained):
|
||||
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
|
||||
else:
|
||||
config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
|
||||
|
|
@ -460,7 +460,8 @@ def STDiT3_XL_2(from_pretrained=None, **kwargs):
|
|||
|
||||
@MODELS.register_module("STDiT3-3B/2")
|
||||
def STDiT3_3B_2(from_pretrained=None, **kwargs):
|
||||
if from_pretrained is not None and not os.path.isdir(from_pretrained):
|
||||
force_huggingface = kwargs.pop("force_huggingface", False)
|
||||
if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained):
|
||||
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
|
||||
else:
|
||||
config = STDiT3Config(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs)
|
||||
|
|
|
|||
|
|
@ -277,7 +277,7 @@ def OpenSoraVAE_V1_2(
|
|||
scale=scale,
|
||||
)
|
||||
|
||||
if force_huggingface or (from_pretrained is not None and not os.path.isdir(from_pretrained)):
|
||||
if force_huggingface or (from_pretrained is not None and not os.path.exists(from_pretrained)):
|
||||
model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs)
|
||||
else:
|
||||
config = VideoAutoencoderPipelineConfig(**kwargs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue