mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 12:49:38 +02:00
[ckpt] mitigate gpu mem peak when loading ckpt
This commit is contained in:
parent
bc4aa4f217
commit
5730060f41
|
|
@ -113,8 +113,7 @@ def load_checkpoint(
|
||||||
|
|
||||||
log_message(f"Loading checkpoint from {path}")
|
log_message(f"Loading checkpoint from {path}")
|
||||||
if path.endswith(".safetensors"):
|
if path.endswith(".safetensors"):
|
||||||
# ckpt = load_file(path, device=str(device_map))
|
ckpt = load_file(path, device='cpu')
|
||||||
ckpt = load_file(path, device=torch.cuda.current_device())
|
|
||||||
|
|
||||||
if rename_keys is not None:
|
if rename_keys is not None:
|
||||||
# rename keys in the loaded state_dict with old_key_prefix to with new_key_prefix.
|
# rename keys in the loaded state_dict with old_key_prefix to with new_key_prefix.
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue