diff --git a/configs/vae_3d/train/16x256x256.py b/configs/vae_3d/train/16x256x256.py index 554080a..6f3dd4f 100644 --- a/configs/vae_3d/train/16x256x256.py +++ b/configs/vae_3d/train/16x256x256.py @@ -16,8 +16,7 @@ sp_size = 1 # Define Loss kl_weight = 0.000001 -use_perceptual_loss = True -perceptual_weight = 1 +perceptual_weight = 1.0 # Define model diff --git a/opensora/models/vae/model_utils.py b/opensora/models/vae/model_utils.py index 58eacfd..0a0b5b6 100644 --- a/opensora/models/vae/model_utils.py +++ b/opensora/models/vae/model_utils.py @@ -189,46 +189,46 @@ class DiagonalGaussianDistribution(object): def mode(self): return self.mean -class VEA3DLoss(nn.Module): - def __init__( - self, - # disc_start, - logvar_init=0.0, - kl_weight=1.0, - pixelloss_weight=1.0, - perceptual_weight=1.0, - disc_loss="hinge" - ): - super().__init__() - assert disc_loss in ["hinge", "vanilla"] - self.kl_weight = kl_weight - self.pixel_weight = pixelloss_weight - # self.perceptual_loss = LPIPS().eval() # TODO - self.perceptual_weight = perceptual_weight - # output log variance - self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) +# class VEA3DLoss(nn.Module): +# def __init__( +# self, +# # disc_start, +# logvar_init=0.0, +# kl_weight=1.0, +# pixelloss_weight=1.0, +# perceptual_weight=1.0, +# disc_loss="hinge" +# ): +# super().__init__() +# assert disc_loss in ["hinge", "vanilla"] +# self.kl_weight = kl_weight +# self.pixel_weight = pixelloss_weight +# # self.perceptual_loss = LPIPS().eval() # TODO +# self.perceptual_weight = perceptual_weight +# # output log variance +# self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) - def forward( - self, - inputs, - reconstructions, - posteriors, - weights=None, - ): - rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) +# def forward( +# self, +# inputs, +# reconstructions, +# posteriors, +# weights=None, +# ): +# rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar - weighted_nll_loss = nll_loss - if weights is not None: - weighted_nll_loss = weights*nll_loss - weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] - nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] - kl_loss = posteriors.kl() - kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] +# nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar +# weighted_nll_loss = nll_loss +# if weights is not None: +# weighted_nll_loss = weights*nll_loss +# weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] +# nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] +# kl_loss = posteriors.kl() +# kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] - loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later +# loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later - return loss +# return loss class VEA3DLossWithPerceptualLoss(nn.Module): def __init__( @@ -328,4 +328,3 @@ class VEA3DLossWithPerceptualLoss(nn.Module): # } return loss -