diff --git a/eval/sample.sh b/eval/sample.sh index cd765ee..d52b667 100644 --- a/eval/sample.sh +++ b/eval/sample.sh @@ -169,14 +169,14 @@ function run_video_h() { # 61min eval $CMD --ckpt-path $CKPT --save-dir $OUTPUT --sample-name ref_L10C4_${NUM_FRAMES}x240x426 \ --prompt-path assets/texts/t2v_ref.txt --start-index 0 --end-index 3 \ --num-frames $NUM_FRAMES --image-size 240 426 \ - --loop 5 --condition-frame-length 15 \ + --loop 5 --condition-frame-length 5 \ --reference-path assets/images/condition/cliff.png assets/images/condition/wave.png assets/images/condition/ship.png \ --mask-strategy "0" "0" "0" --batch-size $DEFAULT_BS eval $CMD --ckpt-path $CKPT --save-dir $OUTPUT --sample-name ref_L10C4_${QUAD_FRAMES}x240x426 \ --prompt-path assets/texts/t2v_ref.txt --start-index 0 --end-index 3 \ --num-frames $QUAD_FRAMES --image-size 240 426 \ - --loop 5 --condition-frame-length 60 \ + --loop 5 --condition-frame-length 10 \ --reference-path assets/images/condition/cliff.png assets/images/condition/wave.png assets/images/condition/ship.png \ --mask-strategy "0" "0" "0" --batch-size $DEFAULT_BS diff --git a/opensora/utils/inference_utils.py b/opensora/utils/inference_utils.py index beb07e4..d2ec784 100644 --- a/opensora/utils/inference_utils.py +++ b/opensora/utils/inference_utils.py @@ -175,3 +175,8 @@ def append_generated(vae, generated_video, refs_x, mask_strategy, loop_i, condit mask_strategy[j] += ";" mask_strategy[j] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length}" return refs_x, mask_strategy + + +def dframe_to_frame(num): + assert num % 5 == 0, f"Invalid num: {num}" + return num // 5 * 17 diff --git a/scripts/inference.py b/scripts/inference.py index 982994c..aeac0df 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -23,6 +23,7 @@ from opensora.utils.inference_utils import ( get_save_path_name, load_prompts, prepare_multi_resolution_info, + dframe_to_frame, ) from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype @@ -210,7 +211,7 @@ def main(): save_path = save_paths[idx] video = [video_clips[i][idx] for i in range(loop)] for i in range(1, loop): - video[i] = video[i][:, condition_frame_length:] + video[i] = video[i][:, dframe_to_frame(condition_frame_length) :] video = torch.cat(video, dim=1) save_sample( video,