[fix] condition sampling

This commit is contained in:
zhengzangw 2024-06-11 07:49:00 +00:00
parent 0532a4c962
commit a271faa131
3 changed files with 9 additions and 3 deletions

View file

@ -169,14 +169,14 @@ function run_video_h() { # 61min
eval $CMD --ckpt-path $CKPT --save-dir $OUTPUT --sample-name ref_L10C4_${NUM_FRAMES}x240x426 \
--prompt-path assets/texts/t2v_ref.txt --start-index 0 --end-index 3 \
--num-frames $NUM_FRAMES --image-size 240 426 \
--loop 5 --condition-frame-length 15 \
--loop 5 --condition-frame-length 5 \
--reference-path assets/images/condition/cliff.png assets/images/condition/wave.png assets/images/condition/ship.png \
--mask-strategy "0" "0" "0" --batch-size $DEFAULT_BS
eval $CMD --ckpt-path $CKPT --save-dir $OUTPUT --sample-name ref_L10C4_${QUAD_FRAMES}x240x426 \
--prompt-path assets/texts/t2v_ref.txt --start-index 0 --end-index 3 \
--num-frames $QUAD_FRAMES --image-size 240 426 \
--loop 5 --condition-frame-length 60 \
--loop 5 --condition-frame-length 10 \
--reference-path assets/images/condition/cliff.png assets/images/condition/wave.png assets/images/condition/ship.png \
--mask-strategy "0" "0" "0" --batch-size $DEFAULT_BS

View file

@ -175,3 +175,8 @@ def append_generated(vae, generated_video, refs_x, mask_strategy, loop_i, condit
mask_strategy[j] += ";"
mask_strategy[j] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length}"
return refs_x, mask_strategy
def dframe_to_frame(num):
assert num % 5 == 0, f"Invalid num: {num}"
return num // 5 * 17

View file

@ -23,6 +23,7 @@ from opensora.utils.inference_utils import (
get_save_path_name,
load_prompts,
prepare_multi_resolution_info,
dframe_to_frame,
)
from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype
@ -210,7 +211,7 @@ def main():
save_path = save_paths[idx]
video = [video_clips[i][idx] for i in range(loop)]
for i in range(1, loop):
video[i] = video[i][:, condition_frame_length:]
video[i] = video[i][:, dframe_to_frame(condition_frame_length) :]
video = torch.cat(video, dim=1)
save_sample(
video,