mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
[exp] update
This commit is contained in:
parent
15f9d702ed
commit
74aade586a
|
|
@ -210,14 +210,16 @@ def load(
|
|||
) -> Tuple[int, int, int]:
|
||||
booster.load_model(model, os.path.join(load_dir, "model"))
|
||||
# ema is not boosted, so we don't use booster.load_model
|
||||
# ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt")))
|
||||
ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")))
|
||||
ema.load_state_dict(
|
||||
torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")),
|
||||
strict=False,
|
||||
)
|
||||
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
|
||||
if lr_scheduler is not None:
|
||||
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
|
||||
running_states = load_json(os.path.join(load_dir, "running_states.json"))
|
||||
if sampler is not None:
|
||||
sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler")), strict=False)
|
||||
sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler")))
|
||||
dist.barrier()
|
||||
return (
|
||||
running_states["epoch"],
|
||||
|
|
|
|||
Loading…
Reference in a new issue