mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-30 16:58:09 +02:00
479 lines
15 KiB
Text
479 lines
15 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Inference for OpenSora"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Define global variables. You should change the following variables according to your setting."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# global variables\n",
|
|
"ROOT = \"..\"\n",
|
|
"cfg_path = f\"{ROOT}/configs/opensora-v1-2/inference/sample.py\"\n",
|
|
"ckpt_path = \"/home/lishenggui/projects/sora/Open-Sora-dev/outputs/207-STDiT3-XL-2/epoch0-global_step9000/\"\n",
|
|
"vae_path = f\"{ROOT}/pretrained_models/vae-pipeline\"\n",
|
|
"save_dir = f\"{ROOT}/samples/samples_notebook/\"\n",
|
|
"device = \"cuda:0\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Import necessary libraries and load the models."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"from pprint import pformat\n",
|
|
"\n",
|
|
"import colossalai\n",
|
|
"import torch\n",
|
|
"import torch.distributed as dist\n",
|
|
"from colossalai.cluster import DistCoordinator\n",
|
|
"from mmengine.runner import set_random_seed\n",
|
|
"from tqdm.notebook import tqdm\n",
|
|
"\n",
|
|
"from opensora.acceleration.parallel_states import set_sequence_parallel_group\n",
|
|
"from opensora.datasets import save_sample, is_img\n",
|
|
"from opensora.datasets.aspect import get_image_size, get_num_frames\n",
|
|
"from opensora.models.text_encoder.t5 import text_preprocessing\n",
|
|
"from opensora.registry import MODELS, SCHEDULERS, build_module\n",
|
|
"from opensora.utils.config_utils import read_config\n",
|
|
"from opensora.utils.inference_utils import (\n",
|
|
" append_generated,\n",
|
|
" apply_mask_strategy,\n",
|
|
" collect_references_batch,\n",
|
|
" extract_json_from_prompts,\n",
|
|
" extract_prompts_loop,\n",
|
|
" get_save_path_name,\n",
|
|
" load_prompts,\n",
|
|
" prepare_multi_resolution_info,\n",
|
|
")\n",
|
|
"from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"torch.set_grad_enabled(False)\n",
|
|
"\n",
|
|
"# == parse configs ==\n",
|
|
"cfg = read_config(cfg_path)\n",
|
|
"cfg.model.from_pretrained = ckpt_path\n",
|
|
"cfg.vae.from_pretrained = vae_path\n",
|
|
"\n",
|
|
"# == device and dtype ==\n",
|
|
"cfg_dtype = cfg.get(\"dtype\", \"fp32\")\n",
|
|
"assert cfg_dtype in [\"fp16\", \"bf16\", \"fp32\"], f\"Unknown mixed precision {cfg_dtype}\"\n",
|
|
"dtype = to_torch_dtype(cfg.get(\"dtype\", \"bf16\"))\n",
|
|
"torch.backends.cuda.matmul.allow_tf32 = True\n",
|
|
"torch.backends.cudnn.allow_tf32 = True\n",
|
|
"\n",
|
|
"set_random_seed(seed=cfg.get(\"seed\", 1024))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# == build text-encoder and vae ==\n",
|
|
"text_encoder = build_module(cfg.text_encoder, MODELS, device=device)\n",
|
|
"vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()\n",
|
|
"\n",
|
|
"# == build diffusion model ==\n",
|
|
"input_size = (None, None, None)\n",
|
|
"latent_size = vae.get_latent_size(input_size)\n",
|
|
"model = (\n",
|
|
" build_module(\n",
|
|
" cfg.model,\n",
|
|
" MODELS,\n",
|
|
" input_size=latent_size,\n",
|
|
" in_channels=vae.out_channels,\n",
|
|
" caption_channels=text_encoder.output_dim,\n",
|
|
" model_max_length=text_encoder.model_max_length,\n",
|
|
" )\n",
|
|
" .to(device, dtype)\n",
|
|
" .eval()\n",
|
|
")\n",
|
|
"text_encoder.y_embedder = model.y_embedder # HACK: for classifier-free guidance"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Define inference function."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"start_idx = 0\n",
|
|
"multi_resolution = cfg.get(\"multi_resolution\", None)\n",
|
|
"batch_size = cfg.get(\"batch_size\", 1)\n",
|
|
"\n",
|
|
"\n",
|
|
"def inference(\n",
|
|
" prompts=cfg.get(\"prompt\", None),\n",
|
|
" image_size=None,\n",
|
|
" num_frames=None,\n",
|
|
" resolution=None,\n",
|
|
" aspect_ratio=None,\n",
|
|
" mask_strategy=None,\n",
|
|
" reference_path=None,\n",
|
|
" num_sampling_steps=None,\n",
|
|
" cfg_scale=None,\n",
|
|
" seed=None,\n",
|
|
" fps=cfg.fps,\n",
|
|
" num_sample=cfg.get(\"num_sample\", 1),\n",
|
|
" loop=cfg.get(\"loop\", 1),\n",
|
|
" condition_frame_length=cfg.get(\"condition_frame_length\", 5),\n",
|
|
" align=cfg.get(\"align\", None),\n",
|
|
" sample_name=cfg.get(\"sample_name\", None),\n",
|
|
" prompt_as_path=cfg.get(\"prompt_as_path\", False),\n",
|
|
" disable_progress=False,\n",
|
|
"):\n",
|
|
" global start_idx\n",
|
|
" os.makedirs(save_dir, exist_ok=True)\n",
|
|
" if seed is not None:\n",
|
|
" set_random_seed(seed=seed)\n",
|
|
" if not isinstance(prompts, list):\n",
|
|
" prompts = [prompts]\n",
|
|
" if mask_strategy is None:\n",
|
|
" mask_strategy = [\"\"] * len(prompts)\n",
|
|
" if reference_path is None:\n",
|
|
" reference_path = [\"\"] * len(prompts)\n",
|
|
" save_fps = cfg.fps // cfg.get(\"frame_interval\", 1)\n",
|
|
" if num_sampling_steps is not None:\n",
|
|
" cfg.scheduler[\"num_sampling_steps\"] = num_sampling_steps\n",
|
|
" if cfg_scale is not None:\n",
|
|
" cfg.scheduler[\"scale\"] = cfg_scale\n",
|
|
" scheduler = build_module(cfg.scheduler, SCHEDULERS)\n",
|
|
" ret_path = []\n",
|
|
"\n",
|
|
" # == prepare video size ==\n",
|
|
" if image_size is None:\n",
|
|
" assert (\n",
|
|
" resolution is not None and aspect_ratio is not None\n",
|
|
" ), \"resolution and aspect_ratio must be provided if image_size is not provided\"\n",
|
|
" image_size = get_image_size(resolution, aspect_ratio)\n",
|
|
" num_frames = get_num_frames(num_frames)\n",
|
|
" input_size = (num_frames, *image_size)\n",
|
|
" latent_size = vae.get_latent_size(input_size)\n",
|
|
"\n",
|
|
" # == Iter over all samples ==\n",
|
|
" for i in tqdm(range(0, len(prompts), batch_size), disable=disable_progress):\n",
|
|
" # == prepare batch prompts ==\n",
|
|
" batch_prompts = prompts[i : i + batch_size]\n",
|
|
" ms = mask_strategy[i : i + batch_size]\n",
|
|
" refs = reference_path[i : i + batch_size]\n",
|
|
"\n",
|
|
" batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms)\n",
|
|
" refs = collect_references_batch(refs, vae, image_size)\n",
|
|
"\n",
|
|
" # == multi-resolution info ==\n",
|
|
" model_args = prepare_multi_resolution_info(\n",
|
|
" multi_resolution, len(batch_prompts), image_size, num_frames, fps, device, dtype\n",
|
|
" )\n",
|
|
"\n",
|
|
" # == Iter over number of sampling for one prompt ==\n",
|
|
" for k in range(num_sample):\n",
|
|
" # == prepare save paths ==\n",
|
|
" save_paths = [\n",
|
|
" get_save_path_name(\n",
|
|
" save_dir,\n",
|
|
" sample_name=sample_name,\n",
|
|
" sample_idx=start_idx + idx,\n",
|
|
" prompt=batch_prompts[idx],\n",
|
|
" prompt_as_path=prompt_as_path,\n",
|
|
" num_sample=num_sample,\n",
|
|
" k=k,\n",
|
|
" )\n",
|
|
" for idx in range(len(batch_prompts))\n",
|
|
" ]\n",
|
|
"\n",
|
|
" # NOTE: Skip if the sample already exists\n",
|
|
" # This is useful for resuming sampling VBench\n",
|
|
" if prompt_as_path and all_exists(save_paths):\n",
|
|
" continue\n",
|
|
"\n",
|
|
" # == Iter over loop generation ==\n",
|
|
" video_clips = []\n",
|
|
" for loop_i in range(loop):\n",
|
|
" batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i)\n",
|
|
" batch_prompts_cleaned = [text_preprocessing(prompt) for prompt in batch_prompts_loop]\n",
|
|
"\n",
|
|
" # == loop ==\n",
|
|
" if loop_i > 0:\n",
|
|
" refs, ms = append_generated(vae, video_clips[-1], refs, ms, loop_i, condition_frame_length)\n",
|
|
"\n",
|
|
" # == sampling ==\n",
|
|
" z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)\n",
|
|
" masks = apply_mask_strategy(z, refs, ms, loop_i, align=align)\n",
|
|
" samples = scheduler.sample(\n",
|
|
" model,\n",
|
|
" text_encoder,\n",
|
|
" z=z,\n",
|
|
" prompts=batch_prompts_cleaned,\n",
|
|
" device=device,\n",
|
|
" additional_args=model_args,\n",
|
|
" progress=False,\n",
|
|
" mask=masks,\n",
|
|
" )\n",
|
|
" samples = vae.decode(samples.to(dtype), num_frames=num_frames)\n",
|
|
" video_clips.append(samples)\n",
|
|
"\n",
|
|
" # == save samples ==\n",
|
|
" if is_main_process():\n",
|
|
" for idx, batch_prompt in enumerate(batch_prompts):\n",
|
|
" save_path = save_paths[idx]\n",
|
|
" video = [video_clips[i][idx] for i in range(loop)]\n",
|
|
" for i in range(1, loop):\n",
|
|
" video[i] = video[i][:, condition_frame_length:]\n",
|
|
" video = torch.cat(video, dim=1)\n",
|
|
" path = save_sample(\n",
|
|
" video,\n",
|
|
" fps=save_fps,\n",
|
|
" save_path=save_path,\n",
|
|
" verbose=False,\n",
|
|
" )\n",
|
|
" ret_path.append(path)\n",
|
|
" start_idx += len(batch_prompts)\n",
|
|
" return ret_path"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from IPython.display import Video, Image, display\n",
|
|
"\n",
|
|
"\n",
|
|
"def display_results(paths):\n",
|
|
" for path in paths:\n",
|
|
" if is_img(path):\n",
|
|
" display(Image(path))\n",
|
|
" else:\n",
|
|
" display(Video(path, embed=True))\n",
|
|
"\n",
|
|
"\n",
|
|
"def reset_start_idx():\n",
|
|
" global start_idx\n",
|
|
" start_idx = 0\n",
|
|
"\n",
|
|
"\n",
|
|
"ALL_ASPECT_RATIO = [\"1:1\", \"16:9\", \"9:16\", \"3:4\", \"4:3\", \"1:2\", \"2:1\"]\n",
|
|
"\n",
|
|
"\n",
|
|
"def inference_all_aspects(prompts, resolution, num_frames, *args, **kwargs):\n",
|
|
" paths = []\n",
|
|
" for aspect_ratio in tqdm(ALL_ASPECT_RATIO):\n",
|
|
" paths.extend(\n",
|
|
" inference(\n",
|
|
" prompts,\n",
|
|
" resolution=resolution,\n",
|
|
" num_frames=num_frames,\n",
|
|
" aspect_ratio=aspect_ratio,\n",
|
|
" disable_progress=True,\n",
|
|
" *args,\n",
|
|
" **kwargs\n",
|
|
" )\n",
|
|
" )\n",
|
|
" return paths"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Inference for OpenSora"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Sample code for inference for OpenSora."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"paths = inference(\n",
|
|
" [\"a man.\", \"a woman\"],\n",
|
|
" resolution=\"240p\",\n",
|
|
" aspect_ratio=\"1:1\",\n",
|
|
" num_frames=\"1x\",\n",
|
|
" num_sampling_steps=30,\n",
|
|
" cfg_scale=7.0,\n",
|
|
")\n",
|
|
"display_results(paths)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Sample all aspect ratios."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"PROMPT = \"a boy.\"\n",
|
|
"paths = inference_all_aspects(\n",
|
|
" PROMPT,\n",
|
|
" resolution=\"240p\",\n",
|
|
" num_frames=\"1x\",\n",
|
|
" num_sampling_steps=30,\n",
|
|
" cfg_scale=7.0,\n",
|
|
")\n",
|
|
"display_results(paths)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Sample all resolution and length."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"PROMPT = \"a boy.\"\n",
|
|
"sample_cfg = {\n",
|
|
" \"144p\": [1, \"1x\", \"2x\", \"4x\", \"8x\"],\n",
|
|
" \"240p\": [1, \"1x\", \"2x\", \"4x\", \"8x\"],\n",
|
|
" \"360p\": [1, \"1x\", \"2x\", \"4x\"],\n",
|
|
" \"480p\": [1, \"1x\", \"2x\", \"4x\"],\n",
|
|
" \"720p\": [1, \"1x\", \"2x\"],\n",
|
|
"}\n",
|
|
"all_paths = []\n",
|
|
"for resolution, num_frames in sample_cfg.items():\n",
|
|
" for num_frame in num_frames:\n",
|
|
" print(f\"Resolution: {resolution}, Num Frames: {num_frame}\")\n",
|
|
" paths = inference(\n",
|
|
" PROMPT,\n",
|
|
" resolution=resolution,\n",
|
|
" num_frames=num_frame,\n",
|
|
" aspect_ratio=\"9:16\",\n",
|
|
" num_sampling_steps=30,\n",
|
|
" cfg_scale=7.0,\n",
|
|
" disable_progress=True,\n",
|
|
" )\n",
|
|
" display_results(paths)\n",
|
|
" all_paths.extend(paths)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Sample all resolution, length, and aspect ratios."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"PROMPT = \"a boy.\"\n",
|
|
"sample_cfg = {\n",
|
|
" \"144p\": [1, \"1x\", \"2x\", \"4x\", \"8x\"],\n",
|
|
" \"240p\": [1, \"1x\", \"2x\", \"4x\", \"8x\"],\n",
|
|
" \"360p\": [1, \"1x\", \"2x\", \"4x\"],\n",
|
|
" \"480p\": [1, \"1x\", \"2x\", \"4x\"],\n",
|
|
" \"720p\": [1, \"1x\", \"2x\"],\n",
|
|
"}\n",
|
|
"all_paths = []\n",
|
|
"for resolution, num_frames in sample_cfg.items():\n",
|
|
" for num_frame in num_frames:\n",
|
|
" paths = inference_all_aspects(\n",
|
|
" PROMPT,\n",
|
|
" resolution=resolution,\n",
|
|
" num_frames=num_frames,\n",
|
|
" num_sampling_steps=30,\n",
|
|
" cfg_scale=7.0,\n",
|
|
" )\n",
|
|
" display_results(paths)\n",
|
|
" all_paths.extend(paths)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "opensora",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.14"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|