From 6b2715e40cac1ae746493d4e26825689782efb09 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 22 Mar 2024 15:07:04 +0800 Subject: [PATCH] Feature/gradio demo (#190) * [gradio] added demo app * polish --- .gitignore | 1 + README.md | 10 +- configs/opensora/inference/16x256x256.py | 2 +- configs/opensora/inference/16x512x512.py | 2 +- configs/opensora/inference/64x512x512.py | 2 +- configs/opensora/train/16x256x256.py | 2 +- opensora/datasets/utils.py | 1 + opensora/models/text_encoder/t5.py | 61 +----- scripts/demo.py | 254 +++++++++++++++++++++++ 9 files changed, 280 insertions(+), 55 deletions(-) create mode 100644 scripts/demo.py diff --git a/.gitignore b/.gitignore index a21188e..c8a1724 100644 --- a/.gitignore +++ b/.gitignore @@ -175,3 +175,4 @@ pretrained_models/ # Secret files hostfile +gradio_cached_examples/ diff --git a/README.md b/README.md index ec8b467..4e4d8e7 100644 --- a/README.md +++ b/README.md @@ -134,7 +134,15 @@ Our model's weight is partially initialized from [PixArt-α](https://github.com/ ## Inference -To run inference with our provided weights, first download [T5](https://huggingface.co/DeepFloyd/t5-v1_1-xxl/tree/main) weights into `pretrained_models/t5_ckpts/t5-v1_1-xxl`. Then download the model weights from [huggingface](https://huggingface.co/hpcai-tech/Open-Sora/tree/main). Run the following commands to generate samples. To change sampling prompts, modify the txt file passed to `--prompt-path`. See [here](docs/structure.md#inference-config-demos) to customize the configuration. +We have provided a Gradio application in this repository, you can use the following the command to start an interactive web application to experience video generation with Open-Sora. + +```bash +python scripts/demo.py +``` + +This will launch a Gradio application on your localhost. + +Besides, we have also provided an offline inference script. To run inference with our provided weights, first download [T5](https://huggingface.co/DeepFloyd/t5-v1_1-xxl/tree/main) weights into `pretrained_models/t5_ckpts/t5-v1_1-xxl`. Then download the model weights from [huggingface](https://huggingface.co/hpcai-tech/Open-Sora/tree/main). Run the following commands to generate samples. To change sampling prompts, modify the txt file passed to `--prompt-path`. See [here](docs/structure.md#inference-config-demos) to customize the configuration. ```bash # Sample 16x256x256 (5s/sample, 100 time steps, 22 GB memory) diff --git a/configs/opensora/inference/16x256x256.py b/configs/opensora/inference/16x256x256.py index fc9a468..db6f2e4 100644 --- a/configs/opensora/inference/16x256x256.py +++ b/configs/opensora/inference/16x256x256.py @@ -18,7 +18,7 @@ vae = dict( ) text_encoder = dict( type="t5", - from_pretrained="./pretrained_models/t5_ckpts", + from_pretrained="DeepFloyd/t5-v1_1-xxl", model_max_length=120, ) scheduler = dict( diff --git a/configs/opensora/inference/16x512x512.py b/configs/opensora/inference/16x512x512.py index afc224c..2064074 100644 --- a/configs/opensora/inference/16x512x512.py +++ b/configs/opensora/inference/16x512x512.py @@ -18,7 +18,7 @@ vae = dict( ) text_encoder = dict( type="t5", - from_pretrained="./pretrained_models/t5_ckpts", + from_pretrained="DeepFloyd/t5-v1_1-xxl", model_max_length=120, ) scheduler = dict( diff --git a/configs/opensora/inference/64x512x512.py b/configs/opensora/inference/64x512x512.py index e15649a..d0fb3e4 100644 --- a/configs/opensora/inference/64x512x512.py +++ b/configs/opensora/inference/64x512x512.py @@ -18,7 +18,7 @@ vae = dict( ) text_encoder = dict( type="t5", - from_pretrained="./pretrained_models/t5_ckpts", + from_pretrained="DeepFloyd/t5-v1_1-xxl", model_max_length=120, ) scheduler = dict( diff --git a/configs/opensora/train/16x256x256.py b/configs/opensora/train/16x256x256.py index a64a318..c8d8b02 100644 --- a/configs/opensora/train/16x256x256.py +++ b/configs/opensora/train/16x256x256.py @@ -29,7 +29,7 @@ vae = dict( ) text_encoder = dict( type="t5", - from_pretrained="./pretrained_models/t5_ckpts", + from_pretrained="DeepFloyd/t5-v1_1-xxl", model_max_length=120, shardformer=True, ) diff --git a/opensora/datasets/utils.py b/opensora/datasets/utils.py index cd268ae..0c4b3b8 100644 --- a/opensora/datasets/utils.py +++ b/opensora/datasets/utils.py @@ -33,6 +33,7 @@ def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1)): x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8) write_video(save_path, x, fps=fps, video_codec="h264") print(f"Saved to {save_path}") + return save_path class StatefulDistributedSampler(DistributedSampler): diff --git a/opensora/models/text_encoder/t5.py b/opensora/models/text_encoder/t5.py index f93612a..7f55ef9 100644 --- a/opensora/models/text_encoder/t5.py +++ b/opensora/models/text_encoder/t5.py @@ -37,7 +37,7 @@ from opensora.registry import MODELS class T5Embedder: - available_models = ["t5-v1_1-xxl"] + available_models = ["DeepFloyd/t5-v1_1-xxl"] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa @@ -45,9 +45,8 @@ class T5Embedder: def __init__( self, device, - dir_or_name="t5-v1_1-xxl", + from_pretrained=None, *, - local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True, @@ -58,8 +57,11 @@ class T5Embedder: ): self.device = torch.device(device) self.torch_dtype = torch_dtype or torch.bfloat16 + self.cache_dir = cache_dir + if t5_model_kwargs is None: t5_model_kwargs = {"low_cpu_mem_usage": True, "torch_dtype": self.torch_dtype} + if use_offload_folder is not None: t5_model_kwargs["offload_folder"] = use_offload_folder t5_model_kwargs["device_map"] = { @@ -97,51 +99,10 @@ class T5Embedder: self.use_text_preprocessing = use_text_preprocessing self.hf_token = hf_token - self.cache_dir = cache_dir or os.path.expanduser("~/.cache/IF_") - self.dir_or_name = dir_or_name - tokenizer_path, path = dir_or_name, dir_or_name - if local_cache: - cache_dir = os.path.join(self.cache_dir, dir_or_name) - tokenizer_path, path = cache_dir, cache_dir - elif dir_or_name in self.available_models: - cache_dir = os.path.join(self.cache_dir, dir_or_name) - for filename in [ - "config.json", - "special_tokens_map.json", - "spiece.model", - "tokenizer_config.json", - "pytorch_model.bin.index.json", - "pytorch_model-00001-of-00002.bin", - "pytorch_model-00002-of-00002.bin", - ]: - hf_hub_download( - repo_id=f"DeepFloyd/{dir_or_name}", - filename=filename, - cache_dir=cache_dir, - force_filename=filename, - token=self.hf_token, - ) - tokenizer_path, path = cache_dir, cache_dir - else: - cache_dir = os.path.join(self.cache_dir, "t5-v1_1-xxl") - for filename in [ - "config.json", - "special_tokens_map.json", - "spiece.model", - "tokenizer_config.json", - ]: - hf_hub_download( - repo_id="DeepFloyd/t5-v1_1-xxl", - filename=filename, - cache_dir=cache_dir, - force_filename=filename, - token=self.hf_token, - ) - tokenizer_path = cache_dir - print(tokenizer_path) - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval() + assert from_pretrained in self.available_models + self.tokenizer = AutoTokenizer.from_pretrained(from_pretrained, cache_dir=cache_dir) + self.model = T5EncoderModel.from_pretrained(from_pretrained, cache_dir=cache_dir, **t5_model_kwargs).eval() self.model_max_length = model_max_length def get_text_embeddings(self, texts): @@ -304,7 +265,7 @@ class T5Encoder: model_max_length=120, device="cuda", dtype=torch.float, - local_cache=True, + cache_dir=None, shardformer=False, ): assert from_pretrained is not None, "Please specify the path to the T5 model" @@ -312,8 +273,8 @@ class T5Encoder: self.t5 = T5Embedder( device=device, torch_dtype=dtype, - local_cache=local_cache, - cache_dir=from_pretrained, + from_pretrained=from_pretrained, + cache_dir=cache_dir, model_max_length=model_max_length, ) self.t5.model.to(dtype=dtype) diff --git a/scripts/demo.py b/scripts/demo.py new file mode 100644 index 0000000..7e03fa3 --- /dev/null +++ b/scripts/demo.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python +""" +This script runs a Gradio App for the Open-Sora model. + +Usage: + python demo.py +""" + +import argparse +import importlib +import os +import subprocess +import sys +from functools import partial + +import gradio as gr +import torch + +MODEL_TYPES = ["v1-16x256x256", "v1-HQ-16x256x256", "v1-HQ-16x512x512"] +CONFIG_MAP = { + "v1-16x256x256": "configs/opensora/inference/16x256x256.py", + "v1-HQ-16x256x256": "configs/opensora/inference/16x512x512.py", + "v1-HQ-16x512x512": "configs/opensora/inference/16x512x512.py", +} +HF_STDIT_MAP = { + "v1-16x256x256": "hpcai-tech/OpenSora-STDiT-v1-16x256x256", + "v1-HQ-16x256x256": "hpcai-tech/OpenSora-STDiT-v1-HQ-16x256x256", + "v1-HQ-16x512x512": "hpcai-tech/OpenSora-STDiT-v1-HQ-16x512x512", +} + + +def install_dependencies(): + """ + Install the required dependencies for the demo if they are not already installed. + """ + + def _is_package_available(name) -> bool: + try: + importlib.import_module(name) + return True + except (ImportError, ModuleNotFoundError): + return False + + # install flash attention + if not _is_package_available("flash_attn"): + subprocess.run( + f"{sys.executable} -m pip install flash-attn --no-build-isolation", + env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, + shell=True, + ) + + # install apex + if not _is_package_available("apex"): + subprocess.run( + f'{sys.executable} -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" git+https://github.com/NVIDIA/apex.git', + shell=True, + ) + + # install ninja + if not _is_package_available("ninja"): + subprocess.run(f"{sys.executable} -m pip install ninja", shell=True) + + # install xformers + if not _is_package_available("xformers"): + subprocess.run( + f"{sys.executable} -m pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers", + shell=True, + ) + + # install opensora + if not _is_package_available("opensora"): + subprocess.run(f"{sys.executable} -m pip install git+https://github.com/hpcaitech/Open-Sora.git", shell=True) + + +def set_up_torch(): + """ + Configure PyTorch for the demo. + """ + torch.set_grad_enabled(False) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + +def get_device(): + """ + Get the default device to run the model. Hugging Face space might provide CPU only, so we need to check for that. + """ + device = "cuda" if torch.cuda.is_available() else "cpu" + return device + + +def read_config(config_path): + """ + Read the configuration file. + """ + from mmengine.config import Config + + return Config.fromfile(config_path) + + +def build_models(model_type, config): + """ + Build the models for the given model type and configuration. + """ + # build vae + from opensora.registry import MODELS, build_module + + vae = build_module(config.vae, MODELS) + + # build text encoder + text_encoder = build_module(config.text_encoder, MODELS, device=get_device()) # T5 must be fp32 + + # build stdit + # we load model from HuggingFace directly so that we don't need to + # handle model download logic in HuggingFace Space + from transformers import AutoModel + + stdit = AutoModel.from_pretrained( + HF_STDIT_MAP[model_type], enable_flash_attn=True, enable_layernorm_kernel=True, trust_remote_code=True + ) + + # build scheduler + from opensora.registry import SCHEDULERS + + scheduler = build_module(config.scheduler, SCHEDULERS) + + # hack for classifier-free guidance + text_encoder.y_embedder = stdit.y_embedder + + # move modelst to device + vae = vae.to(get_device()).to(torch.float16).eval() + text_encoder.t5.model = text_encoder.t5.model.to(get_device()).eval() # t5 must be in fp32 + stdit = stdit.to(get_device()).to(torch.float16).eval() + + return vae, text_encoder, stdit, scheduler + + +def get_latent_size(config, vae): + input_size = (config.num_frames, *config.image_size) + latent_size = vae.get_latent_size(input_size) + return latent_size + + +# @spaces.GPU(duration=200) +def run_inference(prompt_text, config, scheduler, vae, text_encoder, stdit, latent_size, output): + from opensora.datasets import save_sample + + samples = scheduler.sample( + stdit, + text_encoder, + z_size=(vae.out_channels, *latent_size), + prompts=[prompt_text], + device=get_device(), + ) + samples = vae.decode(samples.to(torch.float16)) + filename = f"{output}/sample" + saved_path = save_sample(samples[0], fps=config.fps, save_path=filename) + return saved_path + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-type", + default="v1-HQ-16x512x512", + choices=MODEL_TYPES, + help=f"The type of model to run for the Gradio App, can only be {MODEL_TYPES}", + ) + parser.add_argument("--output", default="./outputs", type=str, help="The path to the output folder") + parser.add_argument("--port", default=8000, type=int, help="The port to run the Gradio App on.") + parser.add_argument("--host", default="127.0.0.1", type=str, help="The host to run the Gradio App on.") + parser.add_argument("--share", action="store_true", help="Whether to share this gradio demo.") + return parser.parse_args() + + +def main(): + # read config + args = parse_args() + config = read_config(CONFIG_MAP[args.model_type]) + + # set up + set_up_torch() + install_dependencies() + + # build model + vae, text_encoder, stdit, scheduler = build_models(args.model_type, config) + + # wrap inference function to accept 1 input only + run_inference_func = partial( + run_inference, + config=config, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + stdit=stdit, + latent_size=get_latent_size(config, vae), + output=args.output, + ) + + # make outputs dir + os.makedirs(args.output, exist_ok=True) + + with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(): + gr.HTML( + """ +
+

+ +

+
+ + + + + + + +
+

Open-Sora: Democratizing Efficient Video Production for All

+
+ """ + ) + + with gr.Row(): + with gr.Column(): + prompt_text = gr.Textbox(show_label=False, placeholder="Describe your video here", lines=4) + submit_button = gr.Button("Generate video") + + with gr.Column(): + output_video = gr.Video() + + submit_button.click(fn=run_inference_func, inputs=[prompt_text], outputs=output_video) + + gr.Examples( + examples=[ + [ + "The video captures the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. The camera angle provides a bird's eye view of the waterfall, allowing viewers to appreciate the full height and grandeur of the waterfall. The video is a stunning representation of nature's power and beauty.", + ], + ], + fn=run_inference_func, + inputs=[ + prompt_text, + ], + outputs=[output_video], + cache_examples=True, + ) + + demo.launch(server_port=args.port, server_name=args.host, share=args.share) + + +if __name__ == "__main__": + main()