[ckpt] mitigate gpu mem peak when loading ckpt

This commit is contained in:
hxwang 2025-03-26 18:04:16 +08:00
parent bc4aa4f217
commit 5730060f41
No known key found for this signature in database
GPG key ID: 0EC383D418F0B9F8

View file

@ -113,8 +113,7 @@ def load_checkpoint(
log_message(f"Loading checkpoint from {path}") log_message(f"Loading checkpoint from {path}")
if path.endswith(".safetensors"): if path.endswith(".safetensors"):
# ckpt = load_file(path, device=str(device_map)) ckpt = load_file(path, device='cpu')
ckpt = load_file(path, device=torch.cuda.current_device())
if rename_keys is not None: if rename_keys is not None:
# rename keys in the loaded state_dict with old_key_prefix to with new_key_prefix. # rename keys in the loaded state_dict with old_key_prefix to with new_key_prefix.