This commit is contained in:
Shen-Chenhui 2024-04-11 11:28:05 +08:00
parent ca1d5863c3
commit dc63a2b4e5

View file

@ -609,8 +609,7 @@ class VAE_3D_V2(nn.Module):
# perceptual loss
if self.use_vgg:
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0:
frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
input_vgg_input = pick_video_frame(video, frame_indices)
@ -626,8 +625,8 @@ class VAE_3D_V2(nn.Module):
input_vgg_feats = self.vgg(input_vgg_input)
recon_vgg_feats = self.vgg(recon_vgg_input)
perceptual_loss = F.mse_loss(input_vgg_feats, recon_vgg_feats)
else:
perceptual_loss = self.zero