diff --git a/scripts/train-vae-v2.py b/scripts/train-vae-v2.py index 2580cd8..c15c1fc 100644 --- a/scripts/train-vae-v2.py +++ b/scripts/train-vae-v2.py @@ -369,7 +369,6 @@ def main(): running_loss += vae_loss.item() - # ====== Discriminator Loss ====== if global_step > cfg.discriminator_start: # if video_contains_first_frame: @@ -419,6 +418,10 @@ def main(): running_disc_loss += disc_loss.item() else: disc_loss = torch.tensor(0.0) + weighted_d_adversarial_loss = torch.tensor(0.0) + lecam_loss = torch.tensor(0.0) + gradient_penalty_loss = torch.tensor(0.0) + log_step += 1 # Log to tensorboard