This commit is contained in:
Shen-Chenhui 2024-04-11 17:19:05 +08:00
parent 5abb9239fb
commit a5015c46c8

View file

@ -318,19 +318,15 @@ class StyleGANDiscriminator(nn.Module):
def forward(self, x):
breakpoint()
x = self.conv1(x)
x = self.activation_fn(x)
breakpoint()
for i in range(self.num_blocks):
x = self.res_block_list[i](x)
breakpoint()
x = self.conv2(x)
x = self.norm1(x)
x = self.activation_fn(x)
breakpoint()
x = x.reshape((x.shape[0], -1)) # SCH: [B, (C * T * W * H)] ?
x = self.linear1(x)
@ -634,7 +630,7 @@ class VAE_3D_V2(nn.Module):
self.time_downsample_factor = 2**sum(temporal_downsample)
self.time_padding = self.time_downsample_factor - 1
self.disc_time_downsample_factor = 2**len(discriminator_channel_multipliers)
self.disc_time_padding = self.disc_time_downsample_factor - 1
self.disc_time_padding = self.disc_time_downsample_factor - num_frames % self.disc_time_downsample_factor
self.separate_first_frame_encoding = separate_first_frame_encoding
image_down = 2 ** len(temporal_downsample)