mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-28 23:41:48 +02:00
debug
This commit is contained in:
parent
eae30f9f89
commit
9c8d084ec5
|
|
@ -164,27 +164,27 @@ def main():
|
|||
num_steps_per_epoch = len(dataloader)
|
||||
logger.info("Boost vae for distributed training")
|
||||
|
||||
# # =======================================================
|
||||
# # 6. training loop
|
||||
# # =======================================================
|
||||
# start_epoch = start_step = log_step = sampler_start_idx = 0
|
||||
# running_loss = 0.0
|
||||
# =======================================================
|
||||
# 6. training loop
|
||||
# =======================================================
|
||||
start_epoch = start_step = log_step = sampler_start_idx = 0
|
||||
running_loss = 0.0
|
||||
|
||||
|
||||
# # 6.1. resume training
|
||||
# if cfg.load is not None:
|
||||
# logger.info("Loading checkpoint")
|
||||
# booster.load_model(vae, os.path.join(cfg.load, "model"))
|
||||
# booster.load_optimizer(optimizer, os.path.join(cfg.load, "optimizer"))
|
||||
# if lr_scheduler is not None:
|
||||
# booster.load_lr_scheduler(lr_scheduler, os.path.join(cfg.load, "lr_scheduler"))
|
||||
# running_states = load_json(os.path.join(cfg.load, "running_states.json"))
|
||||
# dist.barrier()
|
||||
# start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"]
|
||||
# logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}")
|
||||
# logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch")
|
||||
# 6.1. resume training
|
||||
if cfg.load is not None:
|
||||
logger.info("Loading checkpoint")
|
||||
booster.load_model(vae, os.path.join(cfg.load, "model"))
|
||||
booster.load_optimizer(optimizer, os.path.join(cfg.load, "optimizer"))
|
||||
if lr_scheduler is not None:
|
||||
booster.load_lr_scheduler(lr_scheduler, os.path.join(cfg.load, "lr_scheduler"))
|
||||
running_states = load_json(os.path.join(cfg.load, "running_states.json"))
|
||||
dist.barrier()
|
||||
start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"]
|
||||
logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}")
|
||||
logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch")
|
||||
|
||||
# dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
|
||||
# # define loss function
|
||||
# loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype)
|
||||
|
|
|
|||
Loading…
Reference in a new issue