From 17d6b58178806bd68b34b691601fc2e6eeb75f86 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Fri, 12 Apr 2024 18:08:42 +0800 Subject: [PATCH] debug --- configs/vae_magvit_v2/train/17x128x128.py | 11 ++++++----- opensora/models/vae/vae_3d_v2.py | 22 +++++++++++----------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/configs/vae_magvit_v2/train/17x128x128.py b/configs/vae_magvit_v2/train/17x128x128.py index 9e69791..29cfb56 100644 --- a/configs/vae_magvit_v2/train/17x128x128.py +++ b/configs/vae_magvit_v2/train/17x128x128.py @@ -30,7 +30,7 @@ model = dict( kl_embed_dim = 64, custom_conv_padding = None, activation_fn = 'swish', - image_size = image_size, + # image_size = image_size, separate_first_frame_encoding = False, # kl_loss_weight = 0.000001, # perceptual_loss_weight = 0.1, # use vgg is not None and more than 0 @@ -45,10 +45,11 @@ model = dict( discriminator = dict( type="DISCRIMINATOR_3D", - discriminator_in_channels = 3, - discriminator_filters = 128, - discriminator_channel_multipliers = (2,4,4,4,4), - discriminator_start = 50001, + image_size = image_size, + num_frames = num_frames, + in_channels = 3, + filters = 128, + channel_multipliers = (2,4,4,4,4), ) diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index 5a9a542..1ff96f8 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -287,9 +287,9 @@ class StyleGANDiscriminator(nn.Module): self, image_size = (128, 128), num_frames = 17, - discriminator_in_channels = 3, - discriminator_filters = 128, - discriminator_channel_multipliers = (2,4,4,4,4), + in_channels = 3, + filters = 128, + channel_multipliers = (2,4,4,4,4), num_groups=32, dtype = torch.bfloat16, device="cpu", @@ -298,11 +298,11 @@ class StyleGANDiscriminator(nn.Module): self.dtype = dtype self.input_size = cast_tuple(image_size, 2) - self.filters = discriminator_filters + self.filters = filters self.activation_fn = nn.LeakyReLU(negative_slope=0.2) - self.channel_multipliers = discriminator_channel_multipliers + self.channel_multipliers = channel_multipliers - self.conv1 = nn.Conv3d(discriminator_in_channels, self.filters, (3, 3, 3), padding=1, device=device, dtype=dtype) # NOTE: init to xavier_uniform + self.conv1 = nn.Conv3d(in_channels, self.filters, (3, 3, 3), padding=1, device=device, dtype=dtype) # NOTE: init to xavier_uniform prev_filters = self.filters # record in_channels self.num_blocks = len(self.channel_multipliers) @@ -357,7 +357,7 @@ class Encoder(nn.Module): channel_multipliers = (1, 2, 2, 4), temporal_downsample = (False, True, True), num_groups = 32, # for nn.GroupNorm - in_out_channels = 3, # SCH: added, in_channels at the start + # in_out_channels = 3, # SCH: added, in_channels at the start latent_embed_dim = 512, # num channels for latent vector # conv_downsample = False, custom_conv_padding = None, @@ -463,7 +463,7 @@ class Decoder(nn.Module): def __init__(self, latent_embed_dim = 512, filters = 128, - in_out_channels = 4, + # in_out_channels = 4, num_res_blocks = 4, channel_multipliers = (1, 2, 2, 4), temporal_downsample = (False, True, True), @@ -475,7 +475,7 @@ class Decoder(nn.Module): dtype=torch.bfloat16, ): super().__init__() - self.output_dim = in_out_channels + # self.output_dim = in_out_channels self.embedding_dim = latent_embed_dim self.filters = filters self.num_res_blocks = num_res_blocks @@ -659,7 +659,7 @@ class VAE_3D_V2(nn.Module): channel_multipliers=channel_multipliers, temporal_downsample=temporal_downsample, num_groups = num_groups, # for nn.GroupNorm - in_out_channels = in_out_channels, + # in_out_channels = in_out_channels, latent_embed_dim = latent_embed_dim, # conv_downsample = conv_downsample, custom_conv_padding = custom_conv_padding, @@ -670,7 +670,7 @@ class VAE_3D_V2(nn.Module): self.decoder = Decoder( latent_embed_dim = latent_embed_dim, filters = filters, - in_out_channels = in_out_channels, + # in_out_channels = in_out_channels, num_res_blocks = num_res_blocks, channel_multipliers = channel_multipliers, temporal_downsample = temporal_downsample,