diff --git a/opensora/utils/ckpt_utils.py b/opensora/utils/ckpt_utils.py index b8bdf7d..75fa700 100644 --- a/opensora/utils/ckpt_utils.py +++ b/opensora/utils/ckpt_utils.py @@ -210,14 +210,16 @@ def load( ) -> Tuple[int, int, int]: booster.load_model(model, os.path.join(load_dir, "model")) # ema is not boosted, so we don't use booster.load_model - # ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"))) - ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu"))) + ema.load_state_dict( + torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")), + strict=False, + ) booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) if lr_scheduler is not None: booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) running_states = load_json(os.path.join(load_dir, "running_states.json")) if sampler is not None: - sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler")), strict=False) + sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler"))) dist.barrier() return ( running_states["epoch"],