Docs/fix zangwei (#474)

* [docs] fix training data num

* [docs] update sp

* add support for issue #470
This commit is contained in:
Zheng Zangwei (Alex Zheng) 2024-06-19 16:53:53 +08:00 committed by GitHub
parent a0ff255d28
commit 403772eee1
2 changed files with 3 additions and 2 deletions

View file

@ -19,6 +19,7 @@ model = dict(
qk_norm=True,
enable_flash_attn=True,
enable_layernorm_kernel=True,
force_huggingface=True,
)
vae = dict(
type="OpenSoraVAE_V1_2",

View file

@ -458,9 +458,9 @@ def STDiT3_XL_2(from_pretrained=None, **kwargs):
@MODELS.register_module("STDiT3-3B/2")
def STDiT3_3B_2(from_pretrained=None, **kwargs):
def STDiT3_3B_2(from_pretrained=None, force_huggingface=True, **kwargs):
# check if from_pretrained is a path
if from_pretrained is not None and not os.path.isdir(from_pretrained):
if force_huggingface or (from_pretrained is not None and not os.path.isdir(from_pretrained)):
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
else:
config = STDiT3Config(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs)