update config

This commit is contained in:
zhengzangw 2024-04-22 03:43:31 +00:00
parent 9686395ade
commit 8931efccd7
7 changed files with 34 additions and 30 deletions

View file

@ -32,6 +32,7 @@ model = dict(
type="STDiT2-XL/2",
from_pretrained=None,
input_sq_size=512,
qk_norm=True,
enable_flashattn=True,
enable_layernorm_kernel=True,
)

View file

@ -9,7 +9,7 @@ model = dict(
type="STDiT2-XL/2",
from_pretrained=None,
input_sq_size=512,
# qk_norm=True,
qk_norm=True,
enable_flashattn=True,
enable_layernorm_kernel=True,
)

View file

@ -17,15 +17,15 @@ bucket_config = { # 6s/it
"1080p": {1: (0.4, 8)},
}
mask_ratios = {
"mask_no": 0.9,
"mask_quarter_random": 0.01,
"mask_quarter_head": 0.01,
"mask_quarter_tail": 0.01,
"mask_quarter_head_tail": 0.02,
"mask_image_random": 0.01,
"mask_image_head": 0.01,
"mask_image_tail": 0.01,
"mask_image_head_tail": 0.02,
"mask_no": 0.75,
"mask_quarter_random": 0.025,
"mask_quarter_head": 0.025,
"mask_quarter_tail": 0.025,
"mask_quarter_head_tail": 0.05,
"mask_image_random": 0.025,
"mask_image_head": 0.025,
"mask_image_tail": 0.025,
"mask_image_head_tail": 0.05,
}
# Define acceleration
@ -41,6 +41,7 @@ model = dict(
type="STDiT2-XL/2",
from_pretrained=None,
input_sq_size=512, # pretrained model is trained on 512x512
qk_norm=True,
enable_flashattn=True,
enable_layernorm_kernel=True,
)
@ -48,15 +49,17 @@ vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
micro_batch_size=4,
local_files_only=True,
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=200,
shardformer=True,
local_files_only=True,
)
scheduler = dict(
type="iddpm-speed",
type="iddpm",
timestep_respacing="",
)

View file

@ -14,23 +14,23 @@ LOG_BASE=logs/sample/$CKPT_BASE
echo "Logging to $LOG_BASE"
# == sample & human evaluation ==
# CUDA_VISIBLE_DEVICES=0 bash eval/sample.sh $CKPT -1 >${LOG_BASE}_1.log 2>&1 &
# CUDA_VISIBLE_DEVICES=1 bash eval/sample.sh $CKPT -3 >${LOG_BASE}_3.log 2>&1 &
# CUDA_VISIBLE_DEVICES=2 bash eval/sample.sh $CKPT -2a >${LOG_BASE}_2a.log 2>&1 &
# CUDA_VISIBLE_DEVICES=3 bash eval/sample.sh $CKPT -2b >${LOG_BASE}_2b.log 2>&1 &
# CUDA_VISIBLE_DEVICES=4 bash eval/sample.sh $CKPT -2c >${LOG_BASE}_2c.log 2>&1 &
# CUDA_VISIBLE_DEVICES=5 bash eval/sample.sh $CKPT -2d >${LOG_BASE}_2d.log 2>&1 &
# CUDA_VISIBLE_DEVICES=6 bash eval/sample.sh $CKPT -2e >${LOG_BASE}_2e.log 2>&1 &
# CUDA_VISIBLE_DEVICES=7 bash eval/sample.sh $CKPT -2f >${LOG_BASE}_2f.log 2>&1 &
CUDA_VISIBLE_DEVICES=0 bash eval/sample.sh $CKPT -1 >${LOG_BASE}_1.log 2>&1 &
CUDA_VISIBLE_DEVICES=1 bash eval/sample.sh $CKPT -3 >${LOG_BASE}_3.log 2>&1 &
CUDA_VISIBLE_DEVICES=2 bash eval/sample.sh $CKPT -2a >${LOG_BASE}_2a.log 2>&1 &
CUDA_VISIBLE_DEVICES=3 bash eval/sample.sh $CKPT -2b >${LOG_BASE}_2b.log 2>&1 &
CUDA_VISIBLE_DEVICES=4 bash eval/sample.sh $CKPT -2c >${LOG_BASE}_2c.log 2>&1 &
CUDA_VISIBLE_DEVICES=5 bash eval/sample.sh $CKPT -2d >${LOG_BASE}_2d.log 2>&1 &
CUDA_VISIBLE_DEVICES=6 bash eval/sample.sh $CKPT -2e >${LOG_BASE}_2e.log 2>&1 &
CUDA_VISIBLE_DEVICES=7 bash eval/sample.sh $CKPT -2f >${LOG_BASE}_2f.log 2>&1 &
# == vbench ==
CUDA_VISIBLE_DEVICES=0 bash eval/sample.sh $CKPT -4a >${LOG_BASE}_4a.log 2>&1 &
CUDA_VISIBLE_DEVICES=1 bash eval/sample.sh $CKPT -4b >${LOG_BASE}_4b.log 2>&1 &
CUDA_VISIBLE_DEVICES=2 bash eval/sample.sh $CKPT -4c >${LOG_BASE}_4c.log 2>&1 &
CUDA_VISIBLE_DEVICES=3 bash eval/sample.sh $CKPT -4d >${LOG_BASE}_4d.log 2>&1 &
CUDA_VISIBLE_DEVICES=4 bash eval/sample.sh $CKPT -4e >${LOG_BASE}_4e.log 2>&1 &
CUDA_VISIBLE_DEVICES=5 bash eval/sample.sh $CKPT -4f >${LOG_BASE}_4f.log 2>&1 &
CUDA_VISIBLE_DEVICES=6 bash eval/sample.sh $CKPT -4g >${LOG_BASE}_4g.log 2>&1 &
CUDA_VISIBLE_DEVICES=7 bash eval/sample.sh $CKPT -4h >${LOG_BASE}_4h.log 2>&1 &
# CUDA_VISIBLE_DEVICES=0 bash eval/sample.sh $CKPT -4a >${LOG_BASE}_4a.log 2>&1 &
# CUDA_VISIBLE_DEVICES=1 bash eval/sample.sh $CKPT -4b >${LOG_BASE}_4b.log 2>&1 &
# CUDA_VISIBLE_DEVICES=2 bash eval/sample.sh $CKPT -4c >${LOG_BASE}_4c.log 2>&1 &
# CUDA_VISIBLE_DEVICES=3 bash eval/sample.sh $CKPT -4d >${LOG_BASE}_4d.log 2>&1 &
# CUDA_VISIBLE_DEVICES=4 bash eval/sample.sh $CKPT -4e >${LOG_BASE}_4e.log 2>&1 &
# CUDA_VISIBLE_DEVICES=5 bash eval/sample.sh $CKPT -4f >${LOG_BASE}_4f.log 2>&1 &
# CUDA_VISIBLE_DEVICES=6 bash eval/sample.sh $CKPT -4g >${LOG_BASE}_4g.log 2>&1 &
# CUDA_VISIBLE_DEVICES=7 bash eval/sample.sh $CKPT -4h >${LOG_BASE}_4h.log 2>&1 &
# kill all by: pkill -f "inference"

View file

@ -1,6 +1,6 @@
#!/bin/bash
set -x
# set -x
set -e
CKPT=$1
@ -86,7 +86,7 @@ function run_video_c() { # 30min
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_sora.txt --save-dir $OUTPUT --num-frames 16 --image-size 240 426 --sample-name sora_16x240x426
# 2.3.2 128x240x426
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_short.txt --save-dir $OUTPUT --num-frames 128 --image-size 240 426 --sample-name short_128x240x426
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_sora.txt --save-dir $OUTPUT --num-frames 48 --image-size 256 256 --sample-name sora_48x256x256
}
function run_video_d() { # 30min

View file

@ -256,7 +256,7 @@ def create_logger(logging_dir):
return logger
def load_checkpoint(model, ckpt_path, save_as_pt=True):
def load_checkpoint(model, ckpt_path, save_as_pt=False):
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
state_dict = find_model(ckpt_path, model=model)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)