[stage2] update config

This commit is contained in:
zhengzangw 2024-06-05 15:49:51 +00:00
parent 5f9e642278
commit ce601066c5
3 changed files with 9 additions and 9 deletions

View file

@ -88,5 +88,4 @@ grad_clip = 1.0
lr = 1e-4
ema_decay = 0.99
adam_eps = 1e-15
warmup_steps = 1000

View file

@ -70,6 +70,7 @@ def parse_args(training=False):
parser.add_argument("--wandb", default=None, type=bool, help="enable wandb")
parser.add_argument("--load", default=None, type=str, help="path to continue training")
parser.add_argument("--start-from-scratch", action="store_true", help="start training from scratch")
parser.add_argument("--warmup-steps", default=None, type=int, help="warmup steps")
return parser.parse_args()

View file

@ -330,18 +330,18 @@ def main():
wandb.log(
{
"iter": global_step,
"acc_step": acc_step,
"epoch": epoch,
"loss": loss.item(),
"avg_loss": avg_loss,
"acc_step": acc_step,
"lr": optimizer.param_groups[0]["lr"],
"move_data_time": move_data_t.elapsed_time,
"encode_time": encode_t.elapsed_time,
"mask_time": mask_t.elapsed_time,
"diffusion_time": loss_t.elapsed_time,
"backward_time": backward_t.elapsed_time,
"update_ema_time": ema_t.elapsed_time,
"reduce_loss_time": reduce_loss_t.elapsed_time,
"debug/move_data_time": move_data_t.elapsed_time,
"debug/encode_time": encode_t.elapsed_time,
"debug/mask_time": mask_t.elapsed_time,
"debug/diffusion_time": loss_t.elapsed_time,
"debug/backward_time": backward_t.elapsed_time,
"debug/update_ema_time": ema_t.elapsed_time,
"debug/reduce_loss_time": reduce_loss_t.elapsed_time,
},
step=global_step,
)