[fix] HF loading

This commit is contained in:
zhengzangw 2024-06-22 15:41:32 +00:00
parent 6893e9464b
commit 7115864314
5 changed files with 49 additions and 6 deletions

View file

@ -19,14 +19,12 @@ model = dict(
qk_norm=True,
enable_flash_attn=True,
enable_layernorm_kernel=True,
force_huggingface=True,
)
vae = dict(
type="OpenSoraVAE_V1_2",
from_pretrained="hpcai-tech/OpenSora-VAE-v1.2",
micro_frame_size=17,
micro_batch_size=4,
force_huggingface=True,
)
text_encoder = dict(
type="t5",

View file

@ -0,0 +1,44 @@
resolution = "240p"
aspect_ratio = "9:16"
num_frames = 51
fps = 24
frame_interval = 1
save_fps = 24
save_dir = "./samples/samples/"
seed = 42
batch_size = 1
multi_resolution = "STDiT2"
dtype = "bf16"
condition_frame_length = 5
align = 5
model = dict(
type="STDiT3-XL/2",
from_pretrained="hpcai-tech/OpenSora-STDiT-v3",
qk_norm=True,
enable_flash_attn=True,
enable_layernorm_kernel=True,
force_huggingface=True,
)
vae = dict(
type="OpenSoraVAE_V1_2",
from_pretrained="hpcai-tech/OpenSora-VAE-v1.2",
micro_frame_size=17,
micro_batch_size=4,
force_huggingface=True,
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=300,
)
scheduler = dict(
type="rflow",
use_timestep_transform=True,
num_sampling_steps=30,
cfg_scale=7.0,
)
aes = 6.5
flow = None

View file

@ -39,7 +39,7 @@ DEFAULT_BS=1
# called inside run_video_b
function run_image() { # 14min
# 1.1 1024x1024
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2i_samples.txt --save-dir $OUTPUT --num-frames 1 --resolution 1024 --aspect_ratio 1:1 --sample-name image_1024_1_1 --batch-size $DEFAULT_BS
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2i_samples.txt --save-dir $OUTPUT --num-frames 1 --resolution 1024 --aspect-ratio 1:1 --sample-name image_1024_1_1 --batch-size $DEFAULT_BS
# 1.2 240x426
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2i_samples.txt --save-dir $OUTPUT --num-frames 1 --resolution 240p --aspect-ratio 9:16 --sample-name image_240p_9_16 --end-index 3 --batch-size $DEFAULT_BS

View file

@ -448,7 +448,7 @@ class STDiT3(PreTrainedModel):
@MODELS.register_module("STDiT3-XL/2")
def STDiT3_XL_2(from_pretrained=None, **kwargs):
force_huggingface = kwargs.pop("force_huggingface", False)
if force_huggingface or from_pretrained is not None and not os.path.isdir(from_pretrained):
if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained):
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
else:
config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
@ -460,7 +460,8 @@ def STDiT3_XL_2(from_pretrained=None, **kwargs):
@MODELS.register_module("STDiT3-3B/2")
def STDiT3_3B_2(from_pretrained=None, **kwargs):
if from_pretrained is not None and not os.path.isdir(from_pretrained):
force_huggingface = kwargs.pop("force_huggingface", False)
if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained):
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
else:
config = STDiT3Config(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs)

View file

@ -277,7 +277,7 @@ def OpenSoraVAE_V1_2(
scale=scale,
)
if force_huggingface or (from_pretrained is not None and not os.path.isdir(from_pretrained)):
if force_huggingface or (from_pretrained is not None and not os.path.exists(from_pretrained)):
model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs)
else:
config = VideoAutoencoderPipelineConfig(**kwargs)