mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 12:49:38 +02:00
Docs/fix zangwei (#474)
* [docs] fix training data num * [docs] update sp * add support for issue #470
This commit is contained in:
parent
a0ff255d28
commit
403772eee1
|
|
@ -19,6 +19,7 @@ model = dict(
|
|||
qk_norm=True,
|
||||
enable_flash_attn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
force_huggingface=True,
|
||||
)
|
||||
vae = dict(
|
||||
type="OpenSoraVAE_V1_2",
|
||||
|
|
|
|||
|
|
@ -458,9 +458,9 @@ def STDiT3_XL_2(from_pretrained=None, **kwargs):
|
|||
|
||||
|
||||
@MODELS.register_module("STDiT3-3B/2")
|
||||
def STDiT3_3B_2(from_pretrained=None, **kwargs):
|
||||
def STDiT3_3B_2(from_pretrained=None, force_huggingface=True, **kwargs):
|
||||
# check if from_pretrained is a path
|
||||
if from_pretrained is not None and not os.path.isdir(from_pretrained):
|
||||
if force_huggingface or (from_pretrained is not None and not os.path.isdir(from_pretrained)):
|
||||
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
|
||||
else:
|
||||
config = STDiT3Config(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue