diff --git a/configs/vae/train/video_disc.py b/configs/vae/train/video_disc.py new file mode 100644 index 0000000..b95a10f --- /dev/null +++ b/configs/vae/train/video_disc.py @@ -0,0 +1,65 @@ +num_frames = 17 +image_size = (256, 256) + +# Define dataset +dataset = dict( + type="VideoTextDataset", + data_path=None, + num_frames=num_frames, + frame_interval=1, + image_size=image_size, +) + +# Define acceleration +num_workers = 16 +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" + +# Define model +model = dict( + type="VideoAutoencoderPipeline", + freeze_vae_2d=False, + from_pretrained=None, + vae_2d=dict( + type="VideoAutoencoderKL", + from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", + subfolder="vae", + local_files_only=True, + ), + vae_temporal=dict( + type="VAE_Temporal_SD", + from_pretrained=None, + ), +) + +discriminator = dict( + type="NLayerDiscriminator", + from_pretrained="/home/shenchenhui/opensoraplan-v1.0.0-discriminator.pt", + input_nc=3, + n_layers=3, + use_actnorm=False, +) + +# loss weights +perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 +kl_loss_weight = 1e-6 + +mixed_image_ratio = 0.2 +use_real_rec_loss = True +use_z_rec_loss = False +use_image_identity_loss = False + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 100 +log_every = 1 +ckpt_every = 1000 +load = None + +batch_size = 1 +lr = 1e-5 +grad_clip = 1.0 diff --git a/opensora/models/vae/discriminator.py b/opensora/models/vae/discriminator.py index e510387..ab00b93 100644 --- a/opensora/models/vae/discriminator.py +++ b/opensora/models/vae/discriminator.py @@ -148,6 +148,62 @@ class ResBlockDown(nn.Module): out = (residual + x) / math.sqrt(2) return out +@MODELS.register_module() +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, from_pretrained=None): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + + norm_layer = nn.BatchNorm2d + + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + if from_pretrained is not None: + load_checkpoint(self, from_pretrained) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + + class NLayerDiscriminator3D(nn.Module): """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.""" diff --git a/scripts/train-vae.py b/scripts/train-vae.py index 82a325f..a100291 100644 --- a/scripts/train-vae.py +++ b/scripts/train-vae.py @@ -109,6 +109,15 @@ def main(): f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}" ) + if cfg.get("discriminator", False) != False: + discriminator = build_module(cfg.discriminator, MODELS) + discriminator.to(device, dtype) + discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator) + logger.info( + f"Trainable model params: {format_numel_str(discriminator_numel_trainable)}, Total model params: {format_numel_str(discriminator_numel)}" + ) + breakpoint() + # 4.4 loss functions vae_loss_fn = VAELoss( logvar_init=cfg.get("logvar_init", 0.0),