diff --git a/configs/vae_magvit_v2/train/17x128x128.py b/configs/vae_magvit_v2/train/17x128x128.py index 23e07a7..f27ac04 100644 --- a/configs/vae_magvit_v2/train/17x128x128.py +++ b/configs/vae_magvit_v2/train/17x128x128.py @@ -34,12 +34,12 @@ model = dict( separate_first_frame_encoding = False, kl_loss_weight = 0.000001, perceptual_loss_weight = 0.1, # use vgg is not None and more than 0 - adversarial_loss_weight = 1.0, + discriminator_factor = 1.0, discriminator_in_channels = 3, discriminator_filters = 128, discriminator_channel_multipliers = (2,4,4,4,4), + discriminator_loss="hinge", discriminator_start = 50001, - discriminator_weight = 0.5, ) # Others diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index a5499ea..d76ed70 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -65,6 +65,7 @@ def hinge_d_loss(logits_real, logits_fake): loss_real = torch.mean(F.relu(1. - logits_real)) loss_fake = torch.mean(F.relu(1. + logits_fake)) d_loss = 0.5 * (loss_real + loss_fake) + breakpoint() # TODO: CHECK mean rather than sum return d_loss def vanilla_d_loss(logits_real, logits_fake): @@ -94,20 +95,6 @@ def SameConv2d(dim_in, dim_out, kernel_size): return nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding) -def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): - breakpoint() # TODO: scrutinize - if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - else: - nll_grads = torch.autograd.grad(nll_loss, self.get_last_layer(), retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, self.get_last_layer(), retain_graph=True)[0] - - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.discriminator_weight - return d_weight - def adopt_weight(weight, global_step, threshold=0, value=0.): if global_step < threshold: weight = value @@ -619,6 +606,7 @@ class VAE_3D_V2(nn.Module): discriminator_filters = 128, discriminator_channel_multipliers = (2,4,4,4,4), discriminator_loss="hinge", + discriminator_start=50001, num_groups = 32, # for nn.GroupNorm # conv_downsample = False, # upsample = "nearest+conv", # options: "deconv", "nearest+conv" @@ -714,6 +702,7 @@ class VAE_3D_V2(nn.Module): # Adversarial Loss self.discriminator_factor = discriminator_factor + self.discriminator_start = discriminator_start self.discriminator = None if discriminator_factor is not None and discriminator_factor > 0: self.discriminator = StyleGANDiscriminator( @@ -728,9 +717,9 @@ class VAE_3D_V2(nn.Module): ) if discriminator_loss == "hinge": - self.calc_disc_loss = hinge_d_loss + self.disc_loss_fn = hinge_d_loss elif discriminator_loss == "vanilla": - self.calc_disc_loss = vanilla_d_loss + self.disc_loss_fn = vanilla_d_loss else: raise ValueError(f"Unknown GAN loss '{discriminator_loss}'.") @@ -797,6 +786,21 @@ class VAE_3D_V2(nn.Module): def get_last_layer(self): return self.conv_out.weight + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + breakpoint() # TODO: scrutinize + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.get_last_layer(), retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.get_last_layer(), retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + + return d_weight + def parameters(self): return [ *self.conv_in.parameters(), @@ -896,7 +900,7 @@ class VAE_3D_V2(nn.Module): d_weight = torch.tensor(0.0) gan_loss = torch.tensor(0.0) - disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_iter_start) + disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start) weighted_gan_loss = d_weight * disc_factor * gan_loss breakpoint() total_loss += weighted_gan_loss @@ -922,7 +926,7 @@ class VAE_3D_V2(nn.Module): # real_logits = self.discriminator(real_video.contiguous.detach()) # fake_logits = self.discriminator(fake_video.contiguous.detach()) # disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_iter_start) - # weight_discriminator_loss = disc_factor * self.calc_disc_loss(real_logits, fake_logits) + # weight_discriminator_loss = disc_factor * self.disc_loss_fn(real_logits, fake_logits) # else: # weight_discriminator_loss = 0 @@ -949,8 +953,8 @@ class VAE_3D_V2(nn.Module): fake_video = pad_at_dim(recon_video, (self.discr_time_padding, 0), value = 0., dim = 2) real_logits = self.discriminator(real_video.contiguous.detach()) fake_logits = self.discriminator(fake_video.contiguous.detach()) - disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_iter_start) - weight_discriminator_loss = disc_factor * self.calc_disc_loss(real_logits, fake_logits) + disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start) + weight_discriminator_loss = disc_factor * self.disc_loss_fn(real_logits, fake_logits) else: weight_discriminator_loss = 0