mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 04:37:45 +02:00
[gradio] udpated inference logic (#138)
* [gradio] udpated inference logic * polish
This commit is contained in:
parent
5fa086cbcb
commit
be61f44a29
110
gradio/app.py
110
gradio/app.py
|
|
@ -26,8 +26,8 @@ CONFIG_MAP = {
|
|||
}
|
||||
HF_STDIT_MAP = {
|
||||
"v1.2-stage3": {
|
||||
"ema": "/mnt/jfs-hdd/sora/checkpoints/outputs/042-STDiT3-XL-2/epoch1-global_step11000/ema.pt",
|
||||
"model": "/mnt/jfs-hdd/sora/checkpoints/outputs/042-STDiT3-XL-2/epoch1-global_step11000/model"
|
||||
"ema": "/mnt/jfs-hdd/sora/checkpoints/outputs/042-STDiT3-XL-2/epoch1-global_step18800/ema.pt",
|
||||
"model": "/mnt/jfs-hdd/sora/checkpoints/outputs/042-STDiT3-XL-2/epoch1-global_step18800/model"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -184,9 +184,12 @@ from opensora.utils.inference_utils import (
|
|||
prepare_multi_resolution_info,
|
||||
dframe_to_frame,
|
||||
append_score_to_prompts,
|
||||
has_openai_key,
|
||||
refine_prompts_by_openai,
|
||||
add_watermark,
|
||||
get_random_prompt_by_openai
|
||||
get_random_prompt_by_openai,
|
||||
split_prompt,
|
||||
merge_prompt
|
||||
)
|
||||
from opensora.models.text_encoder.t5 import text_preprocessing
|
||||
from opensora.datasets.aspect import get_image_size, get_num_frames
|
||||
|
|
@ -199,7 +202,7 @@ device = torch.device("cuda")
|
|||
vae, text_encoder, stdit, scheduler = build_models(args.model_type, config, enable_optimization=args.enable_optimization)
|
||||
|
||||
|
||||
def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, camera_motion, reference_image, enhance_prompt, fps, num_loop, seed, sampling_steps, cfg_scale):
|
||||
def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, camera_motion, reference_image, refine_prompt, fps, num_loop, seed, sampling_steps, cfg_scale):
|
||||
torch.manual_seed(seed)
|
||||
with torch.inference_mode():
|
||||
# ======================
|
||||
|
|
@ -218,13 +221,14 @@ def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_st
|
|||
num_frames = get_num_frames(length)
|
||||
|
||||
condition_frame_length = int(num_frames / 17 * 5 / 3)
|
||||
condition_frame_edit = 0.0
|
||||
|
||||
input_size = (num_frames, *image_size)
|
||||
latent_size = vae.get_latent_size(input_size)
|
||||
multi_resolution = "OpenSora"
|
||||
align = 5
|
||||
|
||||
# prepare reference
|
||||
# == prepare mask strategy ==
|
||||
if mode == "Text2Image":
|
||||
mask_strategy = [None]
|
||||
elif mode == "Text2Video":
|
||||
|
|
@ -235,7 +239,7 @@ def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_st
|
|||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
||||
# prepare refs
|
||||
# == prepare reference ==
|
||||
if mode == "Text2Image":
|
||||
refs = [""]
|
||||
elif mode == "Text2Video":
|
||||
|
|
@ -251,37 +255,60 @@ def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_st
|
|||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
||||
# refine the user prompt with gpt4o
|
||||
# == get json from prompts ==
|
||||
batch_prompts = [prompt_text]
|
||||
if enhance_prompt:
|
||||
# check if openai key is provided
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
gr.Warning("OpenAI API key is not provided, the prompt will not be enhanced.")
|
||||
else:
|
||||
batch_prompts = refine_prompts_by_openai(batch_prompts)
|
||||
|
||||
batch_prompts, refs, mask_strategy = extract_json_from_prompts(batch_prompts, refs, mask_strategy)
|
||||
|
||||
# == get reference for condition ==
|
||||
refs = collect_references_batch(refs, vae, image_size)
|
||||
|
||||
# process scores
|
||||
use_motion_strength = use_motion_strength and mode != "Text2Image"
|
||||
if camera_motion != "none":
|
||||
batch_prompts = [
|
||||
f"{prompt} camera motion: {camera_motion}."
|
||||
for prompt in batch_prompts
|
||||
]
|
||||
batch_prompts = append_score_to_prompts(
|
||||
batch_prompts,
|
||||
aes=aesthetic_score if use_aesthetic_score else None,
|
||||
flow=motion_strength if use_motion_strength else None
|
||||
)
|
||||
|
||||
# multi-resolution info
|
||||
# == multi-resolution info ==
|
||||
model_args = prepare_multi_resolution_info(
|
||||
multi_resolution, len(batch_prompts), image_size, num_frames, fps, device, dtype
|
||||
)
|
||||
|
||||
# == process prompts step by step ==
|
||||
# 0. split prompt
|
||||
# each element in the list is [prompt_segment_list, loop_idx_list]
|
||||
batched_prompt_segment_list = []
|
||||
batched_loop_idx_list = []
|
||||
for prompt in batch_prompts:
|
||||
prompt_segment_list, loop_idx_list = split_prompt(prompt)
|
||||
batched_prompt_segment_list.append(prompt_segment_list)
|
||||
batched_loop_idx_list.append(loop_idx_list)
|
||||
|
||||
# 1. refine prompt by openai
|
||||
if refine_prompt:
|
||||
# check if openai key is provided
|
||||
if not has_openai_key():
|
||||
gr.Warning("OpenAI API key is not provided, the prompt will not be enhanced.")
|
||||
else:
|
||||
for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
|
||||
batched_prompt_segment_list[idx] = refine_prompts_by_openai(prompt_segment_list)
|
||||
|
||||
# process scores
|
||||
aesthetic_score = aesthetic_score if use_aesthetic_score else None
|
||||
motion_strength = motion_strength if use_motion_strength and mode != "Text2Image" else None
|
||||
camera_motion = None if camera_motion == "none" or mode == "Text2Image" else camera_motion
|
||||
# 2. append score
|
||||
for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
|
||||
batched_prompt_segment_list[idx] = append_score_to_prompts(
|
||||
prompt_segment_list,
|
||||
aes=aesthetic_score,
|
||||
flow=motion_strength,
|
||||
camera_motion=camera_motion,
|
||||
)
|
||||
|
||||
# 3. clean prompt with T5
|
||||
for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
|
||||
batched_prompt_segment_list[idx] = [text_preprocessing(prompt) for prompt in prompt_segment_list]
|
||||
|
||||
# 4. merge to obtain the final prompt
|
||||
batch_prompts = []
|
||||
for prompt_segment_list, loop_idx_list in zip(batched_prompt_segment_list, batched_loop_idx_list):
|
||||
batch_prompts.append(merge_prompt(prompt_segment_list, loop_idx_list))
|
||||
|
||||
|
||||
# =========================
|
||||
# Generate image/video
|
||||
# =========================
|
||||
|
|
@ -290,11 +317,18 @@ def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_st
|
|||
for loop_i in range(num_loop):
|
||||
# 4.4 sample in hidden space
|
||||
batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i)
|
||||
batch_prompts_cleaned = [text_preprocessing(prompt) for prompt in batch_prompts_loop]
|
||||
|
||||
# == loop ==
|
||||
if loop_i > 0:
|
||||
refs, mask_strategy = append_generated(vae, video_clips[-1], refs, mask_strategy, loop_i, condition_frame_length)
|
||||
refs, mask_strategy = append_generated(
|
||||
vae,
|
||||
video_clips[-1],
|
||||
refs,
|
||||
mask_strategy,
|
||||
loop_i,
|
||||
condition_frame_length,
|
||||
condition_frame_edit
|
||||
)
|
||||
|
||||
# == sampling ==
|
||||
z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)
|
||||
|
|
@ -314,7 +348,7 @@ def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_st
|
|||
stdit,
|
||||
text_encoder,
|
||||
z=z,
|
||||
prompts=batch_prompts_cleaned,
|
||||
prompts=batch_prompts_loop,
|
||||
device=device,
|
||||
additional_args=model_args,
|
||||
progress=True,
|
||||
|
|
@ -361,7 +395,7 @@ def run_image_inference(
|
|||
use_aesthetic_score,
|
||||
camera_motion,
|
||||
reference_image,
|
||||
enhance_prompt,
|
||||
refine_prompt,
|
||||
fps,
|
||||
num_loop,
|
||||
seed,
|
||||
|
|
@ -379,7 +413,7 @@ def run_image_inference(
|
|||
use_aesthetic_score,
|
||||
camera_motion,
|
||||
reference_image,
|
||||
enhance_prompt,
|
||||
refine_prompt,
|
||||
fps,
|
||||
num_loop,
|
||||
seed,
|
||||
|
|
@ -398,7 +432,7 @@ def run_video_inference(
|
|||
use_aesthetic_score,
|
||||
camera_motion,
|
||||
reference_image,
|
||||
enhance_prompt,
|
||||
refine_prompt,
|
||||
fps,
|
||||
num_loop,
|
||||
seed,
|
||||
|
|
@ -420,7 +454,7 @@ def run_video_inference(
|
|||
use_aesthetic_score,
|
||||
camera_motion,
|
||||
reference_image,
|
||||
enhance_prompt,
|
||||
refine_prompt,
|
||||
fps,
|
||||
num_loop,
|
||||
seed,
|
||||
|
|
@ -471,7 +505,7 @@ def main():
|
|||
info="Empty prompt will mean random prompt from OpenAI.",
|
||||
lines=4,
|
||||
)
|
||||
enhance_prompt = gr.Checkbox(value=True, label="Enhance prompt with GPT4o")
|
||||
refine_prompt = gr.Checkbox(value=True, label="Refine prompt with GPT4o")
|
||||
random_prompt_btn = gr.Button("Random Prompt")
|
||||
|
||||
gr.Markdown("## Basic Settings")
|
||||
|
|
@ -594,12 +628,12 @@ def main():
|
|||
|
||||
image_gen_button.click(
|
||||
fn=run_image_inference,
|
||||
inputs=[prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, camera_motion, reference_image, enhance_prompt, fps, num_loop, seed, sampling_steps, cfg_scale],
|
||||
inputs=[prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, camera_motion, reference_image, refine_prompt, fps, num_loop, seed, sampling_steps, cfg_scale],
|
||||
outputs=reference_image
|
||||
)
|
||||
video_gen_button.click(
|
||||
fn=run_video_inference,
|
||||
inputs=[prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, camera_motion, reference_image, enhance_prompt, fps, num_loop, seed, sampling_steps, cfg_scale],
|
||||
inputs=[prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, camera_motion, reference_image, refine_prompt, fps, num_loop, seed, sampling_steps, cfg_scale],
|
||||
outputs=output_video
|
||||
)
|
||||
random_prompt_btn.click(
|
||||
|
|
|
|||
|
|
@ -115,6 +115,32 @@ def extract_prompts_loop(prompts, num_loop):
|
|||
ret_prompts.append(prompt)
|
||||
return ret_prompts
|
||||
|
||||
def split_prompt(prompt_text):
|
||||
if prompt_text.startswith("|0|"):
|
||||
# this is for prompts which look like
|
||||
# |0| a beautiful day |1| a sunny day |2| a rainy day
|
||||
# we want to parse it into a list of prompts with the loop index
|
||||
prompt_list = prompt_text.split("|")[1:]
|
||||
text_list = []
|
||||
loop_idx = []
|
||||
for i in range(0, len(prompt_list), 2):
|
||||
start_loop = int(prompt_list[i])
|
||||
text = prompt_list[i + 1].strip()
|
||||
text_list.append(text)
|
||||
loop_idx.append(start_loop)
|
||||
return text_list, loop_idx
|
||||
else:
|
||||
return [prompt_text], None
|
||||
|
||||
def merge_prompt(text_list, loop_idx_list=None):
|
||||
if loop_idx_list is None:
|
||||
return text_list[0]
|
||||
else:
|
||||
prompt = ""
|
||||
for i, text in enumerate(text_list):
|
||||
prompt += f"|{loop_idx_list[i]}|{text}"
|
||||
return prompt
|
||||
|
||||
|
||||
MASK_DEFAULT = ["0", "0", "0", "0", "1", "0"]
|
||||
|
||||
|
|
@ -259,6 +285,8 @@ def refine_prompt_by_openai(prompt):
|
|||
response = get_openai_response(REFINE_PROMPTS, prompt)
|
||||
return response
|
||||
|
||||
def has_openai_key():
|
||||
return "OPENAI_API_KEY" in os.environ
|
||||
|
||||
def refine_prompts_by_openai(prompts):
|
||||
new_prompts = []
|
||||
|
|
@ -286,4 +314,5 @@ def add_watermark(
|
|||
output_video_path = input_video_path.replace(".mp4", "_watermark.mp4")
|
||||
cmd = f'ffmpeg -y -i {input_video_path} -i {watermark_image_path} -filter_complex "[1][0]scale2ref=oh*mdar:ih*0.1[logo][video];[video][logo]overlay" {output_video_path}'
|
||||
exit_code = os.system(cmd)
|
||||
return exit_code == 0
|
||||
is_success = exit_code == 0
|
||||
return is_success
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ from opensora.utils.inference_utils import (
|
|||
load_prompts,
|
||||
prepare_multi_resolution_info,
|
||||
refine_prompts_by_openai,
|
||||
split_prompt,
|
||||
merge_prompt
|
||||
)
|
||||
from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype
|
||||
|
||||
|
|
@ -177,31 +179,79 @@ def main():
|
|||
if prompt_as_path and all_exists(save_paths):
|
||||
continue
|
||||
|
||||
# == process prompts step by step ==
|
||||
# 0. split prompt
|
||||
# each element in the list is [prompt_segment_list, loop_idx_list]
|
||||
batched_prompt_segment_list = []
|
||||
batched_loop_idx_list = []
|
||||
for prompt in batch_prompts:
|
||||
prompt_segment_list, loop_idx_list = split_prompt(prompt)
|
||||
batched_prompt_segment_list.append(prompt_segment_list)
|
||||
batched_loop_idx_list.append(loop_idx_list)
|
||||
|
||||
# 1. refine prompt by openai
|
||||
if cfg.get("llm_refine", False):
|
||||
# only call openai API when
|
||||
# 1. seq parallel is not enabled
|
||||
# 2. seq parallel is enabled and the process is rank 0
|
||||
if not enable_sequence_parallelism or (enable_sequence_parallelism and is_main_process()):
|
||||
for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
|
||||
batched_prompt_segment_list[idx] = refine_prompts_by_openai(prompt_segment_list)
|
||||
|
||||
# sync the prompt if using seq parallel
|
||||
if enable_sequence_parallelism:
|
||||
coordinator.block_all()
|
||||
prompt_segment_length = [len(prompt_segment_list) for prompt_segment_list in batched_prompt_segment_list]
|
||||
|
||||
# flatten the prompt segment list
|
||||
batched_prompt_segment_list = [prompt_segment for prompt_segment_list in batched_prompt_segment_list for prompt_segment in prompt_segment_list]
|
||||
|
||||
# create a list of size equal to world size
|
||||
broadcast_obj_list = [batched_prompt_segment_list] * coordinator.world_size
|
||||
dist.broadcast_object_list(broadcast_obj_list, 0)
|
||||
|
||||
# recover the prompt list
|
||||
batched_prompt_segment_list = []
|
||||
start_idx = 0
|
||||
all_prompts = broadcast_obj_list[0]
|
||||
for num_segment in prompt_segment_length:
|
||||
batched_prompt_segment_list.append(all_prompts[start_idx:start_idx+num_segment])
|
||||
start_idx += num_segment
|
||||
|
||||
# 2. append score
|
||||
for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
|
||||
batched_prompt_segment_list[idx] = append_score_to_prompts(
|
||||
prompt_segment_list,
|
||||
aes=cfg.get("aes", None),
|
||||
flow=cfg.get("flow", None),
|
||||
camera_motion=cfg.get("camera_motion", None),
|
||||
)
|
||||
|
||||
# 3. clean prompt with T5
|
||||
for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
|
||||
batched_prompt_segment_list[idx] = [text_preprocessing(prompt) for prompt in prompt_segment_list]
|
||||
|
||||
# 4. merge to obtain the final prompt
|
||||
batch_prompts = []
|
||||
for prompt_segment_list, loop_idx_list in zip(batched_prompt_segment_list, batched_loop_idx_list):
|
||||
batch_prompts.append(merge_prompt(prompt_segment_list, loop_idx_list))
|
||||
|
||||
# == Iter over loop generation ==
|
||||
video_clips = []
|
||||
for loop_i in range(loop):
|
||||
# == get prompt for loop i ==
|
||||
batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i)
|
||||
|
||||
# == refine prompt by openai ==
|
||||
if cfg.get("llm_refine", False):
|
||||
batch_prompts_loop = refine_prompts_by_openai(batch_prompts_loop)
|
||||
|
||||
# == add score to prompt ==
|
||||
batch_prompts_loop = append_score_to_prompts(
|
||||
batch_prompts_loop,
|
||||
aes=cfg.get("aes", None),
|
||||
flow=cfg.get("flow", None),
|
||||
camera_motion=cfg.get("camera_motion", None),
|
||||
)
|
||||
|
||||
# == clean prompt for t5 ==
|
||||
batch_prompts_cleaned = [text_preprocessing(prompt) for prompt in batch_prompts_loop]
|
||||
|
||||
# == add condition frames for loop ==
|
||||
if loop_i > 0:
|
||||
refs, ms = append_generated(
|
||||
vae, video_clips[-1], refs, ms, loop_i, condition_frame_length, condition_frame_edit
|
||||
vae,
|
||||
video_clips[-1],
|
||||
refs,
|
||||
ms,
|
||||
loop_i,
|
||||
condition_frame_length,
|
||||
condition_frame_edit
|
||||
)
|
||||
|
||||
# == sampling ==
|
||||
|
|
@ -211,7 +261,7 @@ def main():
|
|||
model,
|
||||
text_encoder,
|
||||
z=z,
|
||||
prompts=batch_prompts_cleaned,
|
||||
prompts=batch_prompts_loop,
|
||||
device=device,
|
||||
additional_args=model_args,
|
||||
progress=verbose >= 2,
|
||||
|
|
|
|||
Loading…
Reference in a new issue