[fix] inference with score

This commit is contained in:
zhengzangw 2024-06-13 02:37:45 +00:00
parent 51cdd377cc
commit a8a34d56b9
5 changed files with 34 additions and 19 deletions

View file

@ -49,5 +49,5 @@ scheduler = dict(
cfg_scale=7.0,
)
aes = None
aes = 6.5
flow = None

View file

@ -79,20 +79,20 @@ function run_video_a() { # ~ 30min ?
# # sample, 240p, 9:16, 8s
# eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 8s --resolution 240p --aspect-ratio 9:16 --sample-name sample_8s_240p_9_16 --batch-size $DEFAULT_BS
# # sample, 480p, 9:16, 2s
# eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 2s --resoluton 480p --aspect-ratio 9:16 --sample-name sample_2s_480p_9_16 --batch-size $DEFAULT_BS
# eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 2s --resolution 480p --aspect-ratio 9:16 --sample-name sample_2s_480p_9_16 --batch-size $DEFAULT_BS
# # sample, 480p, 9:16, 4s
# eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 4s --resoluton 480p --aspect-ratio 9:16 --sample-name sample_4s_480p_9_16 --batch-size $DEFAULT_BS
# eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 4s --resolution 480p --aspect-ratio 9:16 --sample-name sample_4s_480p_9_16 --batch-size $DEFAULT_BS
# # sample, 720p, 9:16, 2s
# eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 2s --resoluton 720p --aspect-ratio 9:16 --sample-name sample_2s_720p_9_16 --batch-size $DEFAULT_BS
# eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 2s --resolution 720p --aspect-ratio 9:16 --sample-name sample_2s_720p_9_16 --batch-size $DEFAULT_BS
# sample, 720p, 9:16, 2s
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 2s --resoluton 720p --aspect-ratio 9:16 --sample-name sample_2s_720p_9_16 --batch-size $DEFAULT_BS
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 2s --resolution 720p --aspect-ratio 9:16 --sample-name sample_2s_720p_9_16 --batch-size $DEFAULT_BS
# sample, 480p, 9:16, 8s
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 8s --resoluton 480p --aspect-ratio 9:16 --sample-name sample_8s_480p_9_16 --batch-size $DEFAULT_BS
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 8s --resolution 480p --aspect-ratio 9:16 --sample-name sample_8s_480p_9_16 --batch-size $DEFAULT_BS
# sample, 240p, 9:16, 16s
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 16s --resoluton 240p --aspect-ratio 9:16 --sample-name sample_16s_240p_9_16 --batch-size $DEFAULT_BS
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2v_samples.txt --save-dir $OUTPUT --num-frames 16s --resolution 240p --aspect-ratio 9:16 --sample-name sample_16s_240p_9_16 --batch-size $DEFAULT_BS
}
# for (short, sora, sample)

View file

@ -52,6 +52,24 @@ def get_save_path_name(
return save_path
def append_score_to_prompts(prompts, aes=None, flow=None):
score_prompts = []
if aes is not None:
score_prompts.append(f"aesthetic score: {aes:.1f}")
if flow is not None:
score_prompts.append(f"motion score: {flow:.1f}")
if len(score_prompts) > 0:
score_text = ", ".join(score_prompts)
new_prompts = []
for prompt in prompts:
if "score:" in prompt:
new_prompts.append(prompt)
else:
new_prompts.append(f"{prompt} {score_text}")
return new_prompts
return prompts
def extract_json_from_prompts(prompts, reference, mask_strategy):
ret_prompts = []
for i, prompt in enumerate(prompts):

View file

@ -16,14 +16,15 @@ from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.inference_utils import (
append_generated,
append_score_to_prompts,
apply_mask_strategy,
collect_references_batch,
dframe_to_frame,
extract_json_from_prompts,
extract_prompts_loop,
get_save_path_name,
load_prompts,
prepare_multi_resolution_info,
dframe_to_frame,
)
from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype
@ -58,7 +59,7 @@ def main():
# == init logger ==
logger = create_logger()
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
logger.info("Inference configuration:\n %s", pformat(cfg.to_dict()))
verbose = cfg.get("verbose", 1)
progress_wrap = tqdm if verbose == 1 else (lambda x: x)
@ -111,14 +112,6 @@ def main():
if prompts is None:
assert cfg.get("prompt_path", None) is not None, "Prompt or prompt_path must be provided"
prompts = load_prompts(cfg.prompt_path, start_idx, cfg.get("end_index", None))
score_prompts = []
if cfg.get("aes", None) is not None:
score_prompts.append(f"{prompt} aesthetic score: {cfg.aes:.1f}" for prompt in prompts)
if cfg.get("flow", None) is not None:
score_prompts.append(f"{prompt} motion score: {cfg.flow:.1f}" for prompt in prompts)
if len(score_prompts) > 0:
score_text = ", ".join(score_prompts)
prompts = [f"{prompt} [{score_text}]" for prompt in prompts]
# == prepare reference ==
reference_path = cfg.get("reference_path", [""] * len(prompts))
@ -149,8 +142,12 @@ def main():
refs = reference_path[i : i + batch_size]
batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms)
original_prompts = batch_prompts
refs = collect_references_batch(refs, vae, image_size)
# == score ==
batch_prompts = append_score_to_prompts(batch_prompts, aes=cfg.get("aes", None), flow=cfg.get("flow", None))
# == multi-resolution info ==
model_args = prepare_multi_resolution_info(
multi_resolution, len(batch_prompts), image_size, num_frames, fps, device, dtype
@ -164,7 +161,7 @@ def main():
save_dir,
sample_name=sample_name,
sample_idx=start_idx + idx,
prompt=batch_prompts[idx],
prompt=original_prompts[idx],
prompt_as_path=prompt_as_path,
num_sample=num_sample,
k=k,

View file

@ -38,7 +38,7 @@ def main():
# == init logger ==
logger = create_logger()
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
logger.info("Inference configuration:\n %s", pformat(cfg.to_dict()))
verbose = cfg.get("verbose", 1)
# ======================================================