mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 12:49:38 +02:00
debug
This commit is contained in:
parent
06d5a4a5e9
commit
30353e1351
|
|
@ -814,6 +814,7 @@ class VAE_3D_V2(nn.Module):
|
|||
kl_loss = torch.mean(kl_loss) / kl_loss.shape[0]
|
||||
|
||||
# TODO: DOUBLE add more sophisticated discrminator loss
|
||||
gen_loss = self.zero
|
||||
if self.adversarial_loss_weight is not None and self.adversarial_loss_weight > 0:
|
||||
if video_contains_first_frame:
|
||||
# video_len = video.shape[2]
|
||||
|
|
@ -829,6 +830,7 @@ class VAE_3D_V2(nn.Module):
|
|||
|
||||
# perceptual loss
|
||||
# SCH: NOTE: if mse can pick single frame, if use sum of errors, need to calc for all frames!
|
||||
perceptual_loss = self.zero
|
||||
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0:
|
||||
frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
|
||||
|
||||
|
|
@ -847,8 +849,7 @@ class VAE_3D_V2(nn.Module):
|
|||
recon_vgg_feats = self.vgg(recon_vgg_input)
|
||||
perceptual_loss = F.mse_loss(input_vgg_feats, recon_vgg_feats)
|
||||
# perceptual_loss = self.lpips(input_vgg_input.contiguous(), recon_vgg_input.contiguous())
|
||||
else:
|
||||
perceptual_loss = self.zero
|
||||
|
||||
|
||||
|
||||
total_loss = recon_loss \
|
||||
|
|
|
|||
Loading…
Reference in a new issue