mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
working perceptual loss
This commit is contained in:
parent
6f460c1d05
commit
6bed1bdd0b
|
|
@ -16,8 +16,7 @@ sp_size = 1
|
|||
|
||||
# Define Loss
|
||||
kl_weight = 0.000001
|
||||
use_perceptual_loss = True
|
||||
perceptual_weight = 1
|
||||
perceptual_weight = 1.0
|
||||
|
||||
|
||||
# Define model
|
||||
|
|
|
|||
|
|
@ -189,46 +189,46 @@ class DiagonalGaussianDistribution(object):
|
|||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
class VEA3DLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
# disc_start,
|
||||
logvar_init=0.0,
|
||||
kl_weight=1.0,
|
||||
pixelloss_weight=1.0,
|
||||
perceptual_weight=1.0,
|
||||
disc_loss="hinge"
|
||||
):
|
||||
super().__init__()
|
||||
assert disc_loss in ["hinge", "vanilla"]
|
||||
self.kl_weight = kl_weight
|
||||
self.pixel_weight = pixelloss_weight
|
||||
# self.perceptual_loss = LPIPS().eval() # TODO
|
||||
self.perceptual_weight = perceptual_weight
|
||||
# output log variance
|
||||
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
||||
# class VEA3DLoss(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# # disc_start,
|
||||
# logvar_init=0.0,
|
||||
# kl_weight=1.0,
|
||||
# pixelloss_weight=1.0,
|
||||
# perceptual_weight=1.0,
|
||||
# disc_loss="hinge"
|
||||
# ):
|
||||
# super().__init__()
|
||||
# assert disc_loss in ["hinge", "vanilla"]
|
||||
# self.kl_weight = kl_weight
|
||||
# self.pixel_weight = pixelloss_weight
|
||||
# # self.perceptual_loss = LPIPS().eval() # TODO
|
||||
# self.perceptual_weight = perceptual_weight
|
||||
# # output log variance
|
||||
# self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs,
|
||||
reconstructions,
|
||||
posteriors,
|
||||
weights=None,
|
||||
):
|
||||
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||
# def forward(
|
||||
# self,
|
||||
# inputs,
|
||||
# reconstructions,
|
||||
# posteriors,
|
||||
# weights=None,
|
||||
# ):
|
||||
# rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||
|
||||
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
||||
weighted_nll_loss = nll_loss
|
||||
if weights is not None:
|
||||
weighted_nll_loss = weights*nll_loss
|
||||
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
||||
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
kl_loss = posteriors.kl()
|
||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
# nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
||||
# weighted_nll_loss = nll_loss
|
||||
# if weights is not None:
|
||||
# weighted_nll_loss = weights*nll_loss
|
||||
# weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
||||
# nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
# kl_loss = posteriors.kl()
|
||||
# kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
|
||||
loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later
|
||||
# loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later
|
||||
|
||||
return loss
|
||||
# return loss
|
||||
|
||||
class VEA3DLossWithPerceptualLoss(nn.Module):
|
||||
def __init__(
|
||||
|
|
@ -328,4 +328,3 @@ class VEA3DLossWithPerceptualLoss(nn.Module):
|
|||
# }
|
||||
|
||||
return loss
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue