From 8bf4a0fa77f97b13a3794ebdbbd6c9fa23c114b2 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Thu, 18 Apr 2024 12:23:02 +0800 Subject: [PATCH] add disable space in vae v2 --- configs/vae_magvit_v2/train/16x128x128.py | 3 +- .../train/pipeline_16x128x128.py | 83 +++++++++++++++++++ opensora/models/vae/vae_3d_v2.py | 52 ++++++++++-- 3 files changed, 129 insertions(+), 9 deletions(-) create mode 100644 configs/vae_magvit_v2/train/pipeline_16x128x128.py diff --git a/configs/vae_magvit_v2/train/16x128x128.py b/configs/vae_magvit_v2/train/16x128x128.py index 269788a..722f5cb 100644 --- a/configs/vae_magvit_v2/train/16x128x128.py +++ b/configs/vae_magvit_v2/train/16x128x128.py @@ -30,6 +30,7 @@ model = dict( kl_embed_dim = 64, activation_fn = 'swish', separate_first_frame_encoding = False, + disable_space = True, custom_conv_padding = None ) @@ -40,7 +41,7 @@ discriminator = dict( num_frames = num_frames, in_channels = 3, filters = 128, - channel_multipliers = (2,4,4,4,4) # (2,4,4,4) for 64x64 resolution + channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution ) diff --git a/configs/vae_magvit_v2/train/pipeline_16x128x128.py b/configs/vae_magvit_v2/train/pipeline_16x128x128.py new file mode 100644 index 0000000..e86c2cc --- /dev/null +++ b/configs/vae_magvit_v2/train/pipeline_16x128x128.py @@ -0,0 +1,83 @@ +num_frames = 16 +frame_interval = 3 +image_size = (128, 128) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 +video_contains_first_frame = False + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + + +# Define model + +model = dict( + type="VAE_MAGVIT_V2", + in_out_channels = 3, + latent_embed_dim = 256, + filters = 128, + num_res_blocks = 4, + channel_multipliers = (1, 2, 2, 4), + temporal_downsample = (False, True, True), + num_groups = 32, # for nn.GroupNorm + kl_embed_dim = 64, + activation_fn = 'swish', + separate_first_frame_encoding = False, + custom_conv_padding = None +) + + +discriminator = dict( + type="DISCRIMINATOR_3D", + image_size = image_size, + num_frames = num_frames, + in_channels = 3, + filters = 128, + channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution +) + + +# loss weights +logvar_init=0.0 +kl_loss_weight = 0.000001 +perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 +discriminator_factor = 1.0 # for discriminator adversarial loss +# discriminator_loss_weight = 0.5 # for generator adversarial loss +generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1 +lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001 +# discriminator_loss_type="non-saturating" +# generator_loss_type="non-saturating" +discriminator_loss_type="hinge" +generator_loss_type="hinge" +discriminator_start = 30000 # 50000 NOTE: change to correct val, debug use -1 for now +gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use +ema_decay = 0.999 # ema decay factor for generator + + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +# Training +''' NOTE: +magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 +==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200], +3-6 epochs for pexel, from pexel observation its correct +''' + +epochs = 10 +log_every = 1 +ckpt_every = 1000 +load = None + +batch_size = 4 +lr = 1e-4 +grad_clip = 1.0 diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index cdfbf74..0ff219b 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -7,7 +7,7 @@ import numpy as np from numpy import typing as nptyping from opensora.models.vae import model_utils from opensora.registry import MODELS -from opensora.utils.ckpt_utils import load_checkpoint +from opensora.utils.ckpt_utils import load_checkpoint, find_model from einops import rearrange, repeat, pack, unpack import torch.nn.functional as F import torchvision @@ -483,6 +483,7 @@ class Encoder(nn.Module): # in_out_channels = 3, # SCH: added, in_channels at the start latent_embed_dim = 512, # num channels for latent vector # conv_downsample = False, + disable_spatial_downsample = False, # for vae pipeline custom_conv_padding = None, activation_fn = 'swish', device="cpu", @@ -496,6 +497,7 @@ class Encoder(nn.Module): self.num_groups = num_groups self.embedding_dim = latent_embed_dim + self.disable_spatial_downsample = disable_spatial_downsample # self.conv_downsample = conv_downsample self.custom_conv_padding = custom_conv_padding @@ -542,9 +544,10 @@ class Encoder(nn.Module): prev_filters = filters # update in_channels self.block_res_blocks.append(block_items) - if i < self.num_blocks - 1: # SCH: T-Causal Conv 3x3x3, 128->128, stride t x 2 x 2 + if i < self.num_blocks - 1: # SCH: T-Causal Conv 3x3x3, 128->128, stride t x stride s x stride s t_stride = 2 if self.temporal_downsample[i] else 1 - self.conv_blocks.append(self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, 2, 2))) # SCH: should be same in_channel and out_channel + s_stride = 2 if not self.disable_spatial_downsample else 1 + self.conv_blocks.append(self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride))) # SCH: should be same in_channel and out_channel prev_filters = filters # update in_channels @@ -598,6 +601,7 @@ class Decoder(nn.Module): temporal_downsample = (False, True, True), num_groups = 32, # for nn.GroupNorm # upsample = "nearest+conv", # options: "deconv", "nearest+conv" + disable_spatial_upsample = False, # for vae pipeline custom_conv_padding = None, activation_fn = 'swish', device="cpu", @@ -613,6 +617,7 @@ class Decoder(nn.Module): self.num_groups = num_groups # self.upsample = upsample + self.s_stride = 1 if self.disable_spatial_upsample else 2 # spatial stride self.custom_conv_padding = custom_conv_padding # self.norm_type = self.config.vqvae.norm_type # self.num_remat_block = self.config.vqvae.get('num_dec_remat_blocks', 0) @@ -677,7 +682,7 @@ class Decoder(nn.Module): t_stride = 2 if self.temporal_downsample[i - 1] else 1 # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 self.conv_blocks.insert(0, - self.conv_fn(prev_filters, prev_filters * t_stride * 4, kernel_size=(3,3,3)) + self.conv_fn(prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3,3,3)) ) self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, device=device, dtype=dtype) @@ -706,7 +711,7 @@ class Decoder(nn.Module): t_stride = 2 if self.temporal_downsample[i - 1] else 1 # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 x = self.conv_blocks[i-1](x) - x = rearrange(x, "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)", ts=t_stride, hs=2, ws=2) + x = rearrange(x, "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)", ts=t_stride, hs=self.s_stride, ws=self.s_stride) # print("decoder:", x.size()) x = self.norm1(x) @@ -728,6 +733,7 @@ class VAE_3D_V2(nn.Module): # , ModelMixin channel_multipliers = (1, 2, 2, 4), temporal_downsample = (True, True, False), num_groups = 32, # for nn.GroupNorm + disable_space = False, custom_conv_padding = None, activation_fn = 'swish', in_out_channels = 4, @@ -777,6 +783,7 @@ class VAE_3D_V2(nn.Module): # , ModelMixin # in_out_channels = in_out_channels, latent_embed_dim = latent_embed_dim, # conv_downsample = conv_downsample, + disable_spatial_downsample=disable_space, custom_conv_padding = custom_conv_padding, activation_fn = activation_fn, device = device, @@ -791,6 +798,7 @@ class VAE_3D_V2(nn.Module): # , ModelMixin temporal_downsample = temporal_downsample, num_groups = num_groups, # for nn.GroupNorm # upsample = upsample, # options: "deconv", "nearest+conv" + disable_spatial_upsample=disable_space, custom_conv_padding = custom_conv_padding, activation_fn = activation_fn, device = device, @@ -1196,10 +1204,38 @@ def VAE_MAGVIT_V2(from_pretrained=None, **kwargs): return model @MODELS.register_module("DISCRIMINATOR_3D") -def DISCRIMINATOR_3D(from_pretrained=None, **kwargs): +def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, **kwargs): model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init) # model = StyleGANDiscriminator(**kwargs).apply(xavier_uniform_weight_init) # SCH: DEBUG: to change back # model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init) if from_pretrained is not None: - load_checkpoint(model, from_pretrained) - return model \ No newline at end of file + if inflate_from_2d: + load_checkpoint_with_inflation(model, from_pretrained) + else: + load_checkpoint(model, from_pretrained) + return model + + + +def load_checkpoint_with_inflation(model, ckpt_path): + """ + pre-train using image, then inflate to 3D videos + """ + if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"): + state_dict = find_model(ckpt_path) + breakpoint() # NOTE: need to manually check before first use + with torch.no_grad(): + for key in state_dict: + if key in model: + # central inflation + if state_dict[key].size() == model[key][:, :, 0, :, :].size(): + # temporal dimension + val = torch.zeros_like(model[key]) + centre = int(model[key].size(2) // 2) + val[:, :, centre, :, :] = state_dict[key] + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print(f"Missing keys: {missing_keys}") + print(f"Unexpected keys: {unexpected_keys}") + else: + load_checkpoint(model, ckpt_path) # use the default function + \ No newline at end of file