mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-16 21:23:27 +02:00
debug
This commit is contained in:
parent
ca1d5863c3
commit
dc63a2b4e5
|
|
@ -609,8 +609,7 @@ class VAE_3D_V2(nn.Module):
|
|||
|
||||
# perceptual loss
|
||||
|
||||
if self.use_vgg:
|
||||
|
||||
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0:
|
||||
frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
|
||||
|
||||
input_vgg_input = pick_video_frame(video, frame_indices)
|
||||
|
|
@ -626,8 +625,8 @@ class VAE_3D_V2(nn.Module):
|
|||
|
||||
input_vgg_feats = self.vgg(input_vgg_input)
|
||||
recon_vgg_feats = self.vgg(recon_vgg_input)
|
||||
|
||||
perceptual_loss = F.mse_loss(input_vgg_feats, recon_vgg_feats)
|
||||
|
||||
else:
|
||||
perceptual_loss = self.zero
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue