mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-22 15:30:17 +02:00
use default discriminator
This commit is contained in:
parent
5084d4e4f9
commit
e241ecbbe1
|
|
@ -445,7 +445,7 @@ class StyleGANDiscriminatorBlur(nn.Module):
|
|||
self.linear1 = nn.Linear(in_features, prev_filters, device=device, dtype=dtype) # NOTE: init to xavier_uniform
|
||||
self.linear2 = nn.Linear(prev_filters, 1, device=device, dtype=dtype) # NOTE: init to xavier_uniform
|
||||
|
||||
self.apply(xavier_uniform_weight_init)
|
||||
# self.apply(xavier_uniform_weight_init)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
|
|
@ -1197,9 +1197,9 @@ def VAE_MAGVIT_V2(from_pretrained=None, **kwargs):
|
|||
|
||||
@MODELS.register_module("DISCRIMINATOR_3D")
|
||||
def DISCRIMINATOR_3D(from_pretrained=None, **kwargs):
|
||||
# model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init)
|
||||
model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init)
|
||||
# model = StyleGANDiscriminator(**kwargs).apply(xavier_uniform_weight_init) # SCH: DEBUG: to change back
|
||||
model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init)
|
||||
# model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init)
|
||||
if from_pretrained is not None:
|
||||
load_checkpoint(model, from_pretrained)
|
||||
return model
|
||||
Loading…
Reference in a new issue