mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
[fix] condition sampling
This commit is contained in:
parent
0532a4c962
commit
a271faa131
|
|
@ -169,14 +169,14 @@ function run_video_h() { # 61min
|
|||
eval $CMD --ckpt-path $CKPT --save-dir $OUTPUT --sample-name ref_L10C4_${NUM_FRAMES}x240x426 \
|
||||
--prompt-path assets/texts/t2v_ref.txt --start-index 0 --end-index 3 \
|
||||
--num-frames $NUM_FRAMES --image-size 240 426 \
|
||||
--loop 5 --condition-frame-length 15 \
|
||||
--loop 5 --condition-frame-length 5 \
|
||||
--reference-path assets/images/condition/cliff.png assets/images/condition/wave.png assets/images/condition/ship.png \
|
||||
--mask-strategy "0" "0" "0" --batch-size $DEFAULT_BS
|
||||
|
||||
eval $CMD --ckpt-path $CKPT --save-dir $OUTPUT --sample-name ref_L10C4_${QUAD_FRAMES}x240x426 \
|
||||
--prompt-path assets/texts/t2v_ref.txt --start-index 0 --end-index 3 \
|
||||
--num-frames $QUAD_FRAMES --image-size 240 426 \
|
||||
--loop 5 --condition-frame-length 60 \
|
||||
--loop 5 --condition-frame-length 10 \
|
||||
--reference-path assets/images/condition/cliff.png assets/images/condition/wave.png assets/images/condition/ship.png \
|
||||
--mask-strategy "0" "0" "0" --batch-size $DEFAULT_BS
|
||||
|
||||
|
|
|
|||
|
|
@ -175,3 +175,8 @@ def append_generated(vae, generated_video, refs_x, mask_strategy, loop_i, condit
|
|||
mask_strategy[j] += ";"
|
||||
mask_strategy[j] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length}"
|
||||
return refs_x, mask_strategy
|
||||
|
||||
|
||||
def dframe_to_frame(num):
|
||||
assert num % 5 == 0, f"Invalid num: {num}"
|
||||
return num // 5 * 17
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from opensora.utils.inference_utils import (
|
|||
get_save_path_name,
|
||||
load_prompts,
|
||||
prepare_multi_resolution_info,
|
||||
dframe_to_frame,
|
||||
)
|
||||
from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype
|
||||
|
||||
|
|
@ -210,7 +211,7 @@ def main():
|
|||
save_path = save_paths[idx]
|
||||
video = [video_clips[i][idx] for i in range(loop)]
|
||||
for i in range(1, loop):
|
||||
video[i] = video[i][:, condition_frame_length:]
|
||||
video[i] = video[i][:, dframe_to_frame(condition_frame_length) :]
|
||||
video = torch.cat(video, dim=1)
|
||||
save_sample(
|
||||
video,
|
||||
|
|
|
|||
Loading…
Reference in a new issue