This commit is contained in:
Shen-Chenhui 2024-04-12 18:18:06 +08:00
parent f905a1b69d
commit 735e9e722b

View file

@ -691,7 +691,7 @@ class VAE_3D_V2(nn.Module):
# # Perceptual Loss
# self.vgg = None
# self.perceptual_loss_weight = perceptual_loss_weight
# if perceptual_loss_weight is not None and perceptual_loss_weight > 0:
# if perceptual_loss_weight is not None and perceptual_loss_weight > 0.0:
# # self.lpips = LPIPS().eval()
# if not exists(vgg):
# vgg = torchvision.models.vgg16(
@ -704,7 +704,7 @@ class VAE_3D_V2(nn.Module):
# self.discriminator_factor = discriminator_factor
# self.discriminator_start = discriminator_start
# self.discriminator = None
# if discriminator_factor is not None and discriminator_factor > 0:
# if discriminator_factor is not None and discriminator_factor > 0.0:
# self.discriminator = StyleGANDiscriminator(
# image_size = image_size,
# num_frames = num_frames,
@ -842,7 +842,7 @@ class VAE_3D_V2(nn.Module):
# # KL Loss
# weighted_kl_loss = 0
# if self.kl_loss_weight is not None and self.kl_loss_weight > 0:
# if self.kl_loss_weight is not None and self.kl_loss_weight > 0.0:
# kl_loss = posterior.kl()
# # NOTE: since we use MSE, here use mean as well, else use sum
# kl_loss = torch.mean(kl_loss) / kl_loss.shape[0]
@ -852,7 +852,7 @@ class VAE_3D_V2(nn.Module):
# # Perceptual Loss
# # SCH: NOTE: if mse can pick single frame, if use sum of errors, need to calc for all frames!
# weighted_perceptual_loss = 0
# if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0:
# if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0:
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
# input_vgg_input = pick_video_frame(video, frame_indices)
# recon_vgg_input = pick_video_frame(recon_video, frame_indices)
@ -954,7 +954,7 @@ class VEALoss(nn.Module):
# KL Loss
weighted_kl_loss = 0
if self.kl_loss_weight is not None and self.kl_loss_weight > 0:
if self.kl_loss_weight is not None and self.kl_loss_weight > 0.0:
kl_loss = posterior.kl()
# NOTE: since we use MSE, here use mean as well, else use sum
kl_loss = torch.mean(kl_loss) / kl_loss.shape[0]
@ -963,7 +963,7 @@ class VEALoss(nn.Module):
# perceptual loss
weighted_perceptual_loss = 0
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0:
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0:
assert exists(self.vgg)
batch, channels, frames = video.shape[:3]
frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices