diff --git a/opensora/utils/ckpt.py b/opensora/utils/ckpt.py index 2a0efdc..1065a27 100644 --- a/opensora/utils/ckpt.py +++ b/opensora/utils/ckpt.py @@ -113,8 +113,7 @@ def load_checkpoint( log_message(f"Loading checkpoint from {path}") if path.endswith(".safetensors"): - # ckpt = load_file(path, device=str(device_map)) - ckpt = load_file(path, device=torch.cuda.current_device()) + ckpt = load_file(path, device='cpu') if rename_keys is not None: # rename keys in the loaded state_dict with old_key_prefix to with new_key_prefix.