diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index 36319f7..3e176ce 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -814,6 +814,7 @@ class VAE_3D_V2(nn.Module): kl_loss = torch.mean(kl_loss) / kl_loss.shape[0] # TODO: DOUBLE add more sophisticated discrminator loss + gen_loss = self.zero if self.adversarial_loss_weight is not None and self.adversarial_loss_weight > 0: if video_contains_first_frame: # video_len = video.shape[2] @@ -829,6 +830,7 @@ class VAE_3D_V2(nn.Module): # perceptual loss # SCH: NOTE: if mse can pick single frame, if use sum of errors, need to calc for all frames! + perceptual_loss = self.zero if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0: frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices @@ -847,8 +849,7 @@ class VAE_3D_V2(nn.Module): recon_vgg_feats = self.vgg(recon_vgg_input) perceptual_loss = F.mse_loss(input_vgg_feats, recon_vgg_feats) # perceptual_loss = self.lpips(input_vgg_input.contiguous(), recon_vgg_input.contiguous()) - else: - perceptual_loss = self.zero + total_loss = recon_loss \