Open-Sora/opensora/utils/misc.py

413 lines
10 KiB
Python
Raw Normal View History

2024-03-15 15:16:20 +01:00
import collections
import importlib
import logging
2024-05-12 07:15:06 +02:00
import os
2024-03-15 15:16:20 +01:00
import time
from collections import OrderedDict
from collections.abc import Sequence
from itertools import repeat
from typing import Tuple
2024-03-15 15:16:20 +01:00
import numpy as np
import torch
import torch.distributed as dist
2024-05-09 10:07:56 +02:00
# ======================================================
# Logging
# ======================================================
2024-05-12 07:15:06 +02:00
def is_distributed():
return os.environ.get("WORLD_SIZE", None) is not None
def is_main_process():
return not is_distributed() or dist.get_rank() == 0
2024-05-14 07:40:17 +02:00
def get_world_size():
if is_distributed():
return dist.get_world_size()
else:
return 1
2024-05-12 07:15:06 +02:00
def create_logger(logging_dir=None):
2024-05-09 10:07:56 +02:00
"""
Create a logger that writes to a log file and stdout.
"""
2024-05-12 07:15:06 +02:00
if is_main_process(): # real logger
additional_args = dict()
if logging_dir is not None:
additional_args["handlers"] = [
logging.StreamHandler(),
logging.FileHandler(f"{logging_dir}/log.txt"),
]
2024-05-09 10:07:56 +02:00
logging.basicConfig(
level=logging.INFO,
format="[\033[34m%(asctime)s\033[0m] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
2024-05-12 07:15:06 +02:00
**additional_args,
2024-05-09 10:07:56 +02:00
)
logger = logging.getLogger(__name__)
else: # dummy logger (does nothing)
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
return logger
def get_logger():
return logging.getLogger(__name__)
2024-03-15 15:16:20 +01:00
def print_rank(var_name, var_value, rank=0):
if dist.get_rank() == rank:
print(f"[Rank {rank}] {var_name}: {var_value}")
def print_0(*args, **kwargs):
if dist.get_rank() == 0:
print(*args, **kwargs)
2024-05-13 13:39:16 +02:00
def create_tensorboard_writer(exp_dir):
2024-05-16 10:50:24 +02:00
from torch.utils.tensorboard import SummaryWriter
2024-05-13 13:39:16 +02:00
tensorboard_dir = f"{exp_dir}/tensorboard"
os.makedirs(tensorboard_dir, exist_ok=True)
writer = SummaryWriter(tensorboard_dir)
return writer
2024-05-09 10:07:56 +02:00
# ======================================================
# String
# ======================================================
2024-03-15 15:16:20 +01:00
def format_numel_str(numel: int) -> str:
B = 1024**3
M = 1024**2
K = 1024
if numel >= B:
return f"{numel / B:.2f} B"
elif numel >= M:
return f"{numel / M:.2f} M"
elif numel >= K:
return f"{numel / K:.2f} K"
else:
return f"{numel}"
def get_timestamp():
timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime(time.time()))
return timestamp
def format_time(seconds):
days = int(seconds / 3600 / 24)
seconds = seconds - days * 3600 * 24
hours = int(seconds / 3600)
seconds = seconds - hours * 3600
minutes = int(seconds / 60)
seconds = seconds - minutes * 60
secondsf = int(seconds)
seconds = seconds - secondsf
millis = int(seconds * 1000)
f = ""
i = 1
if days > 0:
f += str(days) + "D"
i += 1
if hours > 0 and i <= 2:
f += str(hours) + "h"
i += 1
if minutes > 0 and i <= 2:
f += str(minutes) + "m"
i += 1
if secondsf > 0 and i <= 2:
f += str(secondsf) + "s"
i += 1
if millis > 0 and i <= 2:
f += str(millis) + "ms"
i += 1
if f == "":
f = "0ms"
return f
2024-05-15 10:07:49 +02:00
class BColors:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
2024-05-09 10:07:56 +02:00
# ======================================================
# PyTorch
# ======================================================
def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
return tensor
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
num_params = 0
num_params_trainable = 0
for p in model.parameters():
num_params += p.numel()
if p.requires_grad:
num_params_trainable += p.numel()
return num_params, num_params_trainable
def count_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
2024-03-15 15:16:20 +01:00
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
Args:
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
be converted.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not isinstance(data, str):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
raise TypeError(f"type {type(data)} cannot be converted to tensor.")
def to_ndarray(data):
if isinstance(data, torch.Tensor):
return data.numpy()
elif isinstance(data, np.ndarray):
return data
elif isinstance(data, Sequence):
return np.array(data)
elif isinstance(data, int):
return np.ndarray([data], dtype=int)
elif isinstance(data, float):
return np.array([data], dtype=float)
else:
raise TypeError(f"type {type(data)} cannot be converted to ndarray.")
def to_torch_dtype(dtype):
if isinstance(dtype, torch.dtype):
return dtype
elif isinstance(dtype, str):
dtype_mapping = {
"float64": torch.float64,
"float32": torch.float32,
"float16": torch.float16,
"fp32": torch.float32,
"fp16": torch.float16,
"half": torch.float16,
"bf16": torch.bfloat16,
}
if dtype not in dtype_mapping:
raise ValueError
dtype = dtype_mapping[dtype]
return dtype
else:
raise ValueError
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def convert_SyncBN_to_BN2d(model_cfg):
for k in model_cfg:
v = model_cfg[k]
if k == "norm_cfg" and v["type"] == "SyncBN":
v["type"] = "BN2d"
elif isinstance(v, dict):
convert_SyncBN_to_BN2d(v)
def get_topk(x, dim=4, k=5):
x = to_tensor(x)
inds = x[..., dim].topk(k)[1]
return x[inds]
def param_sigmoid(x, alpha):
ret = 1 / (1 + (-alpha * x).exp())
return ret
def inverse_param_sigmoid(x, alpha, eps=1e-5):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2) / alpha
def inverse_sigmoid(x, eps=1e-5):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
2024-05-09 10:07:56 +02:00
# ======================================================
# Python
# ======================================================
2024-03-15 15:16:20 +01:00
def count_columns(df, columns):
cnt_dict = OrderedDict()
num_samples = len(df)
for col in columns:
d_i = df[col].value_counts().to_dict()
for k in d_i:
d_i[k] = (d_i[k], d_i[k] / num_samples)
cnt_dict[col] = d_i
return cnt_dict
2024-05-09 10:07:56 +02:00
def try_import(name):
"""Try to import a module.
2024-03-15 15:16:20 +01:00
2024-05-09 10:07:56 +02:00
Args:
name (str): Specifies what module to import in absolute or relative
terms (e.g. either pkg.mod or ..mod).
Returns:
ModuleType or None: If importing successfully, returns the imported
module, otherwise returns None.
"""
try:
return importlib.import_module(name)
except ImportError:
return None
2024-03-15 15:16:20 +01:00
2024-05-09 10:07:56 +02:00
def transpose(x):
"""
transpose a list of list
Args:
x (list[list]):
"""
ret = list(map(list, zip(*x)))
return ret
2024-05-13 08:33:12 +02:00
def all_exists(paths):
return all(os.path.exists(path) for path in paths)
2024-05-16 10:50:24 +02:00
# ======================================================
# Profile
# ======================================================
class Timer:
2024-05-29 08:46:14 +02:00
def __init__(self, name, log=False):
2024-05-16 10:50:24 +02:00
self.name = name
self.start_time = None
self.end_time = None
self.log = log
@property
def elapsed_time(self):
return self.end_time - self.start_time
def __enter__(self):
torch.cuda.synchronize()
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
self.end_time = time.time()
if self.log:
print(f"Elapsed time for {self.name}: {self.elapsed_time:.2f} s")
2024-05-17 08:40:44 +02:00
def get_tensor_memory(tensor, human_readable=True):
size = tensor.element_size() * tensor.nelement()
if human_readable:
size = format_numel_str(size)
return size
class FeatureSaver:
2024-05-17 11:53:48 +02:00
def __init__(self, save_dir, bin_size=10, start_bin=0):
2024-05-17 08:40:44 +02:00
self.save_dir = save_dir
self.bin_size = bin_size
2024-05-17 11:53:48 +02:00
self.bin_cnt = start_bin
2024-05-17 08:40:44 +02:00
self.data_list = []
self.cnt = 0
def update(self, data):
self.data_list.append(data)
self.cnt += 1
2024-05-17 11:53:48 +02:00
if self.cnt % self.bin_size == 0:
self.save()
def save(self):
save_path = os.path.join(self.save_dir, f"{self.bin_cnt:08}.bin")
torch.save(self.data_list, save_path)
get_logger().info("Saved to %s", save_path)
self.data_list = []
self.bin_cnt += 1
2024-05-29 08:43:15 +02:00