From d8fab103f4eb7861234e516845554b0a116695aa Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Thu, 2 May 2024 09:48:30 +0000 Subject: [PATCH] added disc --- configs/vae/train/video_disc.py | 12 ++++++------ opensora/models/vae/vae.py | 2 +- scripts/train-vae.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/configs/vae/train/video_disc.py b/configs/vae/train/video_disc.py index 0e10d62..9967c5f 100644 --- a/configs/vae/train/video_disc.py +++ b/configs/vae/train/video_disc.py @@ -42,13 +42,13 @@ discriminator = dict( ) # discriminator hyper-parames TODO -discriminator_factor=1, +discriminator_factor=1 discriminator_start=-1 -generator_factor=0.5, -generator_loss_type="hinge", -discriminator_loss_type="hinge", -lecam_loss_weight=None, -gradient_penalty_loss_weight=None, +generator_factor=0.5 +generator_loss_type="hinge" +discriminator_loss_type="hinge" +lecam_loss_weight=None +gradient_penalty_loss_weight=None # loss weights perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 diff --git a/opensora/models/vae/vae.py b/opensora/models/vae/vae.py index ac320b7..ec2643d 100644 --- a/opensora/models/vae/vae.py +++ b/opensora/models/vae/vae.py @@ -153,7 +153,7 @@ class VideoAutoencoderPipeline(nn.Module): return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size)) def get_temporal_last_layer(self): - return self.temporal_vae.last_layer[0] + return self.temporal_vae.decoder.conv_out.conv.weight @property def device(self): diff --git a/scripts/train-vae.py b/scripts/train-vae.py index 58e33ed..60400d7 100644 --- a/scripts/train-vae.py +++ b/scripts/train-vae.py @@ -277,9 +277,9 @@ def main(): vae_loss += adversarial_loss # Backward & update + optimizer.zero_grad() booster.backward(loss=vae_loss, optimizer=optimizer) optimizer.step() - optimizer.zero_grad() # Adversarial Discriminator loss @@ -295,9 +295,9 @@ def main(): ) disc_loss = weighted_d_adversarial_loss # Backward & update + disc_optimizer.zero_grad() booster.backward(loss=disc_loss, optimizer=disc_optimizer) disc_optimizer.step() - disc_optimizer.zero_grad() all_reduce_mean(disc_loss) running_disc_loss += disc_loss.item()