From a36f80f8470b4a388318c76f82b00eea113adef3 Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Tue, 14 May 2024 07:24:56 +0000 Subject: [PATCH] [feat] notebook for inference --- notebooks/inference.ipynb | 352 +++++++++++++++++++++++++++++++++ opensora/datasets/__init__.py | 2 +- opensora/datasets/aspect.py | 15 ++ opensora/datasets/utils.py | 10 + opensora/utils/config_utils.py | 9 +- scripts/inference.py | 4 +- 6 files changed, 387 insertions(+), 5 deletions(-) diff --git a/notebooks/inference.ipynb b/notebooks/inference.ipynb index e69de29..9a9a7d7 100644 --- a/notebooks/inference.ipynb +++ b/notebooks/inference.ipynb @@ -0,0 +1,352 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Inference for OpenSora" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define global variables." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# global variables\n", + "ROOT = \"..\"\n", + "cfg_path = f\"{ROOT}/configs/opensora-v1-2/inference/sample.py\"\n", + "ckpt_path = \"/home/lishenggui/projects/sora/Open-Sora-dev/outputs/207-STDiT3-XL-2/epoch0-global_step9000/\"\n", + "vae_path = f\"{ROOT}/pretrained_models/vae-pipeline\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Import necessary libraries and load the models." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pprint import pformat\n", + "\n", + "import colossalai\n", + "import torch\n", + "import torch.distributed as dist\n", + "from colossalai.cluster import DistCoordinator\n", + "from mmengine.runner import set_random_seed\n", + "from tqdm.notebook import tqdm\n", + "\n", + "from opensora.acceleration.parallel_states import set_sequence_parallel_group\n", + "from opensora.datasets import save_sample, is_img\n", + "from opensora.datasets.aspect import get_image_size, get_num_frames\n", + "from opensora.models.text_encoder.t5 import text_preprocessing\n", + "from opensora.registry import MODELS, SCHEDULERS, build_module\n", + "from opensora.utils.config_utils import read_config\n", + "from opensora.utils.inference_utils import (\n", + " append_generated,\n", + " apply_mask_strategy,\n", + " collect_references_batch,\n", + " extract_json_from_prompts,\n", + " extract_prompts_loop,\n", + " get_save_path_name,\n", + " load_prompts,\n", + " prepare_multi_resolution_info,\n", + ")\n", + "from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "42" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.set_grad_enabled(False)\n", + "\n", + "# == parse configs ==\n", + "cfg = read_config(cfg_path)\n", + "cfg.model.from_pretrained = ckpt_path\n", + "cfg.vae.from_pretrained = vae_path\n", + "\n", + "# == device and dtype ==\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "cfg_dtype = cfg.get(\"dtype\", \"fp32\")\n", + "assert cfg_dtype in [\"fp16\", \"bf16\", \"fp32\"], f\"Unknown mixed precision {cfg_dtype}\"\n", + "dtype = to_torch_dtype(cfg.get(\"dtype\", \"bf16\"))\n", + "torch.backends.cuda.matmul.allow_tf32 = True\n", + "torch.backends.cudnn.allow_tf32 = True\n", + "\n", + "set_random_seed(seed=cfg.get(\"seed\", 1024))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "29ca38c42a38453aa65784e1ee89a61a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00 0:\n", + " refs, ms = append_generated(vae, video_clips[-1], refs, ms, loop_i, condition_frame_length)\n", + "\n", + " # == sampling ==\n", + " z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)\n", + " masks = apply_mask_strategy(z, refs, ms, loop_i, align=align)\n", + " samples = scheduler.sample(\n", + " model,\n", + " text_encoder,\n", + " z=z,\n", + " prompts=batch_prompts_cleaned,\n", + " device=device,\n", + " additional_args=model_args,\n", + " progress=False,\n", + " mask=masks,\n", + " )\n", + " samples = vae.decode(samples.to(dtype), num_frames=num_frames)\n", + " video_clips.append(samples)\n", + "\n", + " # == save samples ==\n", + " if is_main_process():\n", + " for idx, batch_prompt in enumerate(batch_prompts):\n", + " save_path = save_paths[idx]\n", + " video = [video_clips[i][idx] for i in range(loop)]\n", + " for i in range(1, loop):\n", + " video[i] = video[i][:, condition_frame_length:]\n", + " video = torch.cat(video, dim=1)\n", + " path = save_sample(\n", + " video,\n", + " fps=save_fps,\n", + " save_path=save_path,\n", + " verbose=False,\n", + " )\n", + " ret_path.append(path)\n", + " start_idx += len(batch_prompts)\n", + " return ret_path" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Video, Image, display\n", + "\n", + "def display_results(paths):\n", + " for path in paths:\n", + " if is_img(path):\n", + " display(Image(path))\n", + " else:\n", + " display(Video(path, embed=True))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "paths = inference(\n", + " [\"a man.\", \"a woman\"],\n", + " resolution=\"240p\",\n", + " aspect_ratio=\"1:1\",\n", + " num_frames=\"1x\",\n", + " num_sampling_steps=30,\n", + " cfg_scale=7.0,\n", + ")\n", + "display_results(paths)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "opensora", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/opensora/datasets/__init__.py b/opensora/datasets/__init__.py index 545eb17..aab6bba 100644 --- a/opensora/datasets/__init__.py +++ b/opensora/datasets/__init__.py @@ -1,3 +1,3 @@ from .dataloader import prepare_dataloader, prepare_variable_dataloader from .datasets import IMG_FPS, VariableVideoTextDataset, VideoTextDataset -from .utils import get_transforms_image, get_transforms_video, save_sample +from .utils import get_transforms_image, get_transforms_video, is_img, is_vid, save_sample diff --git a/opensora/datasets/aspect.py b/opensora/datasets/aspect.py index 952e7a4..f6f6065 100644 --- a/opensora/datasets/aspect.py +++ b/opensora/datasets/aspect.py @@ -468,3 +468,18 @@ def get_image_size(resolution, ar_ratio): rs_dict = ASPECT_RATIOS[resolution][1] assert ar_key in rs_dict, f"Aspect ratio {ar_ratio} not found for resolution {resolution}" return rs_dict[ar_key] + + +NUM_FRAMES_MAP = { + "1x": 51, + "2x": 102, + "4x": 204, + "8x": 408, +} + + +def get_num_frames(num_frames): + if num_frames in NUM_FRAMES_MAP: + return NUM_FRAMES_MAP[num_frames] + else: + return int(num_frames) diff --git a/opensora/datasets/utils.py b/opensora/datasets/utils.py index c5b712a..a85cf94 100644 --- a/opensora/datasets/utils.py +++ b/opensora/datasets/utils.py @@ -27,6 +27,16 @@ regex = re.compile( ) +def is_img(path): + ext = os.path.splitext(path)[-1].lower() + return ext in IMG_EXTENSIONS + + +def is_vid(path): + ext = os.path.splitext(path)[-1].lower() + return ext in VID_EXTENSIONS + + def is_url(url): return re.match(regex, url) is not None diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index b8f2af6..cc9e838 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -46,7 +46,7 @@ def parse_args(training=False): parser.add_argument("--prompt", default=None, type=str, nargs="+", help="prompt list") # image/video - parser.add_argument("--num-frames", default=None, type=int, help="number of frames") + parser.add_argument("--num-frames", default=None, type=str, help="number of frames") parser.add_argument("--fps", default=None, type=int, help="fps") parser.add_argument("--image-size", default=None, type=int, nargs=2, help="image size") parser.add_argument("--frame-interval", default=None, type=int, help="frame interval") @@ -111,9 +111,14 @@ def merge_args(cfg, args, training=False): return cfg +def read_config(config_path): + cfg = Config.fromfile(config_path) + return cfg + + def parse_configs(training=False): args = parse_args(training) - cfg = Config.fromfile(args.config) + cfg = read_config(args.config) cfg = merge_args(cfg, args, training) return cfg diff --git a/scripts/inference.py b/scripts/inference.py index af591b0..68f3862 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -10,7 +10,7 @@ from tqdm import tqdm from opensora.acceleration.parallel_states import set_sequence_parallel_group from opensora.datasets import save_sample -from opensora.datasets.aspect import get_image_size +from opensora.datasets.aspect import get_image_size, get_num_frames from opensora.models.text_encoder.t5 import text_preprocessing from opensora.registry import MODELS, SCHEDULERS, build_module from opensora.utils.config_utils import parse_configs @@ -78,7 +78,7 @@ def main(): resolution is not None and aspect_ratio is not None ), "resolution and aspect_ratio must be provided if image_size is not provided" image_size = get_image_size(resolution, aspect_ratio) - num_frames = cfg.num_frames + num_frames = get_num_frames(cfg.num_frames) # == build diffusion model == input_size = (num_frames, *image_size)