From 98e62a7c5762d3f8ab91e3c1926ee4192ead6d82 Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Sat, 23 Mar 2024 20:28:34 +0800 Subject: [PATCH] update inference z --- ...{16x256x256_long.py => 16x256x256-long.py} | 3 +- opensora/datasets/__init__.py | 4 +- opensora/datasets/datasets.py | 27 +-- opensora/datasets/utils.py | 29 +++ opensora/schedulers/dpms/__init__.py | 3 +- opensora/schedulers/iddpm/__init__.py | 3 +- scripts/inference.py | 14 +- scripts/inference_long.py | 208 ++++++++++++++++++ 8 files changed, 254 insertions(+), 37 deletions(-) rename configs/opensora/inference_long/{16x256x256_long.py => 16x256x256-long.py} (93%) diff --git a/configs/opensora/inference_long/16x256x256_long.py b/configs/opensora/inference_long/16x256x256-long.py similarity index 93% rename from configs/opensora/inference_long/16x256x256_long.py rename to configs/opensora/inference_long/16x256x256-long.py index c4aea9d..42967e7 100644 --- a/configs/opensora/inference_long/16x256x256_long.py +++ b/configs/opensora/inference_long/16x256x256-long.py @@ -18,11 +18,12 @@ vae = dict( ) text_encoder = dict( type="t5", - from_pretrained="./pretrained_models/t5_ckpts", + from_pretrained="DeepFloyd/t5-v1_1-xxl", model_max_length=120, ) scheduler = dict( type="iddpm", + # type="dpm-solver", num_sampling_steps=100, cfg_scale=7.0, ) diff --git a/opensora/datasets/__init__.py b/opensora/datasets/__init__.py index c9b3395..94c5447 100644 --- a/opensora/datasets/__init__.py +++ b/opensora/datasets/__init__.py @@ -1,2 +1,2 @@ -from .datasets import DatasetFromCSV, get_transforms_image, get_transforms_video -from .utils import prepare_dataloader, save_sample +from .datasets import DatasetFromCSV +from .utils import get_transforms_image, get_transforms_video, prepare_dataloader, save_sample diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index d302186..7e4baf5 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -4,35 +4,10 @@ import os import numpy as np import torch import torchvision -import torchvision.transforms as transforms from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader from . import video_transforms -from .utils import center_crop_arr, VID_EXTENSIONS - - -def get_transforms_video(resolution=256): - transform_video = transforms.Compose( - [ - video_transforms.ToTensorVideo(), # TCHW - # video_transforms.RandomHorizontalFlipVideo(), - video_transforms.UCFCenterCropVideo(resolution), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), - ] - ) - return transform_video - - -def get_transforms_image(image_size=256): - transform = transforms.Compose( - [ - transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)), - # transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), - ] - ) - return transform +from .utils import VID_EXTENSIONS class DatasetFromCSV(torch.utils.data.Dataset): diff --git a/opensora/datasets/utils.py b/opensora/datasets/utils.py index 843e1e3..206d29d 100644 --- a/opensora/datasets/utils.py +++ b/opensora/datasets/utils.py @@ -3,17 +3,46 @@ from typing import Iterator, Optional import numpy as np import torch +import torchvision +import torchvision.transforms as transforms from PIL import Image from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _get_default_group from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler +from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader from torchvision.io import write_video from torchvision.utils import save_image +from . import video_transforms + VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv") +def get_transforms_video(resolution=256): + transform_video = transforms.Compose( + [ + video_transforms.ToTensorVideo(), # TCHW + # video_transforms.RandomHorizontalFlipVideo(), + video_transforms.UCFCenterCropVideo(resolution), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + return transform_video + + +def get_transforms_image(image_size=256): + transform = transforms.Compose( + [ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)), + # transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + return transform + + def read_image_from_path(path, transform=None, num_frames=1, image_size=256): image = pil_loader(path) if transform is None: diff --git a/opensora/schedulers/dpms/__init__.py b/opensora/schedulers/dpms/__init__.py index f0cebbc..084224b 100644 --- a/opensora/schedulers/dpms/__init__.py +++ b/opensora/schedulers/dpms/__init__.py @@ -17,13 +17,12 @@ class DMP_SOLVER: self, model, text_encoder, - z_size, + z, prompts, device, additional_args=None, ): n = len(prompts) - z = torch.randn(n, *z_size, device=device) model_args = text_encoder.encode(prompts) y = model_args.pop("y") null_y = text_encoder.null(n) diff --git a/opensora/schedulers/iddpm/__init__.py b/opensora/schedulers/iddpm/__init__.py index b9806ad..14d6027 100644 --- a/opensora/schedulers/iddpm/__init__.py +++ b/opensora/schedulers/iddpm/__init__.py @@ -54,13 +54,12 @@ class IDDPM(SpacedDiffusion): self, model, text_encoder, - z_size, + z, prompts, device, additional_args=None, ): n = len(prompts) - z = torch.randn(n, *z_size, device=device) z = torch.cat([z, z], 0) model_args = text_encoder.encode(prompts) y_null = text_encoder.null(n) diff --git a/scripts/inference.py b/scripts/inference.py index 7f492aa..f1323d2 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -1,16 +1,16 @@ import os -import torch import colossalai +import torch import torch.distributed as dist +from colossalai.cluster import DistCoordinator from mmengine.runner import set_random_seed +from opensora.acceleration.parallel_states import set_sequence_parallel_group from opensora.datasets import save_sample from opensora.registry import MODELS, SCHEDULERS, build_module from opensora.utils.config_utils import parse_configs from opensora.utils.misc import to_torch_dtype -from opensora.acceleration.parallel_states import set_sequence_parallel_group -from colossalai.cluster import DistCoordinator def main(): @@ -82,12 +82,18 @@ def main(): sample_idx = 0 save_dir = cfg.save_dir os.makedirs(save_dir, exist_ok=True) + + # 4.1. batch generation for i in range(0, len(prompts), cfg.batch_size): + # 4.2 sample in hidden space batch_prompts = prompts[i : i + cfg.batch_size] + z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype) + + # 4.3. diffusion sampling samples = scheduler.sample( model, text_encoder, - z_size=(vae.out_channels, *latent_size), + z=z, prompts=batch_prompts, device=device, additional_args=model_args, diff --git a/scripts/inference_long.py b/scripts/inference_long.py index e69de29..7f59f6c 100644 --- a/scripts/inference_long.py +++ b/scripts/inference_long.py @@ -0,0 +1,208 @@ +import os + +import colossalai +import torch +import torch.distributed as dist +from colossalai.cluster import DistCoordinator +from mmengine.runner import set_random_seed + +from opensora.acceleration.parallel_states import set_sequence_parallel_group +from opensora.datasets import save_sample +from opensora.datasets.utils import read_from_path +from opensora.registry import MODELS, SCHEDULERS, build_module +from opensora.utils.config_utils import parse_configs +from opensora.utils.misc import to_torch_dtype + + +def collect_references_batch(reference_paths, vae, image_size): + refs_x = [] + for reference_path in reference_paths: + ref_path = reference_path.split(";") + ref = [] + for r_path in ref_path: + r = read_from_path(r_path, image_size) + r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype)) + r_x = r_x.squeeze(0) + ref.append(r_x) + refs_x.append(ref) + # refs_x: [batch, ref_num, C, T, H, W] + return refs_x + + +def apply_mask_strategy(z, refs_x, mask_strategys, loop_i): + masks = [] + for i, mask_strategy in enumerate(mask_strategys): + mask_strategy = mask_strategy.split(";") + mask = torch.ones(z.shape[2], dtype=torch.bool, device=z.device) + for mst in mask_strategy: + loop_id, m_id, m_ref_start, m_length, m_target_start = mst.split(",") + loop_id = int(loop_id) + if loop_id != loop_i: + continue + m_id = int(m_id) + m_ref_start = int(m_ref_start) + m_length = int(m_length) + m_target_start = int(m_target_start) + ref = refs_x[i][m_id] # [C, T, H, W] + if m_ref_start < 0: + m_ref_start = ref.shape[1] + m_ref_start + if m_target_start < 0: + # z: [B, C, T, H, W] + m_target_start = z.shape[2] + m_target_start + z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length] + mask[m_target_start : m_target_start + m_length] = 0 + masks.append(mask) + masks = torch.stack(masks) + return masks + + +def process_prompts(prompts, num_loop): + ret_prompts = [] + for prompt in prompts: + if prompt.startswith("|0|"): + prompt_list = prompt.split("|")[1:] + text_list = [] + for i in range(0, len(prompt_list), 2): + start_loop = int(prompt_list[i]) + text = prompt_list[i + 1] + end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop + text_list.extend([text] * (end_loop - start_loop)) + assert len(text_list) == num_loop + ret_prompts.append(text_list) + else: + ret_prompts.append([prompt] * num_loop) + return ret_prompts + + +def main(): + # ====================================================== + # 1. cfg and init distributed env + # ====================================================== + cfg = parse_configs(training=False) + print(cfg) + + # init distributed + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + if coordinator.world_size > 1: + set_sequence_parallel_group(dist.group.WORLD) + enable_sequence_parallelism = True + else: + enable_sequence_parallelism = False + + # ====================================================== + # 2. runtime variables + # ====================================================== + torch.set_grad_enabled(False) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = to_torch_dtype(cfg.dtype) + set_random_seed(seed=cfg.seed) + prompts = cfg.prompt + + # ====================================================== + # 3. build model & load weights + # ====================================================== + # 3.1. build model + input_size = (cfg.num_frames, *cfg.image_size) + vae = build_module(cfg.vae, MODELS) + latent_size = vae.get_latent_size(input_size) + text_encoder = build_module(cfg.text_encoder, MODELS, device=device) # T5 must be fp32 + model = build_module( + cfg.model, + MODELS, + input_size=latent_size, + in_channels=vae.out_channels, + caption_channels=text_encoder.output_dim, + model_max_length=text_encoder.model_max_length, + dtype=dtype, + enable_sequence_parallelism=enable_sequence_parallelism, + ) + text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance + + # 3.2. move to device & eval + vae = vae.to(device, dtype).eval() + model = model.to(device, dtype).eval() + + # 3.3. build scheduler + scheduler = build_module(cfg.scheduler, SCHEDULERS) + + # 3.4. support for multi-resolution + model_args = dict() + if cfg.multi_resolution: + image_size = cfg.image_size + hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1) + ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1) + model_args["data_info"] = dict(ar=ar, hw=hw) + + # 3.5 reference + if cfg.reference_path is not None: + assert len(cfg.reference_path) == len(prompts) + assert len(cfg.reference_path) == len(cfg.mask_strategy) + + # ====================================================== + # 4. inference + # ====================================================== + sample_idx = 0 + save_dir = cfg.save_dir + os.makedirs(save_dir, exist_ok=True) + + # 4.1. batch generation + for i in range(0, len(prompts), cfg.batch_size): + batch_prompts_loops = process_prompts(prompts[i : i + cfg.batch_size], cfg.loop) + video_clips = [] + + # 4.2. load reference videos & images + if cfg.reference_path is not None: + refs_x = collect_references_batch(cfg.reference_path[i : i + cfg.batch_size], vae, cfg.image_size[0]) + mask_strategy = cfg.mask_strategy[i : i + cfg.batch_size] + + # 4.3. long video generation + for loop_i in range(cfg.loop): + # 4.4 sample in hidden space + batch_prompts = [prompt[loop_i] for prompt in batch_prompts_loops] + z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype) + + # 4.5. apply mask strategy + masks = None + if cfg.reference_path is not None: + if loop_i > 0: + ref_x = vae.encode(video_clips[-1]) + for j, refs in enumerate(refs_x): + refs.append(ref_x[j]) + mask_strategy[ + j + ] += f";{loop_i},{len(refs)-1},-{cfg.condition_frame_length},{cfg.condition_frame_length},0" + masks = apply_mask_strategy(z, refs_x, mask_strategy, loop_i) + + # 4.6. diffusion sampling + samples = scheduler.sample( + model, + text_encoder, + z=z, + prompts=batch_prompts, + device=device, + additional_args=model_args, + mask=masks, # scheduler must support mask + ) + samples = vae.decode(samples.to(dtype)) + video_clips.append(samples) + + # 4.7. save video + if loop_i == cfg.loop - 1: + if coordinator.is_master(): + for idx in range(len(video_clips[0])): + video_clips_i = [video_clips[0][idx]] + [ + video_clips[i][idx][:, cfg.condition_frame_length :] for i in range(1, cfg.loop) + ] + video = torch.cat(video_clips_i, dim=1) + print(f"Prompt: {prompts[i + idx]}") + save_path = os.path.join(save_dir, f"sample_{sample_idx}") + save_sample(video, fps=cfg.fps, save_path=save_path) + sample_idx += 1 + + +if __name__ == "__main__": + main()