working perceptual loss

This commit is contained in:
Shen-Chenhui 2024-04-01 16:11:21 +08:00
parent 6f460c1d05
commit 6bed1bdd0b
2 changed files with 37 additions and 39 deletions

View file

@ -16,8 +16,7 @@ sp_size = 1
# Define Loss
kl_weight = 0.000001
use_perceptual_loss = True
perceptual_weight = 1
perceptual_weight = 1.0
# Define model

View file

@ -189,46 +189,46 @@ class DiagonalGaussianDistribution(object):
def mode(self):
return self.mean
class VEA3DLoss(nn.Module):
def __init__(
self,
# disc_start,
logvar_init=0.0,
kl_weight=1.0,
pixelloss_weight=1.0,
perceptual_weight=1.0,
disc_loss="hinge"
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
self.kl_weight = kl_weight
self.pixel_weight = pixelloss_weight
# self.perceptual_loss = LPIPS().eval() # TODO
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
# class VEA3DLoss(nn.Module):
# def __init__(
# self,
# # disc_start,
# logvar_init=0.0,
# kl_weight=1.0,
# pixelloss_weight=1.0,
# perceptual_weight=1.0,
# disc_loss="hinge"
# ):
# super().__init__()
# assert disc_loss in ["hinge", "vanilla"]
# self.kl_weight = kl_weight
# self.pixel_weight = pixelloss_weight
# # self.perceptual_loss = LPIPS().eval() # TODO
# self.perceptual_weight = perceptual_weight
# # output log variance
# self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
def forward(
self,
inputs,
reconstructions,
posteriors,
weights=None,
):
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
# def forward(
# self,
# inputs,
# reconstructions,
# posteriors,
# weights=None,
# ):
# rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights*nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
# nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
# weighted_nll_loss = nll_loss
# if weights is not None:
# weighted_nll_loss = weights*nll_loss
# weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
# nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
# kl_loss = posteriors.kl()
# kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later
# loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later
return loss
# return loss
class VEA3DLossWithPerceptualLoss(nn.Module):
def __init__(
@ -328,4 +328,3 @@ class VEA3DLossWithPerceptualLoss(nn.Module):
# }
return loss