[fix] wandb log only if record time

This commit is contained in:
zhengzangw 2024-07-07 05:34:08 +00:00
parent 194e2204c1
commit 06f7eb93f0

View file

@ -347,24 +347,27 @@ def main():
tb_writer.add_scalar("loss", loss.item(), global_step)
# wandb
if cfg.get("wandb", False):
wandb.log(
{
"iter": global_step,
"acc_step": acc_step,
"epoch": epoch,
"loss": loss.item(),
"avg_loss": avg_loss,
"lr": optimizer.param_groups[0]["lr"],
"debug/move_data_time": move_data_t.elapsed_time,
"debug/encode_time": encode_t.elapsed_time,
"debug/mask_time": mask_t.elapsed_time,
"debug/diffusion_time": loss_t.elapsed_time,
"debug/backward_time": backward_t.elapsed_time,
"debug/update_ema_time": ema_t.elapsed_time,
"debug/reduce_loss_time": reduce_loss_t.elapsed_time,
},
step=global_step,
)
wandb_dict = {
"iter": global_step,
"acc_step": acc_step,
"epoch": epoch,
"loss": loss.item(),
"avg_loss": avg_loss,
"lr": optimizer.param_groups[0]["lr"],
}
if record_time:
wandb_dict.update(
{
"debug/move_data_time": move_data_t.elapsed_time,
"debug/encode_time": encode_t.elapsed_time,
"debug/mask_time": mask_t.elapsed_time,
"debug/diffusion_time": loss_t.elapsed_time,
"debug/backward_time": backward_t.elapsed_time,
"debug/update_ema_time": ema_t.elapsed_time,
"debug/reduce_loss_time": reduce_loss_t.elapsed_time,
}
)
wandb.log(wandb_dict, step=global_step)
running_loss = 0.0
log_step = 0