This commit is contained in:
Shen-Chenhui 2024-04-12 18:08:42 +08:00
parent eae76b5a04
commit 17d6b58178
2 changed files with 17 additions and 16 deletions

View file

@ -30,7 +30,7 @@ model = dict(
kl_embed_dim = 64,
custom_conv_padding = None,
activation_fn = 'swish',
image_size = image_size,
# image_size = image_size,
separate_first_frame_encoding = False,
# kl_loss_weight = 0.000001,
# perceptual_loss_weight = 0.1, # use vgg is not None and more than 0
@ -45,10 +45,11 @@ model = dict(
discriminator = dict(
type="DISCRIMINATOR_3D",
discriminator_in_channels = 3,
discriminator_filters = 128,
discriminator_channel_multipliers = (2,4,4,4,4),
discriminator_start = 50001,
image_size = image_size,
num_frames = num_frames,
in_channels = 3,
filters = 128,
channel_multipliers = (2,4,4,4,4),
)

View file

@ -287,9 +287,9 @@ class StyleGANDiscriminator(nn.Module):
self,
image_size = (128, 128),
num_frames = 17,
discriminator_in_channels = 3,
discriminator_filters = 128,
discriminator_channel_multipliers = (2,4,4,4,4),
in_channels = 3,
filters = 128,
channel_multipliers = (2,4,4,4,4),
num_groups=32,
dtype = torch.bfloat16,
device="cpu",
@ -298,11 +298,11 @@ class StyleGANDiscriminator(nn.Module):
self.dtype = dtype
self.input_size = cast_tuple(image_size, 2)
self.filters = discriminator_filters
self.filters = filters
self.activation_fn = nn.LeakyReLU(negative_slope=0.2)
self.channel_multipliers = discriminator_channel_multipliers
self.channel_multipliers = channel_multipliers
self.conv1 = nn.Conv3d(discriminator_in_channels, self.filters, (3, 3, 3), padding=1, device=device, dtype=dtype) # NOTE: init to xavier_uniform
self.conv1 = nn.Conv3d(in_channels, self.filters, (3, 3, 3), padding=1, device=device, dtype=dtype) # NOTE: init to xavier_uniform
prev_filters = self.filters # record in_channels
self.num_blocks = len(self.channel_multipliers)
@ -357,7 +357,7 @@ class Encoder(nn.Module):
channel_multipliers = (1, 2, 2, 4),
temporal_downsample = (False, True, True),
num_groups = 32, # for nn.GroupNorm
in_out_channels = 3, # SCH: added, in_channels at the start
# in_out_channels = 3, # SCH: added, in_channels at the start
latent_embed_dim = 512, # num channels for latent vector
# conv_downsample = False,
custom_conv_padding = None,
@ -463,7 +463,7 @@ class Decoder(nn.Module):
def __init__(self,
latent_embed_dim = 512,
filters = 128,
in_out_channels = 4,
# in_out_channels = 4,
num_res_blocks = 4,
channel_multipliers = (1, 2, 2, 4),
temporal_downsample = (False, True, True),
@ -475,7 +475,7 @@ class Decoder(nn.Module):
dtype=torch.bfloat16,
):
super().__init__()
self.output_dim = in_out_channels
# self.output_dim = in_out_channels
self.embedding_dim = latent_embed_dim
self.filters = filters
self.num_res_blocks = num_res_blocks
@ -659,7 +659,7 @@ class VAE_3D_V2(nn.Module):
channel_multipliers=channel_multipliers,
temporal_downsample=temporal_downsample,
num_groups = num_groups, # for nn.GroupNorm
in_out_channels = in_out_channels,
# in_out_channels = in_out_channels,
latent_embed_dim = latent_embed_dim,
# conv_downsample = conv_downsample,
custom_conv_padding = custom_conv_padding,
@ -670,7 +670,7 @@ class VAE_3D_V2(nn.Module):
self.decoder = Decoder(
latent_embed_dim = latent_embed_dim,
filters = filters,
in_out_channels = in_out_channels,
# in_out_channels = in_out_channels,
num_res_blocks = num_res_blocks,
channel_multipliers = channel_multipliers,
temporal_downsample = temporal_downsample,