mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-13 06:46:08 +02:00
debug
This commit is contained in:
parent
c3763cb556
commit
8aad3d6de2
|
|
@ -915,9 +915,19 @@ class VEALoss(nn.Module):
|
|||
kl_loss_weight = 0.000001,
|
||||
vgg=None,
|
||||
vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
|
||||
device = "cpu",
|
||||
dtype = "bf16"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if type(dtype) == str:
|
||||
if dtype == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
elif dtype == "fp16":
|
||||
dtype = torch.float16
|
||||
else:
|
||||
raise NotImplementedError(f'dtype: {dtype}')
|
||||
|
||||
# KL Loss
|
||||
self.kl_loss_weight = kl_loss_weight
|
||||
# Perceptual Loss
|
||||
|
|
@ -929,7 +939,7 @@ class VEALoss(nn.Module):
|
|||
weights = vgg_weights
|
||||
)
|
||||
vgg.classifier = Sequential(*vgg.classifier[:-2])
|
||||
self.vgg = vgg
|
||||
self.vgg = vgg.to(device, dtype)
|
||||
|
||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
|
||||
breakpoint() # TODO: scrutinize
|
||||
|
|
|
|||
Loading…
Reference in a new issue