From 5730060f41f6c498a15d002ca429773ec04bcf29 Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 26 Mar 2025 18:04:16 +0800 Subject: [PATCH] [ckpt] mitigate gpu mem peak when loading ckpt --- opensora/utils/ckpt.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/opensora/utils/ckpt.py b/opensora/utils/ckpt.py index 2a0efdc..1065a27 100644 --- a/opensora/utils/ckpt.py +++ b/opensora/utils/ckpt.py @@ -113,8 +113,7 @@ def load_checkpoint( log_message(f"Loading checkpoint from {path}") if path.endswith(".safetensors"): - # ckpt = load_file(path, device=str(device_map)) - ckpt = load_file(path, device=torch.cuda.current_device()) + ckpt = load_file(path, device='cpu') if rename_keys is not None: # rename keys in the loaded state_dict with old_key_prefix to with new_key_prefix.