From 78caeea999abeab62c5c2ee7d720bad9e8effd2a Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Fri, 19 Apr 2024 15:14:53 +0800 Subject: [PATCH] add support and config for z channel 4 --- configs/vae_magvit_v2/train/16x128x128_z=4.py | 82 +++++++++++++++++++ opensora/models/vae/vae_3d_v2.py | 10 ++- 2 files changed, 89 insertions(+), 3 deletions(-) create mode 100644 configs/vae_magvit_v2/train/16x128x128_z=4.py diff --git a/configs/vae_magvit_v2/train/16x128x128_z=4.py b/configs/vae_magvit_v2/train/16x128x128_z=4.py new file mode 100644 index 0000000..8dc7feb --- /dev/null +++ b/configs/vae_magvit_v2/train/16x128x128_z=4.py @@ -0,0 +1,82 @@ +num_frames = 16 +frame_interval = 3 +image_size = (128, 128) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 +video_contains_first_frame = False + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + + +# Define model + +model = dict( + type="VAE_MAGVIT_V2", + in_out_channels = 3, + latent_embed_dim = 4, + filters = 128, + num_res_blocks = 4, + channel_multipliers = (1, 2, 2, 4), + temporal_downsample = (False, True, True), + num_groups = 32, # for nn.GroupNorm + kl_embed_dim = 4, + activation_fn = 'swish', + separate_first_frame_encoding = False, + disable_space = False, + custom_conv_padding = None +) + + +discriminator = dict( + type="DISCRIMINATOR_3D", + image_size = image_size, + num_frames = num_frames, + in_channels = 3, + filters = 128, + channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution +) + + +# loss weights +logvar_init=0.0 +kl_loss_weight = 0.000001 +perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 +discriminator_factor = 1.0 # for discriminator adversarial loss +# discriminator_loss_weight = 0.5 # for generator adversarial loss +generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1 +lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001 +discriminator_loss_type="non-saturating" +generator_loss_type="non-saturating" +discriminator_start = 1000 # 50000 NOTE: change to correct val, debug use -1 for now +gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use +ema_decay = 0.999 # ema decay factor for generator + + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +# Training +''' NOTE: +magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 +==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200], +3-6 epochs for pexel, from pexel observation its correct +''' + +epochs = 200 +log_every = 1 +ckpt_every = 200 +load = None + +batch_size = 4 +lr = 1e-4 +grad_clip = 1.0 diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index 225b072..fe97dda 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -739,6 +739,7 @@ class VAE_3D_V2(nn.Module): # , ModelMixin activation_fn = 'swish', in_out_channels = 4, kl_embed_dim = 64, + encoder_double_z = True, device="cpu", dtype="bf16", ): @@ -782,7 +783,7 @@ class VAE_3D_V2(nn.Module): # , ModelMixin temporal_downsample=temporal_downsample, num_groups = num_groups, # for nn.GroupNorm # in_out_channels = in_out_channels, - latent_embed_dim = latent_embed_dim, + latent_embed_dim = latent_embed_dim * 2 if encoder_double_z else latent_embed_dim, # conv_downsample = conv_downsample, disable_spatial_downsample=disable_space, custom_conv_padding = custom_conv_padding, @@ -805,8 +806,11 @@ class VAE_3D_V2(nn.Module): # , ModelMixin device = device, dtype = dtype, ) - - self.quant_conv = nn.Conv3d(latent_embed_dim, 2*kl_embed_dim, 1, device=device, dtype=dtype) + + if encoder_double_z: + self.quant_conv = nn.Conv3d(2*latent_embed_dim, 2*kl_embed_dim, 1, device=device, dtype=dtype) + else: + self.quant_conv = nn.Conv3d(latent_embed_dim, 2*kl_embed_dim, 1, device=device, dtype=dtype) self.post_quant_conv = nn.Conv3d(kl_embed_dim, latent_embed_dim, 1, device=device, dtype=dtype) def get_latent_size(self, input_size):