Merge pull request #841 from hpcaitech/oom_fix

[ckpt] mitigate gpu mem peak when loading ckpt
This commit is contained in:
Hanks 2025-03-27 09:53:41 +08:00 committed by GitHub
commit d0cd5ac50d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -113,8 +113,7 @@ def load_checkpoint(
log_message(f"Loading checkpoint from {path}")
if path.endswith(".safetensors"):
# ckpt = load_file(path, device=str(device_map))
ckpt = load_file(path, device=torch.cuda.current_device())
ckpt = load_file(path, device='cpu')
if rename_keys is not None:
# rename keys in the loaded state_dict with old_key_prefix to with new_key_prefix.