[ckpt] mitigate gpu mem peak when loading ckpt

This commit is contained in:
hxwang 2025-03-26 18:04:16 +08:00
parent bc4aa4f217
commit 5730060f41
No known key found for this signature in database
GPG key ID: 0EC383D418F0B9F8

View file

@ -113,8 +113,7 @@ def load_checkpoint(
log_message(f"Loading checkpoint from {path}")
if path.endswith(".safetensors"):
# ckpt = load_file(path, device=str(device_map))
ckpt = load_file(path, device=torch.cuda.current_device())
ckpt = load_file(path, device='cpu')
if rename_keys is not None:
# rename keys in the loaded state_dict with old_key_prefix to with new_key_prefix.