mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-17 22:56:10 +02:00
add z nll loss
This commit is contained in:
parent
b78b469420
commit
c631358bbf
|
|
@ -79,7 +79,6 @@ class VEALoss(nn.Module):
|
|||
recon_video,
|
||||
posterior,
|
||||
nll_weights=None,
|
||||
split="train",
|
||||
):
|
||||
video = rearrange(video, "b c t h w -> (b t) c h w").contiguous()
|
||||
recon_video = rearrange(recon_video, "b c t h w -> (b t) c h w").contiguous()
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ def main():
|
|||
if cfg.calc_loss:
|
||||
# ====== Calc Loss ======
|
||||
# simple nll loss
|
||||
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(video, recon_video, posterior, split="eval")
|
||||
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(video, recon_video, posterior)
|
||||
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
fake_logits = discriminator(fake_video.contiguous())
|
||||
|
|
|
|||
|
|
@ -340,7 +340,11 @@ def main():
|
|||
# ====== Generator Loss ======
|
||||
# simple nll loss
|
||||
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(
|
||||
video, recon_video, posterior, split="train"
|
||||
video, recon_video, posterior
|
||||
)
|
||||
|
||||
_, weighted_z_nll_loss, _ = vae_loss_fn(
|
||||
t_video, t_recon_video, posterior
|
||||
)
|
||||
|
||||
adversarial_loss = torch.tensor(0.0)
|
||||
|
|
@ -357,7 +361,7 @@ def main():
|
|||
is_training=vae.training,
|
||||
)
|
||||
|
||||
vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss
|
||||
vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss + weighted_z_nll_loss
|
||||
|
||||
optimizer.zero_grad()
|
||||
# Backward & update
|
||||
|
|
|
|||
Loading…
Reference in a new issue