add Nlayer 3D discriminator

This commit is contained in:
Shen-Chenhui 2024-04-29 14:18:57 +08:00
parent 069ffcc4e3
commit 7811b8e99a
2 changed files with 17 additions and 1 deletions

View file

@ -73,5 +73,7 @@ CUDA_VISIBLE_DEVICES7 torchrun --master_port=29510 --nnodes=1 --nproc_per_node=1
### 2.4 Data
full data combining the follwing: `/home/shenchenhui/data/pixabay+pexels.csv`
* ~/data/pixabay: `/home/data/sora_data/pixabay/raw/data/split-0`
* pexels: `/home/litianyi/data/pexels/processed/meta/pexels_caption_vinfo_ready_noempty_clean.csv`

View file

@ -1355,7 +1355,21 @@ def VAE_MAGVIT_V2(from_pretrained=None, **kwargs):
@MODELS.register_module("DISCRIMINATOR_3D")
def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, use_pretrained=True, **kwargs):
# model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init)
model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init)
if from_pretrained is not None:
if use_pretrained:
if inflate_from_2d:
load_checkpoint_with_inflation(model, from_pretrained)
else:
load_checkpoint(model, from_pretrained, model_name="discriminator")
print(f"loaded discriminator")
else:
print(f"discriminator use_pretrained={use_pretrained}, initializing new discriminator")
return model
@MODELS.register_module("N_Layer_DISCRIMINATOR_3D")
def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, use_pretrained=True, **kwargs):
model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init)
if from_pretrained is not None:
if use_pretrained: