mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-16 12:55:02 +02:00
debug
This commit is contained in:
parent
6bed1bdd0b
commit
562a966a77
|
|
@ -230,7 +230,7 @@ class DiagonalGaussianDistribution(object):
|
|||
|
||||
# return loss
|
||||
|
||||
class VEA3DLossWithPerceptualLoss(nn.Module):
|
||||
class VEA3DLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
# disc_start,
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from opensora.utils.config_utils import (
|
|||
)
|
||||
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype
|
||||
from opensora.utils.train_utils import update_ema, MaskGenerator
|
||||
from opensora.models.vae.model_utils import VEA3DLoss, VEA3DLossWithPerceptualLoss
|
||||
from opensora.models.vae.model_utils import VEA3DLoss
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -189,10 +189,7 @@ def main():
|
|||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
|
||||
# define loss function
|
||||
if cfg.use_perceptual_loss:
|
||||
loss_function = VEA3DLossWithPerceptualLoss(kl_weight=cfg.kl_weight).to(device, dtype)
|
||||
else:
|
||||
loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype).to(device, dtype)
|
||||
loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype)
|
||||
|
||||
|
||||
# 6.2. training loop
|
||||
|
|
|
|||
Loading…
Reference in a new issue