From d0ebf731128a9d7734720cb8c9e3f60987df013d Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Sat, 4 May 2024 21:00:41 +0800 Subject: [PATCH 1/7] allow variable length training --- scripts/train-vae.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/train-vae.py b/scripts/train-vae.py index 60400d7..43fecaa 100644 --- a/scripts/train-vae.py +++ b/scripts/train-vae.py @@ -231,8 +231,10 @@ def main(): ) as pbar: for step, batch in pbar: x = batch["video"].to(device, dtype) # [B, C, T, H, W] - if random.random() < cfg.get("mixed_image_ratio", 0.0): - x = x[:, :, :1, :, :] + # if random.random() < cfg.get("mixed_image_ratio", 0.0): + # x = x[:, :, :1, :, :] + length = random.randint(1, x.size(2)) + x = x[:, :, :length, :, :] # ===== VAE ===== x_rec, x_z_rec, z, posterior, x_z = model(x) From 4b54a897de091e43d4d7a23b50e45a9d5d6863bf Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Sat, 4 May 2024 21:06:35 +0800 Subject: [PATCH 2/7] typo --- scripts/train-vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train-vae.py b/scripts/train-vae.py index 43fecaa..925528a 100644 --- a/scripts/train-vae.py +++ b/scripts/train-vae.py @@ -234,7 +234,7 @@ def main(): # if random.random() < cfg.get("mixed_image_ratio", 0.0): # x = x[:, :, :1, :, :] length = random.randint(1, x.size(2)) - x = x[:, :, :length, :, :] + x = x[:, :, length, :, :] # ===== VAE ===== x_rec, x_z_rec, z, posterior, x_z = model(x) From 40fd8a6b7be7958486f8ad9fb6912825225bbd16 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Sat, 4 May 2024 21:21:35 +0800 Subject: [PATCH 3/7] typo --- scripts/train-vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train-vae.py b/scripts/train-vae.py index 925528a..43fecaa 100644 --- a/scripts/train-vae.py +++ b/scripts/train-vae.py @@ -234,7 +234,7 @@ def main(): # if random.random() < cfg.get("mixed_image_ratio", 0.0): # x = x[:, :, :1, :, :] length = random.randint(1, x.size(2)) - x = x[:, :, length, :, :] + x = x[:, :, :length, :, :] # ===== VAE ===== x_rec, x_z_rec, z, posterior, x_z = model(x) From 041b14f1385f7c55134b3f3c32957b167b0d647b Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Mon, 6 May 2024 07:47:12 +0000 Subject: [PATCH 4/7] fix padding error;add chunck inference script --- opensora/models/vae/vae_temporal.py | 2 + scripts/inference-vae-chunked-enc.py | 135 +++++++++++++++++++++++++++ 2 files changed, 137 insertions(+) create mode 100644 scripts/inference-vae-chunked-enc.py diff --git a/opensora/models/vae/vae_temporal.py b/opensora/models/vae/vae_temporal.py index b6d96be..d4e10ba 100644 --- a/opensora/models/vae/vae_temporal.py +++ b/opensora/models/vae/vae_temporal.py @@ -374,6 +374,7 @@ class VAE_Temporal(nn.Module): return input_size def encode(self, x): + # time_padding = 0 if (x.shape[2] % self.time_downsample_factor == 0) else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor time_padding = self.time_downsample_factor - x.shape[2] % self.time_downsample_factor x = pad_at_dim(x, (time_padding, 0), dim=2) encoded_feature = self.encoder(x) @@ -382,6 +383,7 @@ class VAE_Temporal(nn.Module): return posterior def decode(self, z, num_frames=None): + # time_padding = 0 if (num_frames % self.time_downsample_factor == 0) else self.time_downsample_factor - num_frames % self.time_downsample_factor time_padding = self.time_downsample_factor - num_frames % self.time_downsample_factor z = self.post_quant_conv(z) x = self.decoder(z) diff --git a/scripts/inference-vae-chunked-enc.py b/scripts/inference-vae-chunked-enc.py new file mode 100644 index 0000000..4d06e59 --- /dev/null +++ b/scripts/inference-vae-chunked-enc.py @@ -0,0 +1,135 @@ +import os + +import colossalai +import torch +import torch.distributed as dist +from colossalai.cluster import DistCoordinator +from mmengine.runner import set_random_seed +from tqdm import tqdm + +from opensora.acceleration.parallel_states import get_data_parallel_group +from opensora.datasets import prepare_dataloader, save_sample +from opensora.models.vae.losses import VAELoss +from opensora.registry import DATASETS, MODELS, build_module +from opensora.utils.config_utils import parse_configs +from opensora.utils.misc import to_torch_dtype + + +def main(): + # ====================================================== + # 1. cfg and init distributed env + # ====================================================== + cfg = parse_configs(training=False) + print(cfg) + + # init distributed + if os.environ.get("WORLD_SIZE", None): + use_dist = True + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + else: + use_dist = False + + # ====================================================== + # 2. runtime variables + # ====================================================== + torch.set_grad_enabled(False) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = to_torch_dtype(cfg.dtype) + set_random_seed(seed=cfg.seed) + + # ====================================================== + # 3. build dataset and dataloader + # ====================================================== + dataset = build_module(cfg.dataset, DATASETS) + dataloader = prepare_dataloader( + dataset, + batch_size=cfg.batch_size, + num_workers=cfg.num_workers, + shuffle=False, + drop_last=True, + pin_memory=True, + process_group=get_data_parallel_group(), + ) + print(f"Dataset contains {len(dataset):,} videos ({cfg.dataset.data_path})") + total_batch_size = cfg.batch_size * dist.get_world_size() + print(f"Total batch size: {total_batch_size}") + + # ====================================================== + # 4. build model & load weights + # ====================================================== + # 4.1. build model + model = build_module(cfg.model, MODELS) + model.to(device, dtype).eval() + + # ====================================================== + # 5. inference + # ====================================================== + save_dir = cfg.save_dir + + # define loss function + vae_loss_fn = VAELoss( + logvar_init=cfg.get("logvar_init", 0.0), + perceptual_loss_weight=cfg.perceptual_loss_weight, + kl_loss_weight=cfg.kl_loss_weight, + device=device, + dtype=dtype, + ) + + # get total number of steps + total_steps = len(dataloader) + if cfg.max_test_samples is not None: + total_steps = min(int(cfg.max_test_samples // cfg.batch_size), total_steps) + print(f"limiting test dataset to {int(cfg.max_test_samples//cfg.batch_size) * cfg.batch_size}") + dataloader_iter = iter(dataloader) + + running_loss = running_nll = running_nll_z = 0.0 + loss_steps = 0 + with tqdm( + range(total_steps), + disable=not coordinator.is_master(), + total=total_steps, + initial=0, + ) as pbar: + for step in pbar: + batch = next(dataloader_iter) + x = batch["video"].to(device, dtype) # [B, C, T, H, W] + input_size = x.shape[2:] + + half_frame = int(x.size(2) // 2) + x_front = x[:,:, :half_frame, :, :] + x_back = x[:, :, half_frame:, :, :] + + # ===== VAE ===== + z_front, posterior_front, x_z_front = model.encode(x_front) + z_back, posterior_back, x_z_back = model.encode(x_back) + + dummy, _, _ = model.encode(x) + latent_size = list(dummy.shape) + + + z = torch.cat((z_front, z_back[:,:, 1:, :, :]), dim=2) + x_z = torch.cat((x_z_front, x_z_back[:,:, 1:, :, :]), dim=2) + assert list(z.shape) == latent_size, f"z shape: {z.shape}, latent_size: {latent_size}" + x_rec, x_z_rec = model.decode(z, num_frames=x.size(2)) + x_ref = model.spatial_vae.decode(x_z) + + if not use_dist or coordinator.is_master(): + ori_dir = f"{save_dir}_ori" + rec_dir = f"{save_dir}_rec" + ref_dir = f"{save_dir}_ref" + os.makedirs(ori_dir, exist_ok=True) + os.makedirs(rec_dir, exist_ok=True) + os.makedirs(ref_dir, exist_ok=True) + for idx, vid in enumerate(x): + pos = step * cfg.batch_size + idx + save_sample(vid, fps=cfg.fps, save_path=f"{ori_dir}/{pos:03d}") + save_sample(x_rec[idx], fps=cfg.fps, save_path=f"{rec_dir}/{pos:03d}") + save_sample(x_ref[idx], fps=cfg.fps, save_path=f"{ref_dir}/{pos:03d}") + + + +if __name__ == "__main__": + main() From 1522f683b8668650ccce904ed32d71d938be8016 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Mon, 6 May 2024 07:49:56 +0000 Subject: [PATCH 5/7] fix padding --- opensora/models/vae/vae_temporal.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/opensora/models/vae/vae_temporal.py b/opensora/models/vae/vae_temporal.py index d4e10ba..fd5c7e8 100644 --- a/opensora/models/vae/vae_temporal.py +++ b/opensora/models/vae/vae_temporal.py @@ -374,8 +374,7 @@ class VAE_Temporal(nn.Module): return input_size def encode(self, x): - # time_padding = 0 if (x.shape[2] % self.time_downsample_factor == 0) else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor - time_padding = self.time_downsample_factor - x.shape[2] % self.time_downsample_factor + time_padding = 0 if (x.shape[2] % self.time_downsample_factor == 0) else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor x = pad_at_dim(x, (time_padding, 0), dim=2) encoded_feature = self.encoder(x) moments = self.quant_conv(encoded_feature).to(x.dtype) @@ -383,8 +382,7 @@ class VAE_Temporal(nn.Module): return posterior def decode(self, z, num_frames=None): - # time_padding = 0 if (num_frames % self.time_downsample_factor == 0) else self.time_downsample_factor - num_frames % self.time_downsample_factor - time_padding = self.time_downsample_factor - num_frames % self.time_downsample_factor + time_padding = 0 if (num_frames % self.time_downsample_factor == 0) else self.time_downsample_factor - num_frames % self.time_downsample_factor z = self.post_quant_conv(z) x = self.decoder(z) x = x[:, :, time_padding:] From fb18b4f274b5c01ce268a90f9ed2615038a4c8cb Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Tue, 7 May 2024 03:30:35 +0000 Subject: [PATCH 6/7] fixed bug, remove temporaroy fix --- scripts/inference-vae-chunked-enc.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/scripts/inference-vae-chunked-enc.py b/scripts/inference-vae-chunked-enc.py index 4d06e59..9192e60 100644 --- a/scripts/inference-vae-chunked-enc.py +++ b/scripts/inference-vae-chunked-enc.py @@ -109,9 +109,8 @@ def main(): dummy, _, _ = model.encode(x) latent_size = list(dummy.shape) - - z = torch.cat((z_front, z_back[:,:, 1:, :, :]), dim=2) - x_z = torch.cat((x_z_front, x_z_back[:,:, 1:, :, :]), dim=2) + z = torch.cat((z_front, z_back), dim=2) + x_z = torch.cat((x_z_front, x_z_back), dim=2) assert list(z.shape) == latent_size, f"z shape: {z.shape}, latent_size: {latent_size}" x_rec, x_z_rec = model.decode(z, num_frames=x.size(2)) x_ref = model.spatial_vae.decode(x_z) From c91a988fbe453e797772d3730089cbd8871a3b78 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Tue, 7 May 2024 11:39:52 +0800 Subject: [PATCH 7/7] use 33 frames --- configs/vae/train/video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/vae/train/video.py b/configs/vae/train/video.py index 02bbb24..df50632 100644 --- a/configs/vae/train/video.py +++ b/configs/vae/train/video.py @@ -1,4 +1,4 @@ -num_frames = 17 +num_frames = 33 image_size = (256, 256) # Define dataset