This commit is contained in:
Shen-Chenhui 2024-04-12 18:35:28 +08:00
parent c3763cb556
commit 8aad3d6de2

View file

@ -915,9 +915,19 @@ class VEALoss(nn.Module):
kl_loss_weight = 0.000001,
vgg=None,
vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
device = "cpu",
dtype = "bf16"
):
super().__init__()
if type(dtype) == str:
if dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp16":
dtype = torch.float16
else:
raise NotImplementedError(f'dtype: {dtype}')
# KL Loss
self.kl_loss_weight = kl_loss_weight
# Perceptual Loss
@ -929,7 +939,7 @@ class VEALoss(nn.Module):
weights = vgg_weights
)
vgg.classifier = Sequential(*vgg.classifier[:-2])
self.vgg = vgg
self.vgg = vgg.to(device, dtype)
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
breakpoint() # TODO: scrutinize