diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index ff376a1..8a81e2c 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -691,7 +691,7 @@ class VAE_3D_V2(nn.Module): # # Perceptual Loss # self.vgg = None # self.perceptual_loss_weight = perceptual_loss_weight - # if perceptual_loss_weight is not None and perceptual_loss_weight > 0: + # if perceptual_loss_weight is not None and perceptual_loss_weight > 0.0: # # self.lpips = LPIPS().eval() # if not exists(vgg): # vgg = torchvision.models.vgg16( @@ -704,7 +704,7 @@ class VAE_3D_V2(nn.Module): # self.discriminator_factor = discriminator_factor # self.discriminator_start = discriminator_start # self.discriminator = None - # if discriminator_factor is not None and discriminator_factor > 0: + # if discriminator_factor is not None and discriminator_factor > 0.0: # self.discriminator = StyleGANDiscriminator( # image_size = image_size, # num_frames = num_frames, @@ -842,7 +842,7 @@ class VAE_3D_V2(nn.Module): # # KL Loss # weighted_kl_loss = 0 - # if self.kl_loss_weight is not None and self.kl_loss_weight > 0: + # if self.kl_loss_weight is not None and self.kl_loss_weight > 0.0: # kl_loss = posterior.kl() # # NOTE: since we use MSE, here use mean as well, else use sum # kl_loss = torch.mean(kl_loss) / kl_loss.shape[0] @@ -852,7 +852,7 @@ class VAE_3D_V2(nn.Module): # # Perceptual Loss # # SCH: NOTE: if mse can pick single frame, if use sum of errors, need to calc for all frames! # weighted_perceptual_loss = 0 - # if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0: + # if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0: # frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices # input_vgg_input = pick_video_frame(video, frame_indices) # recon_vgg_input = pick_video_frame(recon_video, frame_indices) @@ -954,7 +954,7 @@ class VEALoss(nn.Module): # KL Loss weighted_kl_loss = 0 - if self.kl_loss_weight is not None and self.kl_loss_weight > 0: + if self.kl_loss_weight is not None and self.kl_loss_weight > 0.0: kl_loss = posterior.kl() # NOTE: since we use MSE, here use mean as well, else use sum kl_loss = torch.mean(kl_loss) / kl_loss.shape[0] @@ -963,7 +963,7 @@ class VEALoss(nn.Module): # perceptual loss weighted_perceptual_loss = 0 - if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0: + if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0: assert exists(self.vgg) batch, channels, frames = video.shape[:3] frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices