mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
add Nlayer 3D discriminator
This commit is contained in:
parent
069ffcc4e3
commit
7811b8e99a
|
|
@ -73,5 +73,7 @@ CUDA_VISIBLE_DEVICES7 torchrun --master_port=29510 --nnodes=1 --nproc_per_node=1
|
|||
|
||||
### 2.4 Data
|
||||
|
||||
full data combining the follwing: `/home/shenchenhui/data/pixabay+pexels.csv`
|
||||
|
||||
* ~/data/pixabay: `/home/data/sora_data/pixabay/raw/data/split-0`
|
||||
* pexels: `/home/litianyi/data/pexels/processed/meta/pexels_caption_vinfo_ready_noempty_clean.csv`
|
||||
|
|
@ -1355,7 +1355,21 @@ def VAE_MAGVIT_V2(from_pretrained=None, **kwargs):
|
|||
|
||||
@MODELS.register_module("DISCRIMINATOR_3D")
|
||||
def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, use_pretrained=True, **kwargs):
|
||||
# model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init)
|
||||
model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init)
|
||||
if from_pretrained is not None:
|
||||
if use_pretrained:
|
||||
if inflate_from_2d:
|
||||
load_checkpoint_with_inflation(model, from_pretrained)
|
||||
else:
|
||||
load_checkpoint(model, from_pretrained, model_name="discriminator")
|
||||
print(f"loaded discriminator")
|
||||
else:
|
||||
print(f"discriminator use_pretrained={use_pretrained}, initializing new discriminator")
|
||||
|
||||
return model
|
||||
|
||||
@MODELS.register_module("N_Layer_DISCRIMINATOR_3D")
|
||||
def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, use_pretrained=True, **kwargs):
|
||||
model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init)
|
||||
if from_pretrained is not None:
|
||||
if use_pretrained:
|
||||
|
|
|
|||
Loading…
Reference in a new issue