This commit is contained in:
Shen-Chenhui 2024-04-01 16:14:11 +08:00
parent 6bed1bdd0b
commit 562a966a77
2 changed files with 3 additions and 6 deletions

View file

@ -230,7 +230,7 @@ class DiagonalGaussianDistribution(object):
# return loss
class VEA3DLossWithPerceptualLoss(nn.Module):
class VEA3DLoss(nn.Module):
def __init__(
self,
# disc_start,

View file

@ -33,7 +33,7 @@ from opensora.utils.config_utils import (
)
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype
from opensora.utils.train_utils import update_ema, MaskGenerator
from opensora.models.vae.model_utils import VEA3DLoss, VEA3DLossWithPerceptualLoss
from opensora.models.vae.model_utils import VEA3DLoss
def main():
@ -189,10 +189,7 @@ def main():
dataloader.sampler.set_start_index(sampler_start_idx)
# define loss function
if cfg.use_perceptual_loss:
loss_function = VEA3DLossWithPerceptualLoss(kl_weight=cfg.kl_weight).to(device, dtype)
else:
loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype).to(device, dtype)
loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype)
# 6.2. training loop