This commit is contained in:
Shen-Chenhui 2024-04-08 17:26:36 +08:00
parent eae30f9f89
commit 9c8d084ec5

View file

@ -164,27 +164,27 @@ def main():
num_steps_per_epoch = len(dataloader)
logger.info("Boost vae for distributed training")
# # =======================================================
# # 6. training loop
# # =======================================================
# start_epoch = start_step = log_step = sampler_start_idx = 0
# running_loss = 0.0
# =======================================================
# 6. training loop
# =======================================================
start_epoch = start_step = log_step = sampler_start_idx = 0
running_loss = 0.0
# # 6.1. resume training
# if cfg.load is not None:
# logger.info("Loading checkpoint")
# booster.load_model(vae, os.path.join(cfg.load, "model"))
# booster.load_optimizer(optimizer, os.path.join(cfg.load, "optimizer"))
# if lr_scheduler is not None:
# booster.load_lr_scheduler(lr_scheduler, os.path.join(cfg.load, "lr_scheduler"))
# running_states = load_json(os.path.join(cfg.load, "running_states.json"))
# dist.barrier()
# start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"]
# logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}")
# logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch")
# 6.1. resume training
if cfg.load is not None:
logger.info("Loading checkpoint")
booster.load_model(vae, os.path.join(cfg.load, "model"))
booster.load_optimizer(optimizer, os.path.join(cfg.load, "optimizer"))
if lr_scheduler is not None:
booster.load_lr_scheduler(lr_scheduler, os.path.join(cfg.load, "lr_scheduler"))
running_states = load_json(os.path.join(cfg.load, "running_states.json"))
dist.barrier()
start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"]
logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}")
logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch")
# dataloader.sampler.set_start_index(sampler_start_idx)
dataloader.sampler.set_start_index(sampler_start_idx)
# # define loss function
# loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype)