mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
[feat] notebook for inference
This commit is contained in:
parent
aec8c60036
commit
a36f80f847
|
|
@ -0,0 +1,352 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Inference for OpenSora"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Define global variables."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# global variables\n",
|
||||
"ROOT = \"..\"\n",
|
||||
"cfg_path = f\"{ROOT}/configs/opensora-v1-2/inference/sample.py\"\n",
|
||||
"ckpt_path = \"/home/lishenggui/projects/sora/Open-Sora-dev/outputs/207-STDiT3-XL-2/epoch0-global_step9000/\"\n",
|
||||
"vae_path = f\"{ROOT}/pretrained_models/vae-pipeline\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Import necessary libraries and load the models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from pprint import pformat\n",
|
||||
"\n",
|
||||
"import colossalai\n",
|
||||
"import torch\n",
|
||||
"import torch.distributed as dist\n",
|
||||
"from colossalai.cluster import DistCoordinator\n",
|
||||
"from mmengine.runner import set_random_seed\n",
|
||||
"from tqdm.notebook import tqdm\n",
|
||||
"\n",
|
||||
"from opensora.acceleration.parallel_states import set_sequence_parallel_group\n",
|
||||
"from opensora.datasets import save_sample, is_img\n",
|
||||
"from opensora.datasets.aspect import get_image_size, get_num_frames\n",
|
||||
"from opensora.models.text_encoder.t5 import text_preprocessing\n",
|
||||
"from opensora.registry import MODELS, SCHEDULERS, build_module\n",
|
||||
"from opensora.utils.config_utils import read_config\n",
|
||||
"from opensora.utils.inference_utils import (\n",
|
||||
" append_generated,\n",
|
||||
" apply_mask_strategy,\n",
|
||||
" collect_references_batch,\n",
|
||||
" extract_json_from_prompts,\n",
|
||||
" extract_prompts_loop,\n",
|
||||
" get_save_path_name,\n",
|
||||
" load_prompts,\n",
|
||||
" prepare_multi_resolution_info,\n",
|
||||
")\n",
|
||||
"from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"42"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.set_grad_enabled(False)\n",
|
||||
"\n",
|
||||
"# == parse configs ==\n",
|
||||
"cfg = read_config(cfg_path)\n",
|
||||
"cfg.model.from_pretrained = ckpt_path\n",
|
||||
"cfg.vae.from_pretrained = vae_path\n",
|
||||
"\n",
|
||||
"# == device and dtype ==\n",
|
||||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||
"cfg_dtype = cfg.get(\"dtype\", \"fp32\")\n",
|
||||
"assert cfg_dtype in [\"fp16\", \"bf16\", \"fp32\"], f\"Unknown mixed precision {cfg_dtype}\"\n",
|
||||
"dtype = to_torch_dtype(cfg.get(\"dtype\", \"bf16\"))\n",
|
||||
"torch.backends.cuda.matmul.allow_tf32 = True\n",
|
||||
"torch.backends.cudnn.allow_tf32 = True\n",
|
||||
"\n",
|
||||
"set_random_seed(seed=cfg.get(\"seed\", 1024))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "29ca38c42a38453aa65784e1ee89a61a",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# == build text-encoder and vae ==\n",
|
||||
"text_encoder = build_module(cfg.text_encoder, MODELS, device=device)\n",
|
||||
"vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()\n",
|
||||
"\n",
|
||||
"# == build diffusion model ==\n",
|
||||
"input_size = (None, None, None)\n",
|
||||
"latent_size = vae.get_latent_size(input_size)\n",
|
||||
"model = (\n",
|
||||
" build_module(\n",
|
||||
" cfg.model,\n",
|
||||
" MODELS,\n",
|
||||
" input_size=latent_size,\n",
|
||||
" in_channels=vae.out_channels,\n",
|
||||
" caption_channels=text_encoder.output_dim,\n",
|
||||
" model_max_length=text_encoder.model_max_length,\n",
|
||||
" )\n",
|
||||
" .to(device, dtype)\n",
|
||||
" .eval()\n",
|
||||
")\n",
|
||||
"text_encoder.y_embedder = model.y_embedder # HACK: for classifier-free guidance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Define inference function."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"start_idx = 0\n",
|
||||
"multi_resolution = cfg.get(\"multi_resolution\", None)\n",
|
||||
"batch_size = cfg.get(\"batch_size\", 1)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def inference(\n",
|
||||
" prompts=cfg.get(\"prompt\", None),\n",
|
||||
" image_size=None,\n",
|
||||
" num_frames=None,\n",
|
||||
" resolution=None,\n",
|
||||
" aspect_ratio=None,\n",
|
||||
" mask_strategy=None,\n",
|
||||
" reference_path=None,\n",
|
||||
" num_sampling_steps=None,\n",
|
||||
" cfg_scale=None,\n",
|
||||
" seed=None,\n",
|
||||
" fps=cfg.fps,\n",
|
||||
" num_sample=cfg.get(\"num_sample\", 1),\n",
|
||||
" loop=cfg.get(\"loop\", 1),\n",
|
||||
" condition_frame_length=cfg.get(\"condition_frame_length\", 5),\n",
|
||||
" align=cfg.get(\"align\", None),\n",
|
||||
" save_dir=os.path.join(ROOT, cfg.save_dir),\n",
|
||||
" sample_name=cfg.get(\"sample_name\", None),\n",
|
||||
" prompt_as_path=cfg.get(\"prompt_as_path\", False),\n",
|
||||
"):\n",
|
||||
" global start_idx\n",
|
||||
" os.makedirs(save_dir, exist_ok=True)\n",
|
||||
" if seed is not None:\n",
|
||||
" set_random_seed(seed=seed)\n",
|
||||
" if not isinstance(prompts, list):\n",
|
||||
" prompts = [prompts]\n",
|
||||
" if mask_strategy is None:\n",
|
||||
" mask_strategy = [\"\"] * len(prompts)\n",
|
||||
" if reference_path is None:\n",
|
||||
" reference_path = [\"\"] * len(prompts)\n",
|
||||
" save_fps = cfg.fps // cfg.get(\"frame_interval\", 1)\n",
|
||||
" if num_sampling_steps is not None:\n",
|
||||
" cfg.scheduler[\"num_sampling_steps\"] = num_sampling_steps\n",
|
||||
" if cfg_scale is not None:\n",
|
||||
" cfg.scheduler[\"scale\"] = cfg_scale\n",
|
||||
" scheduler = build_module(cfg.scheduler, SCHEDULERS)\n",
|
||||
" ret_path = []\n",
|
||||
"\n",
|
||||
" # == prepare video size ==\n",
|
||||
" if image_size is None:\n",
|
||||
" assert (\n",
|
||||
" resolution is not None and aspect_ratio is not None\n",
|
||||
" ), \"resolution and aspect_ratio must be provided if image_size is not provided\"\n",
|
||||
" image_size = get_image_size(resolution, aspect_ratio)\n",
|
||||
" num_frames = get_num_frames(cfg.num_frames)\n",
|
||||
" input_size = (num_frames, *image_size)\n",
|
||||
" latent_size = vae.get_latent_size(input_size)\n",
|
||||
"\n",
|
||||
" # == Iter over all samples ==\n",
|
||||
" for i in tqdm(range(0, len(prompts), batch_size)):\n",
|
||||
" # == prepare batch prompts ==\n",
|
||||
" batch_prompts = prompts[i : i + batch_size]\n",
|
||||
" ms = mask_strategy[i : i + batch_size]\n",
|
||||
" refs = reference_path[i : i + batch_size]\n",
|
||||
"\n",
|
||||
" batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms)\n",
|
||||
" refs = collect_references_batch(refs, vae, image_size)\n",
|
||||
"\n",
|
||||
" # == multi-resolution info ==\n",
|
||||
" model_args = prepare_multi_resolution_info(\n",
|
||||
" multi_resolution, len(batch_prompts), image_size, num_frames, fps, device, dtype\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # == Iter over number of sampling for one prompt ==\n",
|
||||
" for k in range(num_sample):\n",
|
||||
" # == prepare save paths ==\n",
|
||||
" save_paths = [\n",
|
||||
" get_save_path_name(\n",
|
||||
" save_dir,\n",
|
||||
" sample_name=sample_name,\n",
|
||||
" sample_idx=start_idx + idx,\n",
|
||||
" prompt=batch_prompts[idx],\n",
|
||||
" prompt_as_path=prompt_as_path,\n",
|
||||
" num_sample=num_sample,\n",
|
||||
" k=k,\n",
|
||||
" )\n",
|
||||
" for idx in range(len(batch_prompts))\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" # NOTE: Skip if the sample already exists\n",
|
||||
" # This is useful for resuming sampling VBench\n",
|
||||
" if prompt_as_path and all_exists(save_paths):\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" # == Iter over loop generation ==\n",
|
||||
" video_clips = []\n",
|
||||
" for loop_i in range(loop):\n",
|
||||
" batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i)\n",
|
||||
" batch_prompts_cleaned = [text_preprocessing(prompt) for prompt in batch_prompts_loop]\n",
|
||||
"\n",
|
||||
" # == loop ==\n",
|
||||
" if loop_i > 0:\n",
|
||||
" refs, ms = append_generated(vae, video_clips[-1], refs, ms, loop_i, condition_frame_length)\n",
|
||||
"\n",
|
||||
" # == sampling ==\n",
|
||||
" z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)\n",
|
||||
" masks = apply_mask_strategy(z, refs, ms, loop_i, align=align)\n",
|
||||
" samples = scheduler.sample(\n",
|
||||
" model,\n",
|
||||
" text_encoder,\n",
|
||||
" z=z,\n",
|
||||
" prompts=batch_prompts_cleaned,\n",
|
||||
" device=device,\n",
|
||||
" additional_args=model_args,\n",
|
||||
" progress=False,\n",
|
||||
" mask=masks,\n",
|
||||
" )\n",
|
||||
" samples = vae.decode(samples.to(dtype), num_frames=num_frames)\n",
|
||||
" video_clips.append(samples)\n",
|
||||
"\n",
|
||||
" # == save samples ==\n",
|
||||
" if is_main_process():\n",
|
||||
" for idx, batch_prompt in enumerate(batch_prompts):\n",
|
||||
" save_path = save_paths[idx]\n",
|
||||
" video = [video_clips[i][idx] for i in range(loop)]\n",
|
||||
" for i in range(1, loop):\n",
|
||||
" video[i] = video[i][:, condition_frame_length:]\n",
|
||||
" video = torch.cat(video, dim=1)\n",
|
||||
" path = save_sample(\n",
|
||||
" video,\n",
|
||||
" fps=save_fps,\n",
|
||||
" save_path=save_path,\n",
|
||||
" verbose=False,\n",
|
||||
" )\n",
|
||||
" ret_path.append(path)\n",
|
||||
" start_idx += len(batch_prompts)\n",
|
||||
" return ret_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from IPython.display import Video, Image, display\n",
|
||||
"\n",
|
||||
"def display_results(paths):\n",
|
||||
" for path in paths:\n",
|
||||
" if is_img(path):\n",
|
||||
" display(Image(path))\n",
|
||||
" else:\n",
|
||||
" display(Video(path, embed=True))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"paths = inference(\n",
|
||||
" [\"a man.\", \"a woman\"],\n",
|
||||
" resolution=\"240p\",\n",
|
||||
" aspect_ratio=\"1:1\",\n",
|
||||
" num_frames=\"1x\",\n",
|
||||
" num_sampling_steps=30,\n",
|
||||
" cfg_scale=7.0,\n",
|
||||
")\n",
|
||||
"display_results(paths)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "opensora",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
|
|
@ -1,3 +1,3 @@
|
|||
from .dataloader import prepare_dataloader, prepare_variable_dataloader
|
||||
from .datasets import IMG_FPS, VariableVideoTextDataset, VideoTextDataset
|
||||
from .utils import get_transforms_image, get_transforms_video, save_sample
|
||||
from .utils import get_transforms_image, get_transforms_video, is_img, is_vid, save_sample
|
||||
|
|
|
|||
|
|
@ -468,3 +468,18 @@ def get_image_size(resolution, ar_ratio):
|
|||
rs_dict = ASPECT_RATIOS[resolution][1]
|
||||
assert ar_key in rs_dict, f"Aspect ratio {ar_ratio} not found for resolution {resolution}"
|
||||
return rs_dict[ar_key]
|
||||
|
||||
|
||||
NUM_FRAMES_MAP = {
|
||||
"1x": 51,
|
||||
"2x": 102,
|
||||
"4x": 204,
|
||||
"8x": 408,
|
||||
}
|
||||
|
||||
|
||||
def get_num_frames(num_frames):
|
||||
if num_frames in NUM_FRAMES_MAP:
|
||||
return NUM_FRAMES_MAP[num_frames]
|
||||
else:
|
||||
return int(num_frames)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,16 @@ regex = re.compile(
|
|||
)
|
||||
|
||||
|
||||
def is_img(path):
|
||||
ext = os.path.splitext(path)[-1].lower()
|
||||
return ext in IMG_EXTENSIONS
|
||||
|
||||
|
||||
def is_vid(path):
|
||||
ext = os.path.splitext(path)[-1].lower()
|
||||
return ext in VID_EXTENSIONS
|
||||
|
||||
|
||||
def is_url(url):
|
||||
return re.match(regex, url) is not None
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ def parse_args(training=False):
|
|||
parser.add_argument("--prompt", default=None, type=str, nargs="+", help="prompt list")
|
||||
|
||||
# image/video
|
||||
parser.add_argument("--num-frames", default=None, type=int, help="number of frames")
|
||||
parser.add_argument("--num-frames", default=None, type=str, help="number of frames")
|
||||
parser.add_argument("--fps", default=None, type=int, help="fps")
|
||||
parser.add_argument("--image-size", default=None, type=int, nargs=2, help="image size")
|
||||
parser.add_argument("--frame-interval", default=None, type=int, help="frame interval")
|
||||
|
|
@ -111,9 +111,14 @@ def merge_args(cfg, args, training=False):
|
|||
return cfg
|
||||
|
||||
|
||||
def read_config(config_path):
|
||||
cfg = Config.fromfile(config_path)
|
||||
return cfg
|
||||
|
||||
|
||||
def parse_configs(training=False):
|
||||
args = parse_args(training)
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg = read_config(args.config)
|
||||
cfg = merge_args(cfg, args, training)
|
||||
return cfg
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from tqdm import tqdm
|
|||
|
||||
from opensora.acceleration.parallel_states import set_sequence_parallel_group
|
||||
from opensora.datasets import save_sample
|
||||
from opensora.datasets.aspect import get_image_size
|
||||
from opensora.datasets.aspect import get_image_size, get_num_frames
|
||||
from opensora.models.text_encoder.t5 import text_preprocessing
|
||||
from opensora.registry import MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
|
|
@ -78,7 +78,7 @@ def main():
|
|||
resolution is not None and aspect_ratio is not None
|
||||
), "resolution and aspect_ratio must be provided if image_size is not provided"
|
||||
image_size = get_image_size(resolution, aspect_ratio)
|
||||
num_frames = cfg.num_frames
|
||||
num_frames = get_num_frames(cfg.num_frames)
|
||||
|
||||
# == build diffusion model ==
|
||||
input_size = (num_frames, *image_size)
|
||||
|
|
|
|||
Loading…
Reference in a new issue