From a8a34d56b95d2e7b0210bb385cf3ebbdee36684e Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Thu, 13 Jun 2024 02:37:45 +0000 Subject: [PATCH] [fix] inference with score --- configs/opensora-v1-2/inference/sample.py | 2 +- eval/sample.sh | 12 ++++++------ opensora/utils/inference_utils.py | 18 ++++++++++++++++++ scripts/inference.py | 19 ++++++++----------- scripts/inference_vae.py | 2 +- 5 files changed, 34 insertions(+), 19 deletions(-) diff --git a/configs/opensora-v1-2/inference/sample.py b/configs/opensora-v1-2/inference/sample.py index 68956f0..67f0e01 100644 --- a/configs/opensora-v1-2/inference/sample.py +++ b/configs/opensora-v1-2/inference/sample.py @@ -49,5 +49,5 @@ scheduler = dict( cfg_scale=7.0, ) -aes = None +aes = 6.5 flow = None diff --git a/eval/sample.sh b/eval/sample.sh index ed509d9..4df92a6 100644 --- a/eval/sample.sh +++ b/eval/sample.sh @@ -79,20 +79,20 @@ function run_video_a() { # ~ 30min ? # # sample, 240p, 9:16, 8s # eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 8s --resolution 240p --aspect-ratio 9:16 --sample-name sample_8s_240p_9_16 --batch-size $DEFAULT_BS # # sample, 480p, 9:16, 2s - # eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 2s --resoluton 480p --aspect-ratio 9:16 --sample-name sample_2s_480p_9_16 --batch-size $DEFAULT_BS + # eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 2s --resolution 480p --aspect-ratio 9:16 --sample-name sample_2s_480p_9_16 --batch-size $DEFAULT_BS # # sample, 480p, 9:16, 4s - # eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 4s --resoluton 480p --aspect-ratio 9:16 --sample-name sample_4s_480p_9_16 --batch-size $DEFAULT_BS + # eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 4s --resolution 480p --aspect-ratio 9:16 --sample-name sample_4s_480p_9_16 --batch-size $DEFAULT_BS # # sample, 720p, 9:16, 2s - # eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 2s --resoluton 720p --aspect-ratio 9:16 --sample-name sample_2s_720p_9_16 --batch-size $DEFAULT_BS + # eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 2s --resolution 720p --aspect-ratio 9:16 --sample-name sample_2s_720p_9_16 --batch-size $DEFAULT_BS # sample, 720p, 9:16, 2s - eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 2s --resoluton 720p --aspect-ratio 9:16 --sample-name sample_2s_720p_9_16 --batch-size $DEFAULT_BS + eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 2s --resolution 720p --aspect-ratio 9:16 --sample-name sample_2s_720p_9_16 --batch-size $DEFAULT_BS # sample, 480p, 9:16, 8s - eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 8s --resoluton 480p --aspect-ratio 9:16 --sample-name sample_8s_480p_9_16 --batch-size $DEFAULT_BS + eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 8s --resolution 480p --aspect-ratio 9:16 --sample-name sample_8s_480p_9_16 --batch-size $DEFAULT_BS # sample, 240p, 9:16, 16s - eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 16s --resoluton 240p --aspect-ratio 9:16 --sample-name sample_16s_240p_9_16 --batch-size $DEFAULT_BS + eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 16s --resolution 240p --aspect-ratio 9:16 --sample-name sample_16s_240p_9_16 --batch-size $DEFAULT_BS } # for (short, sora, sample) diff --git a/opensora/utils/inference_utils.py b/opensora/utils/inference_utils.py index d2ec784..d13c5ea 100644 --- a/opensora/utils/inference_utils.py +++ b/opensora/utils/inference_utils.py @@ -52,6 +52,24 @@ def get_save_path_name( return save_path +def append_score_to_prompts(prompts, aes=None, flow=None): + score_prompts = [] + if aes is not None: + score_prompts.append(f"aesthetic score: {aes:.1f}") + if flow is not None: + score_prompts.append(f"motion score: {flow:.1f}") + if len(score_prompts) > 0: + score_text = ", ".join(score_prompts) + new_prompts = [] + for prompt in prompts: + if "score:" in prompt: + new_prompts.append(prompt) + else: + new_prompts.append(f"{prompt} {score_text}") + return new_prompts + return prompts + + def extract_json_from_prompts(prompts, reference, mask_strategy): ret_prompts = [] for i, prompt in enumerate(prompts): diff --git a/scripts/inference.py b/scripts/inference.py index aeac0df..6317701 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -16,14 +16,15 @@ from opensora.registry import MODELS, SCHEDULERS, build_module from opensora.utils.config_utils import parse_configs from opensora.utils.inference_utils import ( append_generated, + append_score_to_prompts, apply_mask_strategy, collect_references_batch, + dframe_to_frame, extract_json_from_prompts, extract_prompts_loop, get_save_path_name, load_prompts, prepare_multi_resolution_info, - dframe_to_frame, ) from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype @@ -58,7 +59,7 @@ def main(): # == init logger == logger = create_logger() - logger.info("Training configuration:\n %s", pformat(cfg.to_dict())) + logger.info("Inference configuration:\n %s", pformat(cfg.to_dict())) verbose = cfg.get("verbose", 1) progress_wrap = tqdm if verbose == 1 else (lambda x: x) @@ -111,14 +112,6 @@ def main(): if prompts is None: assert cfg.get("prompt_path", None) is not None, "Prompt or prompt_path must be provided" prompts = load_prompts(cfg.prompt_path, start_idx, cfg.get("end_index", None)) - score_prompts = [] - if cfg.get("aes", None) is not None: - score_prompts.append(f"{prompt} aesthetic score: {cfg.aes:.1f}" for prompt in prompts) - if cfg.get("flow", None) is not None: - score_prompts.append(f"{prompt} motion score: {cfg.flow:.1f}" for prompt in prompts) - if len(score_prompts) > 0: - score_text = ", ".join(score_prompts) - prompts = [f"{prompt} [{score_text}]" for prompt in prompts] # == prepare reference == reference_path = cfg.get("reference_path", [""] * len(prompts)) @@ -149,8 +142,12 @@ def main(): refs = reference_path[i : i + batch_size] batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms) + original_prompts = batch_prompts refs = collect_references_batch(refs, vae, image_size) + # == score == + batch_prompts = append_score_to_prompts(batch_prompts, aes=cfg.get("aes", None), flow=cfg.get("flow", None)) + # == multi-resolution info == model_args = prepare_multi_resolution_info( multi_resolution, len(batch_prompts), image_size, num_frames, fps, device, dtype @@ -164,7 +161,7 @@ def main(): save_dir, sample_name=sample_name, sample_idx=start_idx + idx, - prompt=batch_prompts[idx], + prompt=original_prompts[idx], prompt_as_path=prompt_as_path, num_sample=num_sample, k=k, diff --git a/scripts/inference_vae.py b/scripts/inference_vae.py index f4542e3..3580b73 100644 --- a/scripts/inference_vae.py +++ b/scripts/inference_vae.py @@ -38,7 +38,7 @@ def main(): # == init logger == logger = create_logger() - logger.info("Training configuration:\n %s", pformat(cfg.to_dict())) + logger.info("Inference configuration:\n %s", pformat(cfg.to_dict())) verbose = cfg.get("verbose", 1) # ======================================================