diff --git a/scripts/train-vae.py b/scripts/train-vae.py index 5e850a7..2f7e5a1 100644 --- a/scripts/train-vae.py +++ b/scripts/train-vae.py @@ -175,7 +175,7 @@ def main(): if cfg.load is not None: logger.info("Loading checkpoint") booster.load_model(vae, os.path.join(cfg.load, "model")) - booster.load_optimizer(optimizer, os.path.join(cfg.load, "optimizer")) + # booster.load_optimizer(optimizer, os.path.join(cfg.load, "optimizer")) # if lr_scheduler is not None: # booster.load_lr_scheduler(lr_scheduler, os.path.join(cfg.load, "lr_scheduler")) running_states = load_json(os.path.join(cfg.load, "running_states.json"))