This commit is contained in:
Shen-Chenhui 2024-04-16 17:30:03 +08:00
parent dc45231c8a
commit de0199c6b2

View file

@ -890,7 +890,7 @@ class VEALoss(nn.Module):
# KL Loss
self.kl_loss_weight = kl_loss_weight
# Perceptual Loss
self.perceptual_loss_fn = LPIPS().eval()
self.perceptual_loss_fn = LPIPS().eval().to(device, dtype)
self.perceptual_loss_weight = perceptual_loss_weight
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)