This commit is contained in:
Shen-Chenhui 2024-04-11 15:38:40 +08:00
parent f6c79a75c2
commit f8a1cc34b2

View file

@ -232,10 +232,10 @@ class ResBlockDown(nn.Module):
# SCH: NOTE: use blur pooling instead, pooling bias is False following enc dec conv pool
self.blur = Blur()
self.conv_pool_residual = nn.Conv3d(in_channels * 8, in_channels, 3, use_bias=False) # NOTE: init to xavier_uniform
self.conv_pool_input = nn.Conv3d(in_channels * 8, in_channels, 3, use_bias=False) # NOTE: init to xavier_uniform
self.conv_pool_residual = nn.Conv3d(in_channels * 8, in_channels, 3, bias=False) # NOTE: init to xavier_uniform
self.conv_pool_input = nn.Conv3d(in_channels * 8, in_channels, 3, bias=False) # NOTE: init to xavier_uniform
self.conv2 = nn.Conv3d(in_channels, self.filters,(1,1,1), use_bias=False) # NOTE: init to xavier_uniform
self.conv2 = nn.Conv3d(in_channels, self.filters,(1,1,1), bias=False) # NOTE: init to xavier_uniform
self.conv3 = nn.Conv3d(in_channels, self.filters, (3,3,3)) # NOTE: init to xavier_uniform
self.norm2 = nn.GroupNorm(num_groups, self.filters, device=device, dtype=dtype)
@ -278,7 +278,7 @@ class StyleGANDiscriminator(nn.Module):
device="cpu",
):
super().__init__()
self.dtype = dtype
self.input_size = cast_tuple(image_size, 2)
self.filters = discriminator_filters