From 7811b8e99af437c55bc35fe11908d0ef66cfc368 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Mon, 29 Apr 2024 14:18:57 +0800 Subject: [PATCH] add Nlayer 3D discriminator --- opensora/models/vae/README.md | 2 ++ opensora/models/vae/vae_3d_v2.py | 16 +++++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/opensora/models/vae/README.md b/opensora/models/vae/README.md index acc3d8f..0a6e6c6 100644 --- a/opensora/models/vae/README.md +++ b/opensora/models/vae/README.md @@ -73,5 +73,7 @@ CUDA_VISIBLE_DEVICES7 torchrun --master_port=29510 --nnodes=1 --nproc_per_node=1 ### 2.4 Data +full data combining the follwing: `/home/shenchenhui/data/pixabay+pexels.csv` + * ~/data/pixabay: `/home/data/sora_data/pixabay/raw/data/split-0` * pexels: `/home/litianyi/data/pexels/processed/meta/pexels_caption_vinfo_ready_noempty_clean.csv` \ No newline at end of file diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index 3cac557..777d0ea 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -1355,7 +1355,21 @@ def VAE_MAGVIT_V2(from_pretrained=None, **kwargs): @MODELS.register_module("DISCRIMINATOR_3D") def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, use_pretrained=True, **kwargs): - # model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init) + model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init) + if from_pretrained is not None: + if use_pretrained: + if inflate_from_2d: + load_checkpoint_with_inflation(model, from_pretrained) + else: + load_checkpoint(model, from_pretrained, model_name="discriminator") + print(f"loaded discriminator") + else: + print(f"discriminator use_pretrained={use_pretrained}, initializing new discriminator") + + return model + +@MODELS.register_module("N_Layer_DISCRIMINATOR_3D") +def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, use_pretrained=True, **kwargs): model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init) if from_pretrained is not None: if use_pretrained: