diff --git a/configs/vae/train/video.py b/configs/vae/train/video.py index dea41e5..8dd96c2 100644 --- a/configs/vae/train/video.py +++ b/configs/vae/train/video.py @@ -1,4 +1,4 @@ -num_frames = 17 +num_frames = 33 image_size = (256, 256) # Define dataset diff --git a/opensora/models/vae/vae_temporal.py b/opensora/models/vae/vae_temporal.py index 473d057..bdf336f 100644 --- a/opensora/models/vae/vae_temporal.py +++ b/opensora/models/vae/vae_temporal.py @@ -382,7 +382,7 @@ class VAE_Temporal(nn.Module): return latent_size def encode(self, x): - time_padding = self.time_downsample_factor - x.shape[2] % self.time_downsample_factor + time_padding = 0 if (x.shape[2] % self.time_downsample_factor == 0) else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor x = pad_at_dim(x, (time_padding, 0), dim=2) encoded_feature = self.encoder(x) moments = self.quant_conv(encoded_feature).to(x.dtype) @@ -390,7 +390,7 @@ class VAE_Temporal(nn.Module): return posterior def decode(self, z, num_frames=None): - time_padding = self.time_downsample_factor - num_frames % self.time_downsample_factor + time_padding = 0 if (num_frames % self.time_downsample_factor == 0) else self.time_downsample_factor - num_frames % self.time_downsample_factor z = self.post_quant_conv(z) x = self.decoder(z) x = x[:, :, time_padding:] diff --git a/scripts/inference-vae-chunked-enc.py b/scripts/inference-vae-chunked-enc.py new file mode 100644 index 0000000..9192e60 --- /dev/null +++ b/scripts/inference-vae-chunked-enc.py @@ -0,0 +1,134 @@ +import os + +import colossalai +import torch +import torch.distributed as dist +from colossalai.cluster import DistCoordinator +from mmengine.runner import set_random_seed +from tqdm import tqdm + +from opensora.acceleration.parallel_states import get_data_parallel_group +from opensora.datasets import prepare_dataloader, save_sample +from opensora.models.vae.losses import VAELoss +from opensora.registry import DATASETS, MODELS, build_module +from opensora.utils.config_utils import parse_configs +from opensora.utils.misc import to_torch_dtype + + +def main(): + # ====================================================== + # 1. cfg and init distributed env + # ====================================================== + cfg = parse_configs(training=False) + print(cfg) + + # init distributed + if os.environ.get("WORLD_SIZE", None): + use_dist = True + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + else: + use_dist = False + + # ====================================================== + # 2. runtime variables + # ====================================================== + torch.set_grad_enabled(False) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = to_torch_dtype(cfg.dtype) + set_random_seed(seed=cfg.seed) + + # ====================================================== + # 3. build dataset and dataloader + # ====================================================== + dataset = build_module(cfg.dataset, DATASETS) + dataloader = prepare_dataloader( + dataset, + batch_size=cfg.batch_size, + num_workers=cfg.num_workers, + shuffle=False, + drop_last=True, + pin_memory=True, + process_group=get_data_parallel_group(), + ) + print(f"Dataset contains {len(dataset):,} videos ({cfg.dataset.data_path})") + total_batch_size = cfg.batch_size * dist.get_world_size() + print(f"Total batch size: {total_batch_size}") + + # ====================================================== + # 4. build model & load weights + # ====================================================== + # 4.1. build model + model = build_module(cfg.model, MODELS) + model.to(device, dtype).eval() + + # ====================================================== + # 5. inference + # ====================================================== + save_dir = cfg.save_dir + + # define loss function + vae_loss_fn = VAELoss( + logvar_init=cfg.get("logvar_init", 0.0), + perceptual_loss_weight=cfg.perceptual_loss_weight, + kl_loss_weight=cfg.kl_loss_weight, + device=device, + dtype=dtype, + ) + + # get total number of steps + total_steps = len(dataloader) + if cfg.max_test_samples is not None: + total_steps = min(int(cfg.max_test_samples // cfg.batch_size), total_steps) + print(f"limiting test dataset to {int(cfg.max_test_samples//cfg.batch_size) * cfg.batch_size}") + dataloader_iter = iter(dataloader) + + running_loss = running_nll = running_nll_z = 0.0 + loss_steps = 0 + with tqdm( + range(total_steps), + disable=not coordinator.is_master(), + total=total_steps, + initial=0, + ) as pbar: + for step in pbar: + batch = next(dataloader_iter) + x = batch["video"].to(device, dtype) # [B, C, T, H, W] + input_size = x.shape[2:] + + half_frame = int(x.size(2) // 2) + x_front = x[:,:, :half_frame, :, :] + x_back = x[:, :, half_frame:, :, :] + + # ===== VAE ===== + z_front, posterior_front, x_z_front = model.encode(x_front) + z_back, posterior_back, x_z_back = model.encode(x_back) + + dummy, _, _ = model.encode(x) + latent_size = list(dummy.shape) + + z = torch.cat((z_front, z_back), dim=2) + x_z = torch.cat((x_z_front, x_z_back), dim=2) + assert list(z.shape) == latent_size, f"z shape: {z.shape}, latent_size: {latent_size}" + x_rec, x_z_rec = model.decode(z, num_frames=x.size(2)) + x_ref = model.spatial_vae.decode(x_z) + + if not use_dist or coordinator.is_master(): + ori_dir = f"{save_dir}_ori" + rec_dir = f"{save_dir}_rec" + ref_dir = f"{save_dir}_ref" + os.makedirs(ori_dir, exist_ok=True) + os.makedirs(rec_dir, exist_ok=True) + os.makedirs(ref_dir, exist_ok=True) + for idx, vid in enumerate(x): + pos = step * cfg.batch_size + idx + save_sample(vid, fps=cfg.fps, save_path=f"{ori_dir}/{pos:03d}") + save_sample(x_rec[idx], fps=cfg.fps, save_path=f"{rec_dir}/{pos:03d}") + save_sample(x_ref[idx], fps=cfg.fps, save_path=f"{ref_dir}/{pos:03d}") + + + +if __name__ == "__main__": + main() diff --git a/scripts/train-vae.py b/scripts/train-vae.py index 60400d7..43fecaa 100644 --- a/scripts/train-vae.py +++ b/scripts/train-vae.py @@ -231,8 +231,10 @@ def main(): ) as pbar: for step, batch in pbar: x = batch["video"].to(device, dtype) # [B, C, T, H, W] - if random.random() < cfg.get("mixed_image_ratio", 0.0): - x = x[:, :, :1, :, :] + # if random.random() < cfg.get("mixed_image_ratio", 0.0): + # x = x[:, :, :1, :, :] + length = random.randint(1, x.size(2)) + x = x[:, :, :length, :, :] # ===== VAE ===== x_rec, x_z_rec, z, posterior, x_z = model(x)