From 7d27f5553e2e351fca006901808344f779b7b19e Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Sat, 23 Mar 2024 16:32:51 +0800 Subject: [PATCH] merge mask-related utils --- configs/opensora/inference/16x256x256.py | 5 +- .../inference_long/16x256x256_long.py | 46 +++++++++++++++++++ opensora/datasets/datasets.py | 8 ++-- opensora/datasets/utils.py | 30 ++++++++++++ opensora/models/vae/vae.py | 16 +++++++ opensora/utils/config_utils.py | 20 +++++++- scripts/inference.py | 10 +--- scripts/inference_long.py | 0 8 files changed, 120 insertions(+), 15 deletions(-) create mode 100644 configs/opensora/inference_long/16x256x256_long.py create mode 100644 scripts/inference_long.py diff --git a/configs/opensora/inference/16x256x256.py b/configs/opensora/inference/16x256x256.py index db6f2e4..ce01bc5 100644 --- a/configs/opensora/inference/16x256x256.py +++ b/configs/opensora/inference/16x256x256.py @@ -28,8 +28,11 @@ scheduler = dict( ) dtype = "fp16" +# Condition +prompt_path = "./assets/texts/t2v_samples.txt" +prompt = None # prompt has higher priority than prompt_path + # Others batch_size = 1 seed = 42 -prompt_path = "./assets/texts/t2v_samples.txt" save_dir = "./outputs/samples/" diff --git a/configs/opensora/inference_long/16x256x256_long.py b/configs/opensora/inference_long/16x256x256_long.py new file mode 100644 index 0000000..c4aea9d --- /dev/null +++ b/configs/opensora/inference_long/16x256x256_long.py @@ -0,0 +1,46 @@ +num_frames = 16 +fps = 24 // 3 +image_size = (256, 256) + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=0.5, + time_scale=1.0, + enable_flashattn=True, + enable_layernorm_kernel=True, + from_pretrained="PRETRAINED_MODEL", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=4, +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, +) +scheduler = dict( + type="iddpm", + num_sampling_steps=100, + cfg_scale=7.0, +) +dtype = "fp16" + +# Condition +prompt_path = None +prompt = [ + "In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave." +] + +loop = 10 +condition_frame_length = 4 +reference_path = ["assets/images/condition/wave.png"] +mask_strategy = ["0,0,0,1,0"] # valid when reference_path is not None +# loop id, ref id, ref start, length, target start + +# Others +batch_size = 2 +seed = 42 +save_dir = "./samples/" diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index 9d93172..d302186 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -8,14 +8,14 @@ import torchvision.transforms as transforms from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader from . import video_transforms -from .utils import center_crop_arr +from .utils import center_crop_arr, VID_EXTENSIONS def get_transforms_video(resolution=256): transform_video = transforms.Compose( [ video_transforms.ToTensorVideo(), # TCHW - video_transforms.RandomHorizontalFlipVideo(), + # video_transforms.RandomHorizontalFlipVideo(), video_transforms.UCFCenterCropVideo(resolution), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] @@ -27,7 +27,7 @@ def get_transforms_image(image_size=256): transform = transforms.Compose( [ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)), - transforms.RandomHorizontalFlip(), + # transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] @@ -58,7 +58,7 @@ class DatasetFromCSV(torch.utils.data.Dataset): self.samples = list(reader) ext = self.samples[0][0].split(".")[-1] - if ext.lower() in ("mp4", "avi", "mov", "mkv"): + if ext.lower() in VID_EXTENSIONS: self.is_video = True else: assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}" diff --git a/opensora/datasets/utils.py b/opensora/datasets/utils.py index 0c4b3b8..843e1e3 100644 --- a/opensora/datasets/utils.py +++ b/opensora/datasets/utils.py @@ -11,6 +11,36 @@ from torch.utils.data.distributed import DistributedSampler from torchvision.io import write_video from torchvision.utils import save_image +VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv") + + +def read_image_from_path(path, transform=None, num_frames=1, image_size=256): + image = pil_loader(path) + if transform is None: + transform = get_transforms_image(image_size) + image = transform(image) + video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1) + video = video.permute(1, 0, 2, 3) + return video + + +def read_video_from_path(path, transform=None, image_size=256): + vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") + if transform is None: + transform = get_transforms_video(image_size) + video = transform(vframes) # T C H W + video = video.permute(1, 0, 2, 3) + return video + + +def read_from_path(path, image_size): + ext = path.split(".")[-1] + if ext.lower() in VID_EXTENSIONS: + return read_video_from_path(path, image_size=image_size) + else: + assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}" + return read_image_from_path(path, image_size=image_size) + def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1)): """ diff --git a/opensora/models/vae/vae.py b/opensora/models/vae/vae.py index 363bbfe..a26d169 100644 --- a/opensora/models/vae/vae.py +++ b/opensora/models/vae/vae.py @@ -56,6 +56,14 @@ class VideoAutoencoderKL(nn.Module): input_size = [input_size[i] // self.patch_size[i] for i in range(3)] return input_size + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + @MODELS.register_module() class VideoAutoencoderKLTemporalDecoder(nn.Module): @@ -80,3 +88,11 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module): assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size" input_size = [input_size[i] // self.patch_size[i] for i in range(3)] return input_size + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index 5ef8150..89c7fe8 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -7,6 +7,12 @@ from mmengine.config import Config from torch.utils.tensorboard import SummaryWriter +def load_prompts(prompt_path): + with open(prompt_path, "r") as f: + prompts = [line.strip() for line in f.readlines()] + return prompts + + def parse_args(training=False): parser = argparse.ArgumentParser() @@ -47,12 +53,22 @@ def merge_args(cfg, args, training=False): cfg.scheduler["cfg_scale"] = args.cfg_scale args.cfg_scale = None - if "multi_resolution" not in cfg: - cfg["multi_resolution"] = False for k, v in vars(args).items(): if k in cfg and v is not None: cfg[k] = v + if "reference_path" not in cfg: + cfg["reference_path"] = None + if "loop" not in cfg: + cfg["loop"] = 1 + if "multi_resolution" not in cfg: + cfg["multi_resolution"] = False + if "mask_ratios" not in cfg: + cfg["mask_ratios"] = None + if "prompt" not in cfg or cfg["prompt"] is None: + assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided" + cfg["prompt"] = load_prompts(cfg["prompt_path"]) + return cfg diff --git a/scripts/inference.py b/scripts/inference.py index 900870b..7f492aa 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -13,12 +13,6 @@ from opensora.acceleration.parallel_states import set_sequence_parallel_group from colossalai.cluster import DistCoordinator -def load_prompts(prompt_path): - with open(prompt_path, "r") as f: - prompts = [line.strip() for line in f.readlines()] - return prompts - - def main(): # ====================================================== # 1. cfg and init distributed env @@ -31,7 +25,7 @@ def main(): coordinator = DistCoordinator() if coordinator.world_size > 1: - set_sequence_parallel_group(dist.group.WORLD) + set_sequence_parallel_group(dist.group.WORLD) enable_sequence_parallelism = True else: enable_sequence_parallelism = False @@ -45,7 +39,7 @@ def main(): device = "cuda" if torch.cuda.is_available() else "cpu" dtype = to_torch_dtype(cfg.dtype) set_random_seed(seed=cfg.seed) - prompts = load_prompts(cfg.prompt_path) + prompts = cfg.prompt # ====================================================== # 3. build model & load weights diff --git a/scripts/inference_long.py b/scripts/inference_long.py new file mode 100644 index 0000000..e69de29