mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
debug
This commit is contained in:
parent
f905a1b69d
commit
735e9e722b
|
|
@ -691,7 +691,7 @@ class VAE_3D_V2(nn.Module):
|
|||
# # Perceptual Loss
|
||||
# self.vgg = None
|
||||
# self.perceptual_loss_weight = perceptual_loss_weight
|
||||
# if perceptual_loss_weight is not None and perceptual_loss_weight > 0:
|
||||
# if perceptual_loss_weight is not None and perceptual_loss_weight > 0.0:
|
||||
# # self.lpips = LPIPS().eval()
|
||||
# if not exists(vgg):
|
||||
# vgg = torchvision.models.vgg16(
|
||||
|
|
@ -704,7 +704,7 @@ class VAE_3D_V2(nn.Module):
|
|||
# self.discriminator_factor = discriminator_factor
|
||||
# self.discriminator_start = discriminator_start
|
||||
# self.discriminator = None
|
||||
# if discriminator_factor is not None and discriminator_factor > 0:
|
||||
# if discriminator_factor is not None and discriminator_factor > 0.0:
|
||||
# self.discriminator = StyleGANDiscriminator(
|
||||
# image_size = image_size,
|
||||
# num_frames = num_frames,
|
||||
|
|
@ -842,7 +842,7 @@ class VAE_3D_V2(nn.Module):
|
|||
|
||||
# # KL Loss
|
||||
# weighted_kl_loss = 0
|
||||
# if self.kl_loss_weight is not None and self.kl_loss_weight > 0:
|
||||
# if self.kl_loss_weight is not None and self.kl_loss_weight > 0.0:
|
||||
# kl_loss = posterior.kl()
|
||||
# # NOTE: since we use MSE, here use mean as well, else use sum
|
||||
# kl_loss = torch.mean(kl_loss) / kl_loss.shape[0]
|
||||
|
|
@ -852,7 +852,7 @@ class VAE_3D_V2(nn.Module):
|
|||
# # Perceptual Loss
|
||||
# # SCH: NOTE: if mse can pick single frame, if use sum of errors, need to calc for all frames!
|
||||
# weighted_perceptual_loss = 0
|
||||
# if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0:
|
||||
# if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0:
|
||||
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
|
||||
# input_vgg_input = pick_video_frame(video, frame_indices)
|
||||
# recon_vgg_input = pick_video_frame(recon_video, frame_indices)
|
||||
|
|
@ -954,7 +954,7 @@ class VEALoss(nn.Module):
|
|||
|
||||
# KL Loss
|
||||
weighted_kl_loss = 0
|
||||
if self.kl_loss_weight is not None and self.kl_loss_weight > 0:
|
||||
if self.kl_loss_weight is not None and self.kl_loss_weight > 0.0:
|
||||
kl_loss = posterior.kl()
|
||||
# NOTE: since we use MSE, here use mean as well, else use sum
|
||||
kl_loss = torch.mean(kl_loss) / kl_loss.shape[0]
|
||||
|
|
@ -963,7 +963,7 @@ class VEALoss(nn.Module):
|
|||
|
||||
# perceptual loss
|
||||
weighted_perceptual_loss = 0
|
||||
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0:
|
||||
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0:
|
||||
assert exists(self.vgg)
|
||||
batch, channels, frames = video.shape[:3]
|
||||
frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
|
||||
|
|
|
|||
Loading…
Reference in a new issue