mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
debug
This commit is contained in:
parent
eae76b5a04
commit
17d6b58178
|
|
@ -30,7 +30,7 @@ model = dict(
|
|||
kl_embed_dim = 64,
|
||||
custom_conv_padding = None,
|
||||
activation_fn = 'swish',
|
||||
image_size = image_size,
|
||||
# image_size = image_size,
|
||||
separate_first_frame_encoding = False,
|
||||
# kl_loss_weight = 0.000001,
|
||||
# perceptual_loss_weight = 0.1, # use vgg is not None and more than 0
|
||||
|
|
@ -45,10 +45,11 @@ model = dict(
|
|||
|
||||
discriminator = dict(
|
||||
type="DISCRIMINATOR_3D",
|
||||
discriminator_in_channels = 3,
|
||||
discriminator_filters = 128,
|
||||
discriminator_channel_multipliers = (2,4,4,4,4),
|
||||
discriminator_start = 50001,
|
||||
image_size = image_size,
|
||||
num_frames = num_frames,
|
||||
in_channels = 3,
|
||||
filters = 128,
|
||||
channel_multipliers = (2,4,4,4,4),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -287,9 +287,9 @@ class StyleGANDiscriminator(nn.Module):
|
|||
self,
|
||||
image_size = (128, 128),
|
||||
num_frames = 17,
|
||||
discriminator_in_channels = 3,
|
||||
discriminator_filters = 128,
|
||||
discriminator_channel_multipliers = (2,4,4,4,4),
|
||||
in_channels = 3,
|
||||
filters = 128,
|
||||
channel_multipliers = (2,4,4,4,4),
|
||||
num_groups=32,
|
||||
dtype = torch.bfloat16,
|
||||
device="cpu",
|
||||
|
|
@ -298,11 +298,11 @@ class StyleGANDiscriminator(nn.Module):
|
|||
|
||||
self.dtype = dtype
|
||||
self.input_size = cast_tuple(image_size, 2)
|
||||
self.filters = discriminator_filters
|
||||
self.filters = filters
|
||||
self.activation_fn = nn.LeakyReLU(negative_slope=0.2)
|
||||
self.channel_multipliers = discriminator_channel_multipliers
|
||||
self.channel_multipliers = channel_multipliers
|
||||
|
||||
self.conv1 = nn.Conv3d(discriminator_in_channels, self.filters, (3, 3, 3), padding=1, device=device, dtype=dtype) # NOTE: init to xavier_uniform
|
||||
self.conv1 = nn.Conv3d(in_channels, self.filters, (3, 3, 3), padding=1, device=device, dtype=dtype) # NOTE: init to xavier_uniform
|
||||
|
||||
prev_filters = self.filters # record in_channels
|
||||
self.num_blocks = len(self.channel_multipliers)
|
||||
|
|
@ -357,7 +357,7 @@ class Encoder(nn.Module):
|
|||
channel_multipliers = (1, 2, 2, 4),
|
||||
temporal_downsample = (False, True, True),
|
||||
num_groups = 32, # for nn.GroupNorm
|
||||
in_out_channels = 3, # SCH: added, in_channels at the start
|
||||
# in_out_channels = 3, # SCH: added, in_channels at the start
|
||||
latent_embed_dim = 512, # num channels for latent vector
|
||||
# conv_downsample = False,
|
||||
custom_conv_padding = None,
|
||||
|
|
@ -463,7 +463,7 @@ class Decoder(nn.Module):
|
|||
def __init__(self,
|
||||
latent_embed_dim = 512,
|
||||
filters = 128,
|
||||
in_out_channels = 4,
|
||||
# in_out_channels = 4,
|
||||
num_res_blocks = 4,
|
||||
channel_multipliers = (1, 2, 2, 4),
|
||||
temporal_downsample = (False, True, True),
|
||||
|
|
@ -475,7 +475,7 @@ class Decoder(nn.Module):
|
|||
dtype=torch.bfloat16,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_dim = in_out_channels
|
||||
# self.output_dim = in_out_channels
|
||||
self.embedding_dim = latent_embed_dim
|
||||
self.filters = filters
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
|
@ -659,7 +659,7 @@ class VAE_3D_V2(nn.Module):
|
|||
channel_multipliers=channel_multipliers,
|
||||
temporal_downsample=temporal_downsample,
|
||||
num_groups = num_groups, # for nn.GroupNorm
|
||||
in_out_channels = in_out_channels,
|
||||
# in_out_channels = in_out_channels,
|
||||
latent_embed_dim = latent_embed_dim,
|
||||
# conv_downsample = conv_downsample,
|
||||
custom_conv_padding = custom_conv_padding,
|
||||
|
|
@ -670,7 +670,7 @@ class VAE_3D_V2(nn.Module):
|
|||
self.decoder = Decoder(
|
||||
latent_embed_dim = latent_embed_dim,
|
||||
filters = filters,
|
||||
in_out_channels = in_out_channels,
|
||||
# in_out_channels = in_out_channels,
|
||||
num_res_blocks = num_res_blocks,
|
||||
channel_multipliers = channel_multipliers,
|
||||
temporal_downsample = temporal_downsample,
|
||||
|
|
|
|||
Loading…
Reference in a new issue