mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-21 03:33:55 +02:00
debug
This commit is contained in:
parent
dbd3982ee8
commit
06aa4589f2
|
|
@ -309,18 +309,19 @@ def main():
|
|||
running_loss += vae_loss.item()
|
||||
|
||||
# ====== Discriminator ======
|
||||
disc_optimizer.zero_grad()
|
||||
# if video_contains_first_frame:
|
||||
# Since we don't have enough T frames, pad anyways
|
||||
real_logits = discriminator(real_video.contiguous.detach())
|
||||
fake_logits = discriminator(fake_video.contiguous.detach())
|
||||
disc_loss = disc_loss_fn(real_logits, fake_logits, global_step)
|
||||
# Backward & update
|
||||
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
|
||||
disc_optimizer.step()
|
||||
# Log loss values:
|
||||
all_reduce_mean(disc_loss)
|
||||
running_disc_loss += disc_loss.item()
|
||||
if global_step > cfg.discriminator_start:
|
||||
disc_optimizer.zero_grad()
|
||||
# if video_contains_first_frame:
|
||||
# Since we don't have enough T frames, pad anyways
|
||||
real_logits = discriminator(real_video.contiguous.detach())
|
||||
fake_logits = discriminator(fake_video.contiguous.detach())
|
||||
disc_loss = disc_loss_fn(real_logits, fake_logits, global_step)
|
||||
# Backward & update
|
||||
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
|
||||
disc_optimizer.step()
|
||||
# Log loss values:
|
||||
all_reduce_mean(disc_loss)
|
||||
running_disc_loss += disc_loss.item()
|
||||
|
||||
log_step += 1
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue