mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 21:42:26 +02:00
* upload v2.0 * update docs * [hotfix] fit latest fa3 (#802) * update readme * update readme * update readme * update train readme * update readme * update readme: motion score * cleaning video dc ae WIP * update config * add dependency functions * undo cleaning * use latest dcae * complete high compression training * update hcae config * cleaned up vae * update ae.md * further cleanup * update vae & ae paths * align naming of ae * [hotfix] fix ring attn bwd for fa3 (#803) * train ae default without wandb * update config * update evaluation results * added hcae report * update readme * update readme demo * update readme demo * update readme gif * display demo directly in readme * update paper * delete files --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Shen-Chenhui <shen_chenhui@u.nus.edu> Co-authored-by: wuxiwen <wuxiwen.simon@gmail.com>
77 lines
3.2 KiB
Python
77 lines
3.2 KiB
Python
import threading
|
|
from typing import Dict, List, Optional
|
|
|
|
import torch
|
|
|
|
|
|
class PinMemoryCache:
|
|
force_dtype: Optional[torch.dtype] = None
|
|
min_cache_numel: int = 0
|
|
pre_alloc_numels: List[int] = []
|
|
|
|
def __init__(self):
|
|
self.cache: Dict[int, torch.Tensor] = {}
|
|
self.output_to_cache: Dict[int, int] = {}
|
|
self.cache_to_output: Dict[int, int] = {}
|
|
self.lock = threading.Lock()
|
|
self.total_cnt = 0
|
|
self.hit_cnt = 0
|
|
|
|
if len(self.pre_alloc_numels) > 0 and self.force_dtype is not None:
|
|
for n in self.pre_alloc_numels:
|
|
cache_tensor = torch.empty(n, dtype=self.force_dtype, device="cpu", pin_memory=True)
|
|
with self.lock:
|
|
self.cache[id(cache_tensor)] = cache_tensor
|
|
|
|
def get(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
"""Receive a cpu tensor and return the corresponding pinned tensor. Note that this only manage memory allocation, doesn't copy content.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be pinned.
|
|
|
|
Returns:
|
|
torch.Tensor: The pinned tensor.
|
|
"""
|
|
self.total_cnt += 1
|
|
with self.lock:
|
|
# find free cache
|
|
for cache_id, cache_tensor in self.cache.items():
|
|
if cache_id not in self.cache_to_output and cache_tensor.numel() >= tensor.numel():
|
|
target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape)
|
|
out_id = id(target_cache_tensor)
|
|
self.output_to_cache[out_id] = cache_id
|
|
self.cache_to_output[cache_id] = out_id
|
|
self.hit_cnt += 1
|
|
return target_cache_tensor
|
|
# no free cache, create a new one
|
|
dtype = self.force_dtype if self.force_dtype is not None else tensor.dtype
|
|
cache_numel = max(tensor.numel(), self.min_cache_numel)
|
|
cache_tensor = torch.empty(cache_numel, dtype=dtype, device="cpu", pin_memory=True)
|
|
target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape)
|
|
out_id = id(target_cache_tensor)
|
|
with self.lock:
|
|
self.cache[id(cache_tensor)] = cache_tensor
|
|
self.output_to_cache[out_id] = id(cache_tensor)
|
|
self.cache_to_output[id(cache_tensor)] = out_id
|
|
return target_cache_tensor
|
|
|
|
def remove(self, output_tensor: torch.Tensor) -> None:
|
|
"""Release corresponding cache tensor.
|
|
|
|
Args:
|
|
output_tensor (torch.Tensor): The tensor to be released.
|
|
"""
|
|
out_id = id(output_tensor)
|
|
with self.lock:
|
|
if out_id not in self.output_to_cache:
|
|
raise ValueError("Tensor not found in cache.")
|
|
cache_id = self.output_to_cache.pop(out_id)
|
|
del self.cache_to_output[cache_id]
|
|
|
|
def __str__(self):
|
|
with self.lock:
|
|
num_cached = len(self.cache)
|
|
num_used = len(self.output_to_cache)
|
|
total_cache_size = sum([v.numel() * v.element_size() for v in self.cache.values()])
|
|
return f"PinMemoryCache(num_cached={num_cached}, num_used={num_used}, total_cache_size={total_cache_size / 1024**3:.2f} GB, hit rate={self.hit_cnt / self.total_cnt:.2f})"
|