From 11f1822cd2a39e2c2461a6ea40787e5992401010 Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Sat, 20 Apr 2024 15:45:42 +0000 Subject: [PATCH] merge video edit --- .gitignore | 1 + assets/texts/t2v_ref.txt | 1 + configs/opensora-v1-1/inference/sample-ref.py | 21 +- .../inference-long/16x256x256-sdedit.py | 20 +- eval/sample.sh | 6 +- opensora/datasets/utils.py | 31 +++ .../schedulers/iddpm/gaussian_diffusion.py | 20 +- scripts/inference-long.py | 20 +- scripts/inference-sdedit.py | 220 ------------------ 9 files changed, 79 insertions(+), 261 deletions(-) delete mode 100644 scripts/inference-sdedit.py diff --git a/.gitignore b/.gitignore index 2b9e525..b9f8121 100644 --- a/.gitignore +++ b/.gitignore @@ -174,6 +174,7 @@ samples logs/ pretrained_models/ evaluation_results/ +cache/ *.swp # Secret files diff --git a/assets/texts/t2v_ref.txt b/assets/texts/t2v_ref.txt index 49252b6..c0debe5 100644 --- a/assets/texts/t2v_ref.txt +++ b/assets/texts/t2v_ref.txt @@ -3,3 +3,4 @@ In an ornate, historical hall, a massive tidal wave peaks and begins to crash. T Pirate ship in a cosmic maelstrom nebula. Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway. A sad small cactus with in the Sahara desert becomes happy. +A car driving on a road in the middle of a desert. diff --git a/configs/opensora-v1-1/inference/sample-ref.py b/configs/opensora-v1-1/inference/sample-ref.py index 5099164..6e8e805 100644 --- a/configs/opensora-v1-1/inference/sample-ref.py +++ b/configs/opensora-v1-1/inference/sample-ref.py @@ -6,19 +6,28 @@ multi_resolution = "STDiT2" # Condition prompt_path = None -prompt = None +prompt = [ + "A car driving on a road in the middle of a desert.", + # "A man smiling", + # "Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.", + # "In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave.", +] -loop = 10 +loop = 1 condition_frame_length = 4 reference_path = [ - "assets/images/condition/cliff.png", - "assets/images/condition/wave.png", + "https://cdn.openai.com/tmp/s/interp/d0.mp4", + # "https://www.comp.nus.edu.sg/~youy/index_files/yangyou3.png", + # "assets/images/condition/cliff.png", + # "assets/images/condition/wave.png", ] # valid when reference_path is not None # (loop id, ref id, ref start, length, target start) mask_strategy = [ - "0,0,0,1,0", - "0,0,0,1,0", + "0,0,0,8,0,0.5", + # "0,0,0,1,0", + # "0,0,0,1,0", + # "0,0,0,1,0", ] # Define model diff --git a/configs/opensora/inference-long/16x256x256-sdedit.py b/configs/opensora/inference-long/16x256x256-sdedit.py index 4447455..4ade991 100644 --- a/configs/opensora/inference-long/16x256x256-sdedit.py +++ b/configs/opensora/inference-long/16x256x256-sdedit.py @@ -24,7 +24,6 @@ text_encoder = dict( ) scheduler = dict( type="iddpm", - # type="dpm-solver", num_sampling_steps=100, cfg_scale=7.0, ) @@ -33,31 +32,16 @@ dtype = "fp16" # Condition prompt_path = None prompt = [ - # "Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.", - "In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave.", - # "put the video in space with a rainbow road", - # "make it have dinosaurs", - # "make it in claymation animation style", - # "make it go underwater" + "A car driving on a road in the middle of a desert.", ] loop = 1 condition_frame_length = 4 reference_path = [ - # "assets/images/condition/cliff.png", - "assets/images/condition/wave.png", - # "/home/zhouyukun/Open-Sora-dev/assets/videos/base.mp4", - # "/home/zhouyukun/Open-Sora-dev/assets/videos/base.mp4", - # "/home/zhouyukun/Open-Sora-dev/assets/videos/base.mp4", - # "/home/zhouyukun/Open-Sora-dev/assets/videos/base.mp4", + "https://cdn.openai.com/tmp/s/interp/d0.mp4", ] mask_strategy = [ - # "0,0,0,1,0,0", "0,0,0,1,0,0", - # "0,0,0,12,0,0", # 噪声率 - # "0,0,0,12,0,0", # 噪声率 - # "0,0,0,12,0,0", # 噪声率 - # "0,0,0,12,0,0", # 噪声率 ] # valid when reference_path is not None # (loop id, ref id, ref start, length, target start) diff --git a/eval/sample.sh b/eval/sample.sh index 297205e..37320b0 100644 --- a/eval/sample.sh +++ b/eval/sample.sh @@ -120,11 +120,11 @@ function run_video_edit() { # 23min # 3.2 eval $CMD_REF --ckpt-path $CKPT --save-dir $OUTPUT --sample-name ref_L1_128x240x426 \ - --prompt-path assets/texts/t2v_ref.txt --start-index 3 --end-index 5 \ + --prompt-path assets/texts/t2v_ref.txt --start-index 3 --end-index 6 \ --num-frames 128 --image-size 240 426 \ --loop 1 \ - --reference-path assets/images/condition/cliff.png "assets/images/condition/cactus-sad.png\;assets/images/condition/cactus-happy.png" \ - --mask-strategy "0,0,0,1,0\;0,0,0,1,-1" "0,0,0,1,0\;0,1,0,1,-1" + --reference-path assets/images/condition/cliff.png "assets/images/condition/cactus-sad.png\;assets/images/condition/cactus-happy.png" https://cdn.openai.com/tmp/s/interp/d0.mp4 \ + --mask-strategy "0,0,0,1,0\;0,0,0,1,-1" "0,0,0,1,0\;0,1,0,1,-1" "0,0,0,64,0,0.5" } # vbench has 950 samples diff --git a/opensora/datasets/utils.py b/opensora/datasets/utils.py index 907ba6a..0e0ba6a 100644 --- a/opensora/datasets/utils.py +++ b/opensora/datasets/utils.py @@ -1,7 +1,9 @@ import os +import re import numpy as np import pandas as pd +import requests import torch import torchvision import torchvision.transforms as transforms @@ -14,6 +16,20 @@ from . import video_transforms VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") +regex = re.compile( + r"^(?:http|ftp)s?://" # http:// or https:// + r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain... + r"localhost|" # localhost... + r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip + r"(?::\d+)?" # optional port + r"(?:/?|[/?]\S+)$", + re.IGNORECASE, +) + + +def is_url(url): + return re.match(regex, url) is not None + def read_file(input_path): if input_path.endswith(".csv"): @@ -24,6 +40,19 @@ def read_file(input_path): raise NotImplementedError(f"Unsupported file format: {input_path}") +def download_url(input_path): + output_dir = "cache" + if not os.path.exists(output_dir): + os.makedirs(output_dir) + base_name = os.path.basename(input_path) + output_path = os.path.join(output_dir, base_name) + img_data = requests.get(input_path).content + with open(output_path, "wb") as handler: + handler.write(img_data) + print(f"URL {input_path} downloaded to {output_path}") + return output_path + + def temporal_random_crop(vframes, num_frames, frame_interval): temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval) total_frames = len(vframes) @@ -106,6 +135,8 @@ def read_video_from_path(path, transform=None, transform_name="center", image_si def read_from_path(path, image_size, transform_name="center"): + if is_url(path): + path = download_url(path) ext = os.path.splitext(path)[-1].lower() if ext.lower() in VID_EXTENSIONS: return read_video_from_path(path, image_size=image_size, transform_name=transform_name) diff --git a/opensora/schedulers/iddpm/gaussian_diffusion.py b/opensora/schedulers/iddpm/gaussian_diffusion.py index fa734f4..55292b2 100644 --- a/opensora/schedulers/iddpm/gaussian_diffusion.py +++ b/opensora/schedulers/iddpm/gaussian_diffusion.py @@ -408,15 +408,21 @@ class GaussianDiffusion: if mask is not None: if mask.shape[0] != x.shape[0]: mask = mask.repeat(2, 1) # HACK - # copy unchanged x values to x0 + mask_t = (mask * len(self.betas) - 1).to(torch.int) + + # x0: copy unchanged x values + # x_noise: add noise to x values x0 = x.clone() - mask_t = (mask * len(self.betas)).to(torch.int) - mask_t_equall = (mask_t == t.unsqueeze(1))[:, None, :, None, None] - mask_t_upper = (mask_t > t.unsqueeze(1))[:, None, :, None, None] - x_noise = x0 * _extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) + torch.randn_like(x) * _extract_into_tensor( - self.sqrt_one_minus_alphas_cumprod, t, x.shape) + x_noise = x0 * _extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) + torch.randn_like( + x + ) * _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) + # active noise addition - x = torch.where(mask_t_equall, x_noise, x) + mask_t_equall = (mask_t == t.unsqueeze(1))[:, None, :, None, None] + x = torch.where(mask_t_equall, x_noise, x0) + + # create x_mask + mask_t_upper = (mask_t > t.unsqueeze(1))[:, None, :, None, None] batch_size = x.shape[0] model_kwargs["x_mask"] = mask_t_upper.reshape(batch_size, -1).to(torch.bool) diff --git a/scripts/inference-long.py b/scripts/inference-long.py index 1ab74d3..69060dc 100644 --- a/scripts/inference-long.py +++ b/scripts/inference-long.py @@ -34,9 +34,11 @@ def apply_mask_strategy(z, refs_x, mask_strategys, loop_i): masks = [] for i, mask_strategy in enumerate(mask_strategys): mask_strategy = mask_strategy.split(";") - mask = torch.ones(z.shape[2], dtype=torch.bool, device=z.device) + mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device) for mst in mask_strategy: - loop_id, m_id, m_ref_start, m_length, m_target_start = mst.split(",") + mask_batch = mst.split(",") + loop_id, m_id, m_ref_start, m_length, m_target_start = mask_batch[:5] + edit_ratio = mask_batch[5] if len(mask_batch) == 6 else 0.0 loop_id = int(loop_id) if loop_id != loop_i: continue @@ -44,6 +46,7 @@ def apply_mask_strategy(z, refs_x, mask_strategys, loop_i): m_ref_start = int(m_ref_start) m_length = int(m_length) m_target_start = int(m_target_start) + edit_ratio = float(edit_ratio) ref = refs_x[i][m_id] # [C, T, H, W] if m_ref_start < 0: m_ref_start = ref.shape[1] + m_ref_start @@ -51,7 +54,7 @@ def apply_mask_strategy(z, refs_x, mask_strategys, loop_i): # z: [B, C, T, H, W] m_target_start = z.shape[2] + m_target_start z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length] - mask[m_target_start : m_target_start + m_length] = 0 + mask[m_target_start : m_target_start + m_length] = edit_ratio masks.append(mask) masks = torch.stack(masks) return masks @@ -69,7 +72,7 @@ def process_prompts(prompts, num_loop): text = text_preprocessing(text) end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop text_list.extend([text] * (end_loop - start_loop)) - assert len(text_list) == num_loop + assert len(text_list) == num_loop, f"Prompt loop mismatch: {len(text_list)} != {num_loop}" ret_prompts.append(text_list) else: prompt = text_preprocessing(prompt) @@ -161,8 +164,12 @@ def main(): # 3.5 reference if cfg.reference_path is not None: - assert len(cfg.reference_path) == len(prompts) - assert len(cfg.reference_path) == len(cfg.mask_strategy) + assert len(cfg.reference_path) == len( + prompts + ), f"Reference path mismatch: {len(cfg.reference_path)} != {len(prompts)}" + assert len(cfg.reference_path) == len( + cfg.mask_strategy + ), f"Mask strategy mismatch: {len(cfg.mask_strategy)} != {len(prompts)}" # ====================================================== # 4. inference @@ -204,7 +211,6 @@ def main(): j ] += f";{loop_i},{len(refs)-1},-{cfg.condition_frame_length},{cfg.condition_frame_length},0" masks = apply_mask_strategy(z, refs_x, mask_strategy, loop_i) - model_args["x_mask"] = masks # 4.6. diffusion sampling old_sample_idx = sample_idx diff --git a/scripts/inference-sdedit.py b/scripts/inference-sdedit.py deleted file mode 100644 index 9adafcf..0000000 --- a/scripts/inference-sdedit.py +++ /dev/null @@ -1,220 +0,0 @@ -import os - -import colossalai -import torch -import torch.distributed as dist -from colossalai.cluster import DistCoordinator -from mmengine.runner import set_random_seed - -from opensora.acceleration.parallel_states import set_sequence_parallel_group -from opensora.datasets import save_sample -from opensora.datasets.utils import read_from_path -from opensora.registry import MODELS, SCHEDULERS, build_module -from opensora.utils.config_utils import parse_configs -from opensora.utils.misc import to_torch_dtype - - -def collect_references_batch(reference_paths, vae, image_size): - refs_x = [] - for reference_path in reference_paths: - ref_path = reference_path.split(";") - ref = [] - for r_path in ref_path: - r = read_from_path(r_path, image_size) - r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype)) - r_x = r_x.squeeze(0) - ref.append(r_x) - refs_x.append(ref) - # refs_x: [batch, ref_num, C, T, H, W] - return refs_x - - -def apply_mask_strategy(z, refs_x, mask_strategys, loop_i): - masks = [] - for i, mask_strategy in enumerate(mask_strategys): - mask_strategy = mask_strategy.split(";") - mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device) - for mst in mask_strategy: - mask_batch = mst.split(",") - loop_id, m_id, m_ref_start, m_length, m_target_start = mask_batch[:5] - edit_ratio = mask_batch[5] if len(mask_batch) == 6 else 0.0 - loop_id = int(loop_id) - if loop_id != loop_i: - continue - m_id = int(m_id) - m_ref_start = int(m_ref_start) - m_length = int(m_length) - m_target_start = int(m_target_start) - edit_ratio = float(edit_ratio) - ref = refs_x[i][m_id] # [C, T, H, W] - if m_ref_start < 0: - m_ref_start = ref.shape[1] + m_ref_start - if m_target_start < 0: - # z: [B, C, T, H, W] - m_target_start = z.shape[2] + m_target_start - z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length] - mask[m_target_start : m_target_start + m_length] = edit_ratio - masks.append(mask) - masks = torch.stack(masks) - return masks, z - - -def process_prompts(prompts, num_loop): - ret_prompts = [] - for prompt in prompts: - if prompt.startswith("|0|"): - prompt_list = prompt.split("|")[1:] - text_list = [] - for i in range(0, len(prompt_list), 2): - start_loop = int(prompt_list[i]) - text = prompt_list[i + 1] - end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop - text_list.extend([text] * (end_loop - start_loop)) - assert len(text_list) == num_loop - ret_prompts.append(text_list) - else: - ret_prompts.append([prompt] * num_loop) - return ret_prompts - - -def main(): - # ====================================================== - # 1. cfg and init distributed env - # ====================================================== - cfg = parse_configs(training=False) - print(cfg) - - # init distributed - colossalai.launch_from_torch({}) - coordinator = DistCoordinator() - - if coordinator.world_size > 1: - set_sequence_parallel_group(dist.group.WORLD) - enable_sequence_parallelism = True - else: - enable_sequence_parallelism = False - - # ====================================================== - # 2. runtime variables - # ====================================================== - torch.set_grad_enabled(False) - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - device = "cuda" if torch.cuda.is_available() else "cpu" - dtype = to_torch_dtype(cfg.dtype) - set_random_seed(seed=cfg.seed) - prompts = cfg.prompt - - # ====================================================== - # 3. build model & load weights - # ====================================================== - # 3.1. build model - input_size = (cfg.num_frames, *cfg.image_size) - vae = build_module(cfg.vae, MODELS) - latent_size = vae.get_latent_size(input_size) - text_encoder = build_module(cfg.text_encoder, MODELS, device=device) # T5 must be fp32 - model = build_module( - cfg.model, - MODELS, - input_size=latent_size, - in_channels=vae.out_channels, - caption_channels=text_encoder.output_dim, - model_max_length=text_encoder.model_max_length, - dtype=dtype, - enable_sequence_parallelism=enable_sequence_parallelism, - ) - text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance - - # 3.2. move to device & eval - vae = vae.to(device, dtype).eval() - model = model.to(device, dtype).eval() - - # 3.3. build scheduler - scheduler = build_module(cfg.scheduler, SCHEDULERS) - - # 3.4. support for multi-resolution - model_args = dict() - if cfg.multi_resolution == "PixArtMS": - image_size = cfg.image_size - hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1) - ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1) - model_args["data_info"] = dict(ar=ar, hw=hw) - elif cfg.multi_resolution == "STDiT2": - image_size = cfg.image_size - height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(cfg.batch_size) - width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size) - num_frames = torch.tensor([cfg.num_frames], device=device, dtype=dtype).repeat(cfg.batch_size) - ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size) - model_args["height"] = height - model_args["width"] = width - model_args["num_frames"] = num_frames - model_args["ar"] = ar - - # 3.5 reference - if cfg.reference_path is not None: - assert len(cfg.reference_path) == len(prompts) - assert len(cfg.reference_path) == len(cfg.mask_strategy) - - # ====================================================== - # 4. inference - # ====================================================== - sample_idx = 0 - save_dir = cfg.save_dir - os.makedirs(save_dir, exist_ok=True) - - # 4.1. batch generation - for i in range(0, len(prompts), cfg.batch_size): - batch_prompts_loops = process_prompts(prompts[i : i + cfg.batch_size], cfg.loop) - video_clips = [] - - # 4.2. load reference videos & images - if cfg.reference_path is not None: - refs_x = collect_references_batch(cfg.reference_path[i : i + cfg.batch_size], vae, cfg.image_size) - mask_strategy = cfg.mask_strategy[i : i + cfg.batch_size] - - # 4.3. long video generation - for loop_i in range(cfg.loop): - # 4.4 sample in hidden space - batch_prompts = [prompt[loop_i] for prompt in batch_prompts_loops] - z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype) - - # 4.5. apply mask strategy - masks = None - if cfg.reference_path is not None: - if loop_i > 0: - ref_x = vae.encode(video_clips[-1]) - for j, refs in enumerate(refs_x): - refs.append(ref_x[j]) - mask_strategy[ - j - ] += f";{loop_i},{len(refs)-1},-{cfg.condition_frame_length},{cfg.condition_frame_length},0,0" - masks, z = apply_mask_strategy(z, refs_x, mask_strategy, loop_i) - # 4.6. diffusion sampling - samples = scheduler.sample( - model, - text_encoder, - z=z, - prompts=batch_prompts, - device=device, - additional_args=model_args, - mask=masks, # scheduler must support mask - ) - samples = vae.decode(samples.to(dtype)) - video_clips.append(samples) - - # 4.7. save video - if loop_i == cfg.loop - 1: - if coordinator.is_master(): - for idx in range(len(video_clips[0])): - video_clips_i = [video_clips[0][idx]] + [ - video_clips[i][idx][:, cfg.condition_frame_length :] for i in range(1, cfg.loop) - ] - video = torch.cat(video_clips_i, dim=1) - print(f"Prompt: {prompts[i + idx]}") - save_path = os.path.join(save_dir, f"sample_{prompts[i + idx]}") - save_sample(video, fps=cfg.fps, save_path=save_path) - sample_idx += 1 - - -if __name__ == "__main__": - main()