diff --git a/scripts/inference-vae.py b/scripts/inference-vae.py index afbc06d..bc9f597 100644 --- a/scripts/inference-vae.py +++ b/scripts/inference-vae.py @@ -145,7 +145,7 @@ def main(): model=vae, dataloader=dataloader ) # load model using booster - booster.load_model(vae, os.path.join(cfg.ckpt_path, "model")) + booster.load_model(vae, os.path.join(cfg.model["from_pretrained"], "model")) # 4.1. batch generation