mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-21 11:59:01 +02:00
[fix] wandb log only if record time
This commit is contained in:
parent
194e2204c1
commit
06f7eb93f0
|
|
@ -347,24 +347,27 @@ def main():
|
|||
tb_writer.add_scalar("loss", loss.item(), global_step)
|
||||
# wandb
|
||||
if cfg.get("wandb", False):
|
||||
wandb.log(
|
||||
{
|
||||
"iter": global_step,
|
||||
"acc_step": acc_step,
|
||||
"epoch": epoch,
|
||||
"loss": loss.item(),
|
||||
"avg_loss": avg_loss,
|
||||
"lr": optimizer.param_groups[0]["lr"],
|
||||
"debug/move_data_time": move_data_t.elapsed_time,
|
||||
"debug/encode_time": encode_t.elapsed_time,
|
||||
"debug/mask_time": mask_t.elapsed_time,
|
||||
"debug/diffusion_time": loss_t.elapsed_time,
|
||||
"debug/backward_time": backward_t.elapsed_time,
|
||||
"debug/update_ema_time": ema_t.elapsed_time,
|
||||
"debug/reduce_loss_time": reduce_loss_t.elapsed_time,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
wandb_dict = {
|
||||
"iter": global_step,
|
||||
"acc_step": acc_step,
|
||||
"epoch": epoch,
|
||||
"loss": loss.item(),
|
||||
"avg_loss": avg_loss,
|
||||
"lr": optimizer.param_groups[0]["lr"],
|
||||
}
|
||||
if record_time:
|
||||
wandb_dict.update(
|
||||
{
|
||||
"debug/move_data_time": move_data_t.elapsed_time,
|
||||
"debug/encode_time": encode_t.elapsed_time,
|
||||
"debug/mask_time": mask_t.elapsed_time,
|
||||
"debug/diffusion_time": loss_t.elapsed_time,
|
||||
"debug/backward_time": backward_t.elapsed_time,
|
||||
"debug/update_ema_time": ema_t.elapsed_time,
|
||||
"debug/reduce_loss_time": reduce_loss_t.elapsed_time,
|
||||
}
|
||||
)
|
||||
wandb.log(wandb_dict, step=global_step)
|
||||
|
||||
running_loss = 0.0
|
||||
log_step = 0
|
||||
|
|
|
|||
Loading…
Reference in a new issue