[exp] update

This commit is contained in:
Zangwei Zheng 2024-04-19 17:56:02 +08:00
parent 15f9d702ed
commit 74aade586a

View file

@ -210,14 +210,16 @@ def load(
) -> Tuple[int, int, int]:
booster.load_model(model, os.path.join(load_dir, "model"))
# ema is not boosted, so we don't use booster.load_model
# ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt")))
ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")))
ema.load_state_dict(
torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")),
strict=False,
)
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
if lr_scheduler is not None:
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(os.path.join(load_dir, "running_states.json"))
if sampler is not None:
sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler")), strict=False)
sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler")))
dist.barrier()
return (
running_states["epoch"],