From 562a966a77be3af2bba6ece36b08cbbede70ea47 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Mon, 1 Apr 2024 16:14:11 +0800 Subject: [PATCH] debug --- opensora/models/vae/model_utils.py | 2 +- scripts/train-vae.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/opensora/models/vae/model_utils.py b/opensora/models/vae/model_utils.py index 0a0b5b6..b624288 100644 --- a/opensora/models/vae/model_utils.py +++ b/opensora/models/vae/model_utils.py @@ -230,7 +230,7 @@ class DiagonalGaussianDistribution(object): # return loss -class VEA3DLossWithPerceptualLoss(nn.Module): +class VEA3DLoss(nn.Module): def __init__( self, # disc_start, diff --git a/scripts/train-vae.py b/scripts/train-vae.py index 8c36af0..89c9835 100644 --- a/scripts/train-vae.py +++ b/scripts/train-vae.py @@ -33,7 +33,7 @@ from opensora.utils.config_utils import ( ) from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype from opensora.utils.train_utils import update_ema, MaskGenerator -from opensora.models.vae.model_utils import VEA3DLoss, VEA3DLossWithPerceptualLoss +from opensora.models.vae.model_utils import VEA3DLoss def main(): @@ -189,10 +189,7 @@ def main(): dataloader.sampler.set_start_index(sampler_start_idx) # define loss function - if cfg.use_perceptual_loss: - loss_function = VEA3DLossWithPerceptualLoss(kl_weight=cfg.kl_weight).to(device, dtype) - else: - loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype).to(device, dtype) + loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype) # 6.2. training loop