From 8aad3d6de212fa5c2beb6fada4e932cddf84eb59 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Fri, 12 Apr 2024 18:35:28 +0800 Subject: [PATCH] debug --- opensora/models/vae/vae_3d_v2.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index 374e877..0be32da 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -915,9 +915,19 @@ class VEALoss(nn.Module): kl_loss_weight = 0.000001, vgg=None, vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT, + device = "cpu", + dtype = "bf16" ): super().__init__() + if type(dtype) == str: + if dtype == "bf16": + dtype = torch.bfloat16 + elif dtype == "fp16": + dtype = torch.float16 + else: + raise NotImplementedError(f'dtype: {dtype}') + # KL Loss self.kl_loss_weight = kl_loss_weight # Perceptual Loss @@ -929,7 +939,7 @@ class VEALoss(nn.Module): weights = vgg_weights ) vgg.classifier = Sequential(*vgg.classifier[:-2]) - self.vgg = vgg + self.vgg = vgg.to(device, dtype) def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer): breakpoint() # TODO: scrutinize