From c631358bbf41b769fa46ea4bab4d7d6c819edb0c Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Mon, 29 Apr 2024 18:09:42 +0800 Subject: [PATCH] add z nll loss --- opensora/models/vae/losses.py | 1 - scripts/inference-vae.py | 2 +- scripts/train-vae.py | 8 ++++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/opensora/models/vae/losses.py b/opensora/models/vae/losses.py index 0561e70..4adc85c 100644 --- a/opensora/models/vae/losses.py +++ b/opensora/models/vae/losses.py @@ -79,7 +79,6 @@ class VEALoss(nn.Module): recon_video, posterior, nll_weights=None, - split="train", ): video = rearrange(video, "b c t h w -> (b t) c h w").contiguous() recon_video = rearrange(recon_video, "b c t h w -> (b t) c h w").contiguous() diff --git a/scripts/inference-vae.py b/scripts/inference-vae.py index e566a24..bb36d91 100644 --- a/scripts/inference-vae.py +++ b/scripts/inference-vae.py @@ -171,7 +171,7 @@ def main(): if cfg.calc_loss: # ====== Calc Loss ====== # simple nll loss - nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(video, recon_video, posterior, split="eval") + nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(video, recon_video, posterior) fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2) fake_logits = discriminator(fake_video.contiguous()) diff --git a/scripts/train-vae.py b/scripts/train-vae.py index 387e28a..aa8f07f 100644 --- a/scripts/train-vae.py +++ b/scripts/train-vae.py @@ -340,7 +340,11 @@ def main(): # ====== Generator Loss ====== # simple nll loss nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn( - video, recon_video, posterior, split="train" + video, recon_video, posterior + ) + + _, weighted_z_nll_loss, _ = vae_loss_fn( + t_video, t_recon_video, posterior ) adversarial_loss = torch.tensor(0.0) @@ -357,7 +361,7 @@ def main(): is_training=vae.training, ) - vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss + vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss + weighted_z_nll_loss optimizer.zero_grad() # Backward & update