diff --git a/gradio/app.py b/gradio/app.py index 6632579..9a741fe 100644 --- a/gradio/app.py +++ b/gradio/app.py @@ -26,8 +26,8 @@ CONFIG_MAP = { } HF_STDIT_MAP = { "v1.2-stage3": { - "ema": "/mnt/jfs-hdd/sora/checkpoints/outputs/042-STDiT3-XL-2/epoch1-global_step11000/ema.pt", - "model": "/mnt/jfs-hdd/sora/checkpoints/outputs/042-STDiT3-XL-2/epoch1-global_step11000/model" + "ema": "/mnt/jfs-hdd/sora/checkpoints/outputs/042-STDiT3-XL-2/epoch1-global_step18800/ema.pt", + "model": "/mnt/jfs-hdd/sora/checkpoints/outputs/042-STDiT3-XL-2/epoch1-global_step18800/model" } } @@ -184,9 +184,12 @@ from opensora.utils.inference_utils import ( prepare_multi_resolution_info, dframe_to_frame, append_score_to_prompts, + has_openai_key, refine_prompts_by_openai, add_watermark, - get_random_prompt_by_openai + get_random_prompt_by_openai, + split_prompt, + merge_prompt ) from opensora.models.text_encoder.t5 import text_preprocessing from opensora.datasets.aspect import get_image_size, get_num_frames @@ -199,7 +202,7 @@ device = torch.device("cuda") vae, text_encoder, stdit, scheduler = build_models(args.model_type, config, enable_optimization=args.enable_optimization) -def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, camera_motion, reference_image, enhance_prompt, fps, num_loop, seed, sampling_steps, cfg_scale): +def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, camera_motion, reference_image, refine_prompt, fps, num_loop, seed, sampling_steps, cfg_scale): torch.manual_seed(seed) with torch.inference_mode(): # ====================== @@ -218,13 +221,14 @@ def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_st num_frames = get_num_frames(length) condition_frame_length = int(num_frames / 17 * 5 / 3) + condition_frame_edit = 0.0 input_size = (num_frames, *image_size) latent_size = vae.get_latent_size(input_size) multi_resolution = "OpenSora" align = 5 - # prepare reference + # == prepare mask strategy == if mode == "Text2Image": mask_strategy = [None] elif mode == "Text2Video": @@ -235,7 +239,7 @@ def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_st else: raise ValueError(f"Invalid mode: {mode}") - # prepare refs + # == prepare reference == if mode == "Text2Image": refs = [""] elif mode == "Text2Video": @@ -251,36 +255,59 @@ def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_st else: raise ValueError(f"Invalid mode: {mode}") - # refine the user prompt with gpt4o + # == get json from prompts == batch_prompts = [prompt_text] - if enhance_prompt: - # check if openai key is provided - if "OPENAI_API_KEY" not in os.environ: - gr.Warning("OpenAI API key is not provided, the prompt will not be enhanced.") - else: - batch_prompts = refine_prompts_by_openai(batch_prompts) - batch_prompts, refs, mask_strategy = extract_json_from_prompts(batch_prompts, refs, mask_strategy) - + + # == get reference for condition == refs = collect_references_batch(refs, vae, image_size) - - # process scores - use_motion_strength = use_motion_strength and mode != "Text2Image" - if camera_motion != "none": - batch_prompts = [ - f"{prompt} camera motion: {camera_motion}." - for prompt in batch_prompts - ] - batch_prompts = append_score_to_prompts( - batch_prompts, - aes=aesthetic_score if use_aesthetic_score else None, - flow=motion_strength if use_motion_strength else None - ) - - # multi-resolution info + + # == multi-resolution info == model_args = prepare_multi_resolution_info( multi_resolution, len(batch_prompts), image_size, num_frames, fps, device, dtype ) + + # == process prompts step by step == + # 0. split prompt + # each element in the list is [prompt_segment_list, loop_idx_list] + batched_prompt_segment_list = [] + batched_loop_idx_list = [] + for prompt in batch_prompts: + prompt_segment_list, loop_idx_list = split_prompt(prompt) + batched_prompt_segment_list.append(prompt_segment_list) + batched_loop_idx_list.append(loop_idx_list) + + # 1. refine prompt by openai + if refine_prompt: + # check if openai key is provided + if not has_openai_key(): + gr.Warning("OpenAI API key is not provided, the prompt will not be enhanced.") + else: + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = refine_prompts_by_openai(prompt_segment_list) + + # process scores + aesthetic_score = aesthetic_score if use_aesthetic_score else None + motion_strength = motion_strength if use_motion_strength and mode != "Text2Image" else None + camera_motion = None if camera_motion == "none" or mode == "Text2Image" else camera_motion + # 2. append score + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = append_score_to_prompts( + prompt_segment_list, + aes=aesthetic_score, + flow=motion_strength, + camera_motion=camera_motion, + ) + + # 3. clean prompt with T5 + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = [text_preprocessing(prompt) for prompt in prompt_segment_list] + + # 4. merge to obtain the final prompt + batch_prompts = [] + for prompt_segment_list, loop_idx_list in zip(batched_prompt_segment_list, batched_loop_idx_list): + batch_prompts.append(merge_prompt(prompt_segment_list, loop_idx_list)) + # ========================= # Generate image/video @@ -290,11 +317,18 @@ def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_st for loop_i in range(num_loop): # 4.4 sample in hidden space batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i) - batch_prompts_cleaned = [text_preprocessing(prompt) for prompt in batch_prompts_loop] # == loop == if loop_i > 0: - refs, mask_strategy = append_generated(vae, video_clips[-1], refs, mask_strategy, loop_i, condition_frame_length) + refs, mask_strategy = append_generated( + vae, + video_clips[-1], + refs, + mask_strategy, + loop_i, + condition_frame_length, + condition_frame_edit + ) # == sampling == z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype) @@ -314,7 +348,7 @@ def run_inference(mode, prompt_text, resolution, aspect_ratio, length, motion_st stdit, text_encoder, z=z, - prompts=batch_prompts_cleaned, + prompts=batch_prompts_loop, device=device, additional_args=model_args, progress=True, @@ -361,7 +395,7 @@ def run_image_inference( use_aesthetic_score, camera_motion, reference_image, - enhance_prompt, + refine_prompt, fps, num_loop, seed, @@ -379,7 +413,7 @@ def run_image_inference( use_aesthetic_score, camera_motion, reference_image, - enhance_prompt, + refine_prompt, fps, num_loop, seed, @@ -398,7 +432,7 @@ def run_video_inference( use_aesthetic_score, camera_motion, reference_image, - enhance_prompt, + refine_prompt, fps, num_loop, seed, @@ -420,7 +454,7 @@ def run_video_inference( use_aesthetic_score, camera_motion, reference_image, - enhance_prompt, + refine_prompt, fps, num_loop, seed, @@ -471,7 +505,7 @@ def main(): info="Empty prompt will mean random prompt from OpenAI.", lines=4, ) - enhance_prompt = gr.Checkbox(value=True, label="Enhance prompt with GPT4o") + refine_prompt = gr.Checkbox(value=True, label="Refine prompt with GPT4o") random_prompt_btn = gr.Button("Random Prompt") gr.Markdown("## Basic Settings") @@ -594,12 +628,12 @@ def main(): image_gen_button.click( fn=run_image_inference, - inputs=[prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, camera_motion, reference_image, enhance_prompt, fps, num_loop, seed, sampling_steps, cfg_scale], + inputs=[prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, camera_motion, reference_image, refine_prompt, fps, num_loop, seed, sampling_steps, cfg_scale], outputs=reference_image ) video_gen_button.click( fn=run_video_inference, - inputs=[prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, camera_motion, reference_image, enhance_prompt, fps, num_loop, seed, sampling_steps, cfg_scale], + inputs=[prompt_text, resolution, aspect_ratio, length, motion_strength, aesthetic_score, use_motion_strength, use_aesthetic_score, camera_motion, reference_image, refine_prompt, fps, num_loop, seed, sampling_steps, cfg_scale], outputs=output_video ) random_prompt_btn.click( diff --git a/opensora/utils/inference_utils.py b/opensora/utils/inference_utils.py index a00dc22..56575d3 100644 --- a/opensora/utils/inference_utils.py +++ b/opensora/utils/inference_utils.py @@ -115,6 +115,32 @@ def extract_prompts_loop(prompts, num_loop): ret_prompts.append(prompt) return ret_prompts +def split_prompt(prompt_text): + if prompt_text.startswith("|0|"): + # this is for prompts which look like + # |0| a beautiful day |1| a sunny day |2| a rainy day + # we want to parse it into a list of prompts with the loop index + prompt_list = prompt_text.split("|")[1:] + text_list = [] + loop_idx = [] + for i in range(0, len(prompt_list), 2): + start_loop = int(prompt_list[i]) + text = prompt_list[i + 1].strip() + text_list.append(text) + loop_idx.append(start_loop) + return text_list, loop_idx + else: + return [prompt_text], None + +def merge_prompt(text_list, loop_idx_list=None): + if loop_idx_list is None: + return text_list[0] + else: + prompt = "" + for i, text in enumerate(text_list): + prompt += f"|{loop_idx_list[i]}|{text}" + return prompt + MASK_DEFAULT = ["0", "0", "0", "0", "1", "0"] @@ -259,6 +285,8 @@ def refine_prompt_by_openai(prompt): response = get_openai_response(REFINE_PROMPTS, prompt) return response +def has_openai_key(): + return "OPENAI_API_KEY" in os.environ def refine_prompts_by_openai(prompts): new_prompts = [] @@ -286,4 +314,5 @@ def add_watermark( output_video_path = input_video_path.replace(".mp4", "_watermark.mp4") cmd = f'ffmpeg -y -i {input_video_path} -i {watermark_image_path} -filter_complex "[1][0]scale2ref=oh*mdar:ih*0.1[logo][video];[video][logo]overlay" {output_video_path}' exit_code = os.system(cmd) - return exit_code == 0 + is_success = exit_code == 0 + return is_success diff --git a/scripts/inference.py b/scripts/inference.py index 5058ec9..532792d 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -28,6 +28,8 @@ from opensora.utils.inference_utils import ( load_prompts, prepare_multi_resolution_info, refine_prompts_by_openai, + split_prompt, + merge_prompt ) from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype @@ -177,31 +179,79 @@ def main(): if prompt_as_path and all_exists(save_paths): continue + # == process prompts step by step == + # 0. split prompt + # each element in the list is [prompt_segment_list, loop_idx_list] + batched_prompt_segment_list = [] + batched_loop_idx_list = [] + for prompt in batch_prompts: + prompt_segment_list, loop_idx_list = split_prompt(prompt) + batched_prompt_segment_list.append(prompt_segment_list) + batched_loop_idx_list.append(loop_idx_list) + + # 1. refine prompt by openai + if cfg.get("llm_refine", False): + # only call openai API when + # 1. seq parallel is not enabled + # 2. seq parallel is enabled and the process is rank 0 + if not enable_sequence_parallelism or (enable_sequence_parallelism and is_main_process()): + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = refine_prompts_by_openai(prompt_segment_list) + + # sync the prompt if using seq parallel + if enable_sequence_parallelism: + coordinator.block_all() + prompt_segment_length = [len(prompt_segment_list) for prompt_segment_list in batched_prompt_segment_list] + + # flatten the prompt segment list + batched_prompt_segment_list = [prompt_segment for prompt_segment_list in batched_prompt_segment_list for prompt_segment in prompt_segment_list] + + # create a list of size equal to world size + broadcast_obj_list = [batched_prompt_segment_list] * coordinator.world_size + dist.broadcast_object_list(broadcast_obj_list, 0) + + # recover the prompt list + batched_prompt_segment_list = [] + start_idx = 0 + all_prompts = broadcast_obj_list[0] + for num_segment in prompt_segment_length: + batched_prompt_segment_list.append(all_prompts[start_idx:start_idx+num_segment]) + start_idx += num_segment + + # 2. append score + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = append_score_to_prompts( + prompt_segment_list, + aes=cfg.get("aes", None), + flow=cfg.get("flow", None), + camera_motion=cfg.get("camera_motion", None), + ) + + # 3. clean prompt with T5 + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = [text_preprocessing(prompt) for prompt in prompt_segment_list] + + # 4. merge to obtain the final prompt + batch_prompts = [] + for prompt_segment_list, loop_idx_list in zip(batched_prompt_segment_list, batched_loop_idx_list): + batch_prompts.append(merge_prompt(prompt_segment_list, loop_idx_list)) + # == Iter over loop generation == video_clips = [] for loop_i in range(loop): # == get prompt for loop i == batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i) - # == refine prompt by openai == - if cfg.get("llm_refine", False): - batch_prompts_loop = refine_prompts_by_openai(batch_prompts_loop) - - # == add score to prompt == - batch_prompts_loop = append_score_to_prompts( - batch_prompts_loop, - aes=cfg.get("aes", None), - flow=cfg.get("flow", None), - camera_motion=cfg.get("camera_motion", None), - ) - - # == clean prompt for t5 == - batch_prompts_cleaned = [text_preprocessing(prompt) for prompt in batch_prompts_loop] - # == add condition frames for loop == if loop_i > 0: refs, ms = append_generated( - vae, video_clips[-1], refs, ms, loop_i, condition_frame_length, condition_frame_edit + vae, + video_clips[-1], + refs, + ms, + loop_i, + condition_frame_length, + condition_frame_edit ) # == sampling == @@ -211,7 +261,7 @@ def main(): model, text_encoder, z=z, - prompts=batch_prompts_cleaned, + prompts=batch_prompts_loop, device=device, additional_args=model_args, progress=verbose >= 2,