use default discriminator

This commit is contained in:
Shen-Chenhui 2024-04-18 09:32:17 +08:00
parent 5084d4e4f9
commit e241ecbbe1

View file

@ -445,7 +445,7 @@ class StyleGANDiscriminatorBlur(nn.Module):
self.linear1 = nn.Linear(in_features, prev_filters, device=device, dtype=dtype) # NOTE: init to xavier_uniform
self.linear2 = nn.Linear(prev_filters, 1, device=device, dtype=dtype) # NOTE: init to xavier_uniform
self.apply(xavier_uniform_weight_init)
# self.apply(xavier_uniform_weight_init)
def forward(self, x):
@ -1197,9 +1197,9 @@ def VAE_MAGVIT_V2(from_pretrained=None, **kwargs):
@MODELS.register_module("DISCRIMINATOR_3D")
def DISCRIMINATOR_3D(from_pretrained=None, **kwargs):
# model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init)
model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init)
# model = StyleGANDiscriminator(**kwargs).apply(xavier_uniform_weight_init) # SCH: DEBUG: to change back
model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init)
# model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model