diff --git a/scripts/train-vae.py b/scripts/train-vae.py index 4481ddd..389c132 100644 --- a/scripts/train-vae.py +++ b/scripts/train-vae.py @@ -164,27 +164,27 @@ def main(): num_steps_per_epoch = len(dataloader) logger.info("Boost vae for distributed training") - # # ======================================================= - # # 6. training loop - # # ======================================================= - # start_epoch = start_step = log_step = sampler_start_idx = 0 - # running_loss = 0.0 + # ======================================================= + # 6. training loop + # ======================================================= + start_epoch = start_step = log_step = sampler_start_idx = 0 + running_loss = 0.0 - # # 6.1. resume training - # if cfg.load is not None: - # logger.info("Loading checkpoint") - # booster.load_model(vae, os.path.join(cfg.load, "model")) - # booster.load_optimizer(optimizer, os.path.join(cfg.load, "optimizer")) - # if lr_scheduler is not None: - # booster.load_lr_scheduler(lr_scheduler, os.path.join(cfg.load, "lr_scheduler")) - # running_states = load_json(os.path.join(cfg.load, "running_states.json")) - # dist.barrier() - # start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"] - # logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}") - # logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch") + # 6.1. resume training + if cfg.load is not None: + logger.info("Loading checkpoint") + booster.load_model(vae, os.path.join(cfg.load, "model")) + booster.load_optimizer(optimizer, os.path.join(cfg.load, "optimizer")) + if lr_scheduler is not None: + booster.load_lr_scheduler(lr_scheduler, os.path.join(cfg.load, "lr_scheduler")) + running_states = load_json(os.path.join(cfg.load, "running_states.json")) + dist.barrier() + start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"] + logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}") + logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch") - # dataloader.sampler.set_start_index(sampler_start_idx) + dataloader.sampler.set_start_index(sampler_start_idx) # # define loss function # loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype)