add z nll loss

This commit is contained in:
Shen-Chenhui 2024-04-29 18:09:42 +08:00
parent b78b469420
commit c631358bbf
3 changed files with 7 additions and 4 deletions

View file

@ -79,7 +79,6 @@ class VEALoss(nn.Module):
recon_video,
posterior,
nll_weights=None,
split="train",
):
video = rearrange(video, "b c t h w -> (b t) c h w").contiguous()
recon_video = rearrange(recon_video, "b c t h w -> (b t) c h w").contiguous()

View file

@ -171,7 +171,7 @@ def main():
if cfg.calc_loss:
# ====== Calc Loss ======
# simple nll loss
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(video, recon_video, posterior, split="eval")
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(video, recon_video, posterior)
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
fake_logits = discriminator(fake_video.contiguous())

View file

@ -340,7 +340,11 @@ def main():
# ====== Generator Loss ======
# simple nll loss
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(
video, recon_video, posterior, split="train"
video, recon_video, posterior
)
_, weighted_z_nll_loss, _ = vae_loss_fn(
t_video, t_recon_video, posterior
)
adversarial_loss = torch.tensor(0.0)
@ -357,7 +361,7 @@ def main():
is_training=vae.training,
)
vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss
vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss + weighted_z_nll_loss
optimizer.zero_grad()
# Backward & update