diff --git a/configs/vae_magvit_v2/train/17x128x128.py b/configs/vae_magvit_v2/train/17x128x128.py index 5236fc0..56ee25c 100644 --- a/configs/vae_magvit_v2/train/17x128x128.py +++ b/configs/vae_magvit_v2/train/17x128x128.py @@ -58,8 +58,9 @@ discriminator = dict( kl_loss_weight = 0.000001 perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 discriminator_factor = 1.0 +discriminator_loss_weight = 0.5 # TODO: adjust value discriminator_loss="hinge" -discriminator_start = 1 # 50001 TODO: change to correct val, debug use 1 for now +discriminator_start = -1 # 50001 TODO: change to correct val, debug use -1 for now # Others seed = 42 diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index f03f057..0712c99 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -941,15 +941,6 @@ class VEALoss(nn.Module): ) vgg.classifier = Sequential(*vgg.classifier[:-2]) self.vgg = vgg.to(device, dtype) - - def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer): - breakpoint() # TODO: scrutinize - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.discriminator_weight - return d_weight def forward( self, @@ -1026,10 +1017,20 @@ class AdversarialLoss(nn.Module): self, discriminator_factor = 1.0, discriminator_start = 50001, + discriminator_loss_weight = 0.5, ): super().__init__() self.discriminator_factor = discriminator_factor self.discriminator_start = discriminator_start + self.discriminator_loss_weight = discriminator_loss_weight + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer): + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_loss_weight + return d_weight def forward( self, @@ -1038,17 +1039,16 @@ class AdversarialLoss(nn.Module): last_layer, global_step, is_training = True, - ): + ): + gan_loss = -torch.mean(fake_logits) if self.discriminator_factor is not None and self.discriminator_factor > 0.0: try: d_weight = self.calculate_adaptive_weight(nll_loss, gan_loss, last_layer) except RuntimeError: assert not is_training d_weight = torch.tensor(0.0) - gan_loss = -torch.mean(fake_logits) else: d_weight = torch.tensor(0.0) - gan_loss = torch.tensor(0.0) disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start) weighted_gan_loss = d_weight * disc_factor * gan_loss diff --git a/scripts/train-vae-v2.py b/scripts/train-vae-v2.py index b7dc6de..11276e3 100644 --- a/scripts/train-vae-v2.py +++ b/scripts/train-vae-v2.py @@ -271,10 +271,6 @@ def main(): else: video = x - # padded videos for GAN - if global_step > cfg.discriminator_start: - real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2) - fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2) # ====== VAE ====== optimizer.zero_grad() @@ -290,8 +286,11 @@ def main(): split = "train" ) vae_loss = nll_loss + # adversarial loss if global_step > cfg.discriminator_start: + # padded videos for GAN + fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2) fake_logits = discriminator(fake_video.contiguous()) adversarial_loss = adversarial_loss_fn( fake_logits, @@ -313,6 +312,7 @@ def main(): disc_optimizer.zero_grad() # if video_contains_first_frame: # Since we don't have enough T frames, pad anyways + real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2) real_logits = discriminator(real_video.contiguous().detach()) fake_logits = discriminator(fake_video.contiguous().detach()) disc_loss = disc_loss_fn(real_logits, fake_logits, global_step)