From 246eddecddc8a41d6577c61bac79cd1ed79e2536 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Thu, 18 Apr 2024 14:13:56 +0800 Subject: [PATCH] enable pipeline vae training --- .../vae_magvit_v2/train/pipeline_16x128x128.py | 4 ++-- opensora/models/vae/vae_3d_v2.py | 18 ++++++++++++------ scripts/train-vae-v2.py | 6 ++---- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/configs/vae_magvit_v2/train/pipeline_16x128x128.py b/configs/vae_magvit_v2/train/pipeline_16x128x128.py index a26f8a8..2751951 100644 --- a/configs/vae_magvit_v2/train/pipeline_16x128x128.py +++ b/configs/vae_magvit_v2/train/pipeline_16x128x128.py @@ -25,7 +25,7 @@ vae_2d = dict( model = dict( type="VAE_MAGVIT_V2", - in_out_channels = 3, + in_out_channels = 4, latent_embed_dim = 256, filters = 128, num_res_blocks = 4, @@ -44,7 +44,7 @@ discriminator = dict( type="DISCRIMINATOR_3D", image_size = image_size, num_frames = num_frames, - in_channels = 3, + in_channels = 4, filters = 128, channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution ) diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index 05ee2a8..bf81b89 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -932,6 +932,8 @@ class VAE_3D_V2(nn.Module): # , ModelMixin video_contains_first_frame = video_contains_first_frame ) + breakpoint() + return recon_video, posterior @@ -991,13 +993,18 @@ class VEALoss(nn.Module): if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0: # handle channels channels = video.shape[1] + assert channels in {1,3,4} if channels == 1: - input_vgg_input = repeat(input_vgg_input, 'b 1 h w -> b c h w', c = 3) - recon_vgg_input = repeat(recon_vgg_input, 'b 1 h w -> b c h w', c = 3) + input_vgg_input = repeat(video, 'b 1 h w -> b c h w', c = 3) + recon_vgg_input = repeat(recon_video, 'b 1 h w -> b c h w', c = 3) elif channels == 4: # SCH: take the first 3 for perceptual loss calc - input_vgg_input = input_vgg_input[:, :3] - recon_vgg_input = recon_vgg_input[:, :3] - perceptual_loss = self.perceptual_loss_fn(video, recon_video) + input_vgg_input = video[:, :3] + recon_vgg_input = recon_video[:, :3] + else: + input_vgg_input = video + recon_vgg_input = recon_video + + perceptual_loss = self.perceptual_loss_fn(input_vgg_input, recon_vgg_input) recon_loss = recon_loss + self.perceptual_loss_weight * perceptual_loss nll_loss = recon_loss / torch.exp(self.logvar) + self.logvar @@ -1223,7 +1230,6 @@ def load_checkpoint_with_inflation(model, ckpt_path): """ if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"): state_dict = find_model(ckpt_path) - breakpoint() # NOTE: need to manually check before first use with torch.no_grad(): for key in state_dict: if key in model: diff --git a/scripts/train-vae-v2.py b/scripts/train-vae-v2.py index ed63341..3d10db0 100644 --- a/scripts/train-vae-v2.py +++ b/scripts/train-vae-v2.py @@ -314,13 +314,11 @@ def main(): video_contains_first_frame = True else: video = x - + # ===== Spatial VAE ===== if cfg.get("use_pipeline") == True: with torch.no_grad(): - video = vae.encode(video) - - breakpoint() + video = vae_2d.encode(video) # ====== VAE ====== recon_video, posterior = vae(