From dc63a2b4e5fea0b6e0522f1b0f56e7d48230b4e9 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Thu, 11 Apr 2024 11:28:05 +0800 Subject: [PATCH] debug --- opensora/models/vae/vae_3d_v2.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index 6204618..649fce4 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -609,8 +609,7 @@ class VAE_3D_V2(nn.Module): # perceptual loss - if self.use_vgg: - + if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0: frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices input_vgg_input = pick_video_frame(video, frame_indices) @@ -626,8 +625,8 @@ class VAE_3D_V2(nn.Module): input_vgg_feats = self.vgg(input_vgg_input) recon_vgg_feats = self.vgg(recon_vgg_input) - perceptual_loss = F.mse_loss(input_vgg_feats, recon_vgg_feats) + else: perceptual_loss = self.zero