This commit is contained in:
Shen-Chenhui 2024-04-12 18:39:31 +08:00
parent dbd3982ee8
commit 06aa4589f2

View file

@ -309,18 +309,19 @@ def main():
running_loss += vae_loss.item()
# ====== Discriminator ======
disc_optimizer.zero_grad()
# if video_contains_first_frame:
# Since we don't have enough T frames, pad anyways
real_logits = discriminator(real_video.contiguous.detach())
fake_logits = discriminator(fake_video.contiguous.detach())
disc_loss = disc_loss_fn(real_logits, fake_logits, global_step)
# Backward & update
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
disc_optimizer.step()
# Log loss values:
all_reduce_mean(disc_loss)
running_disc_loss += disc_loss.item()
if global_step > cfg.discriminator_start:
disc_optimizer.zero_grad()
# if video_contains_first_frame:
# Since we don't have enough T frames, pad anyways
real_logits = discriminator(real_video.contiguous.detach())
fake_logits = discriminator(fake_video.contiguous.detach())
disc_loss = disc_loss_fn(real_logits, fake_logits, global_step)
# Backward & update
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
disc_optimizer.step()
# Log loss values:
all_reduce_mean(disc_loss)
running_disc_loss += disc_loss.item()
log_step += 1