mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-25 01:48:45 +02:00
472 lines
16 KiB
Python
472 lines
16 KiB
Python
# All rights reserved.
|
||
# Copyright 2024 Vchitect/Latte
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
#
|
||
# copied from https://github.com/Vchitect/Latte/blob/main/utils.py
|
||
|
||
import html
|
||
import logging
|
||
import math
|
||
import os
|
||
import re
|
||
import subprocess
|
||
import urllib.parse as ul
|
||
from collections import OrderedDict
|
||
from typing import Iterable, Union
|
||
|
||
import torch
|
||
import torch.distributed as dist
|
||
from diffusers.utils import is_bs4_available, is_ftfy_available
|
||
|
||
# from torch._six import inf
|
||
from torch import inf
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
|
||
if is_bs4_available():
|
||
from bs4 import BeautifulSoup
|
||
|
||
if is_ftfy_available():
|
||
import ftfy
|
||
|
||
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
|
||
|
||
|
||
#################################################################################
|
||
# Training Clip Gradients #
|
||
#################################################################################
|
||
|
||
|
||
def get_grad_norm(parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor:
|
||
r"""
|
||
Copy from torch.nn.utils.clip_grad_norm_
|
||
|
||
Clips gradient norm of an iterable of parameters.
|
||
|
||
The norm is computed over all gradients together, as if they were
|
||
concatenated into a single vector. Gradients are modified in-place.
|
||
|
||
Args:
|
||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||
single Tensor that will have gradients normalized
|
||
max_norm (float or int): max norm of the gradients
|
||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||
infinity norm.
|
||
error_if_nonfinite (bool): if True, an error is thrown if the total
|
||
norm of the gradients from :attr:`parameters` is ``nan``,
|
||
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
||
|
||
Returns:
|
||
Total norm of the parameter gradients (viewed as a single vector).
|
||
"""
|
||
if isinstance(parameters, torch.Tensor):
|
||
parameters = [parameters]
|
||
grads = [p.grad for p in parameters if p.grad is not None]
|
||
norm_type = float(norm_type)
|
||
if len(grads) == 0:
|
||
return torch.tensor(0.0)
|
||
device = grads[0].device
|
||
if norm_type == inf:
|
||
norms = [g.detach().abs().max().to(device) for g in grads]
|
||
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
|
||
else:
|
||
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
|
||
return total_norm
|
||
|
||
|
||
def clip_grad_norm_(
|
||
parameters: _tensor_or_tensors,
|
||
max_norm: float,
|
||
norm_type: float = 2.0,
|
||
error_if_nonfinite: bool = False,
|
||
clip_grad=True,
|
||
) -> torch.Tensor:
|
||
r"""
|
||
Copy from torch.nn.utils.clip_grad_norm_
|
||
|
||
Clips gradient norm of an iterable of parameters.
|
||
|
||
The norm is computed over all gradients together, as if they were
|
||
concatenated into a single vector. Gradients are modified in-place.
|
||
|
||
Args:
|
||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||
single Tensor that will have gradients normalized
|
||
max_norm (float or int): max norm of the gradients
|
||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||
infinity norm.
|
||
error_if_nonfinite (bool): if True, an error is thrown if the total
|
||
norm of the gradients from :attr:`parameters` is ``nan``,
|
||
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
||
|
||
Returns:
|
||
Total norm of the parameter gradients (viewed as a single vector).
|
||
"""
|
||
if isinstance(parameters, torch.Tensor):
|
||
parameters = [parameters]
|
||
grads = [p.grad for p in parameters if p.grad is not None]
|
||
max_norm = float(max_norm)
|
||
norm_type = float(norm_type)
|
||
if len(grads) == 0:
|
||
return torch.tensor(0.0)
|
||
device = grads[0].device
|
||
if norm_type == inf:
|
||
norms = [g.detach().abs().max().to(device) for g in grads]
|
||
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
|
||
else:
|
||
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
|
||
# print(total_norm)
|
||
|
||
if clip_grad:
|
||
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
|
||
raise RuntimeError(
|
||
f"The total norm of order {norm_type} for gradients from "
|
||
"`parameters` is non-finite, so it cannot be clipped. To disable "
|
||
"this error and scale the gradients by the non-finite norm anyway, "
|
||
"set `error_if_nonfinite=False`"
|
||
)
|
||
clip_coef = max_norm / (total_norm + 1e-6)
|
||
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
|
||
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
|
||
# when the gradients do not reside in CPU memory.
|
||
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
||
for g in grads:
|
||
g.detach().mul_(clip_coef_clamped.to(g.device))
|
||
# gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
|
||
# print(gradient_cliped)
|
||
return total_norm
|
||
|
||
|
||
def get_experiment_dir(root_dir, args):
|
||
# if args.pretrained is not None and 'Latte-XL-2-256x256.pt' not in args.pretrained:
|
||
# root_dir += '-WOPRE'
|
||
if args.use_compile:
|
||
root_dir += "-Compile" # speedup by torch compile
|
||
if args.fixed_spatial:
|
||
root_dir += "-FixedSpa"
|
||
if args.enable_xformers_memory_efficient_attention:
|
||
root_dir += "-Xfor"
|
||
if args.gradient_checkpointing:
|
||
root_dir += "-Gc"
|
||
if args.mixed_precision:
|
||
root_dir += "-Amp"
|
||
if args.image_size == 512:
|
||
root_dir += "-512"
|
||
return root_dir
|
||
|
||
|
||
#################################################################################
|
||
# Training Logger #
|
||
#################################################################################
|
||
|
||
|
||
def create_logger(logging_dir):
|
||
"""
|
||
Create a logger that writes to a log file and stdout.
|
||
"""
|
||
if dist.get_rank() == 0: # real logger
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
# format='[\033[34m%(asctime)s\033[0m] %(message)s',
|
||
format="[%(asctime)s] %(message)s",
|
||
datefmt="%Y-%m-%d %H:%M:%S",
|
||
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")],
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
else: # dummy logger (does nothing)
|
||
logger = logging.getLogger(__name__)
|
||
logger.addHandler(logging.NullHandler())
|
||
return logger
|
||
|
||
|
||
def create_tensorboard(tensorboard_dir):
|
||
"""
|
||
Create a tensorboard that saves losses.
|
||
"""
|
||
if dist.get_rank() == 0: # real tensorboard
|
||
# tensorboard
|
||
writer = SummaryWriter(tensorboard_dir)
|
||
|
||
return writer
|
||
|
||
|
||
def write_tensorboard(writer, *args):
|
||
"""
|
||
write the loss information to a tensorboard file.
|
||
Only for pytorch DDP mode.
|
||
"""
|
||
if dist.get_rank() == 0: # real tensorboard
|
||
writer.add_scalar(args[0], args[1], args[2])
|
||
|
||
|
||
#################################################################################
|
||
# EMA Update/ DDP Training Utils #
|
||
#################################################################################
|
||
|
||
|
||
@torch.no_grad()
|
||
def update_ema(ema_model, model, decay=0.9999):
|
||
"""
|
||
Step the EMA model towards the current model.
|
||
"""
|
||
ema_params = OrderedDict(ema_model.named_parameters())
|
||
model_params = OrderedDict(model.named_parameters())
|
||
|
||
for name, param in model_params.items():
|
||
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
|
||
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
|
||
|
||
|
||
def requires_grad(model, flag=True):
|
||
"""
|
||
Set requires_grad flag for all parameters in a model.
|
||
"""
|
||
for p in model.parameters():
|
||
p.requires_grad = flag
|
||
|
||
|
||
def cleanup():
|
||
"""
|
||
End DDP training.
|
||
"""
|
||
dist.destroy_process_group()
|
||
|
||
|
||
def setup_distributed(backend="nccl", port=None):
|
||
"""Initialize distributed training environment.
|
||
support both slurm and torch.distributed.launch
|
||
see torch.distributed.init_process_group() for more details
|
||
"""
|
||
num_gpus = torch.cuda.device_count()
|
||
|
||
if "SLURM_JOB_ID" in os.environ:
|
||
rank = int(os.environ["SLURM_PROCID"])
|
||
world_size = int(os.environ["SLURM_NTASKS"])
|
||
node_list = os.environ["SLURM_NODELIST"]
|
||
addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
|
||
# specify master port
|
||
if port is not None:
|
||
os.environ["MASTER_PORT"] = str(port)
|
||
elif "MASTER_PORT" not in os.environ:
|
||
# os.environ["MASTER_PORT"] = "29566"
|
||
os.environ["MASTER_PORT"] = str(29567 + num_gpus)
|
||
if "MASTER_ADDR" not in os.environ:
|
||
os.environ["MASTER_ADDR"] = addr
|
||
os.environ["WORLD_SIZE"] = str(world_size)
|
||
os.environ["LOCAL_RANK"] = str(rank % num_gpus)
|
||
os.environ["RANK"] = str(rank)
|
||
else:
|
||
rank = int(os.environ["RANK"])
|
||
world_size = int(os.environ["WORLD_SIZE"])
|
||
|
||
# torch.cuda.set_device(rank % num_gpus)
|
||
|
||
dist.init_process_group(
|
||
backend=backend,
|
||
world_size=world_size,
|
||
rank=rank,
|
||
)
|
||
|
||
|
||
#################################################################################
|
||
# Testing Utils #
|
||
#################################################################################
|
||
|
||
|
||
def save_video_grid(video, nrow=None):
|
||
b, t, h, w, c = video.shape
|
||
|
||
if nrow is None:
|
||
nrow = math.ceil(math.sqrt(b))
|
||
ncol = math.ceil(b / nrow)
|
||
padding = 1
|
||
video_grid = torch.zeros((t, (padding + h) * nrow + padding, (padding + w) * ncol + padding, c), dtype=torch.uint8)
|
||
|
||
print(video_grid.shape)
|
||
for i in range(b):
|
||
r = i // ncol
|
||
c = i % ncol
|
||
start_r = (padding + h) * r
|
||
start_c = (padding + w) * c
|
||
video_grid[:, start_r : start_r + h, start_c : start_c + w] = video[i]
|
||
|
||
return video_grid
|
||
|
||
|
||
#################################################################################
|
||
# MMCV Utils #
|
||
#################################################################################
|
||
|
||
|
||
def collect_env():
|
||
# Copyright (c) OpenMMLab. All rights reserved.
|
||
from mmcv.utils import collect_env as collect_base_env
|
||
from mmcv.utils import get_git_hash
|
||
|
||
"""Collect the information of the running environments."""
|
||
|
||
env_info = collect_base_env()
|
||
env_info["MMClassification"] = get_git_hash()[:7]
|
||
|
||
for name, val in env_info.items():
|
||
print(f"{name}: {val}")
|
||
|
||
print(torch.cuda.get_arch_list())
|
||
print(torch.version.cuda)
|
||
|
||
|
||
#################################################################################
|
||
# Pixart-alpha Utils #
|
||
#################################################################################
|
||
|
||
bad_punct_regex = re.compile(
|
||
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||
)
|
||
|
||
|
||
def text_preprocessing(text, clean_caption=False):
|
||
if clean_caption and not is_bs4_available():
|
||
clean_caption = False
|
||
|
||
if clean_caption and not is_ftfy_available():
|
||
clean_caption = False
|
||
|
||
if not isinstance(text, (tuple, list)):
|
||
text = [text]
|
||
|
||
def process(text: str):
|
||
if clean_caption:
|
||
text = clean_caption(text)
|
||
text = clean_caption(text)
|
||
else:
|
||
text = text.lower().strip()
|
||
return text
|
||
|
||
return [process(t) for t in text]
|
||
|
||
|
||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
||
def clean_caption(caption):
|
||
caption = str(caption)
|
||
caption = ul.unquote_plus(caption)
|
||
caption = caption.strip().lower()
|
||
caption = re.sub("<person>", "person", caption)
|
||
# urls:
|
||
caption = re.sub(
|
||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||
"",
|
||
caption,
|
||
) # regex for urls
|
||
caption = re.sub(
|
||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||
"",
|
||
caption,
|
||
) # regex for urls
|
||
# html:
|
||
caption = BeautifulSoup(caption, features="html.parser").text
|
||
|
||
# @<nickname>
|
||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||
|
||
# 31C0—31EF CJK Strokes
|
||
# 31F0—31FF Katakana Phonetic Extensions
|
||
# 3200—32FF Enclosed CJK Letters and Months
|
||
# 3300—33FF CJK Compatibility
|
||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||
# 4E00—9FFF CJK Unified Ideographs
|
||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||
#######################################################
|
||
|
||
# все виды тире / all types of dash --> "-"
|
||
caption = re.sub(
|
||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||
"-",
|
||
caption,
|
||
)
|
||
|
||
# кавычки к одному стандарту
|
||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||
caption = re.sub(r"[‘’]", "'", caption)
|
||
|
||
# "
|
||
caption = re.sub(r""?", "", caption)
|
||
# &
|
||
caption = re.sub(r"&", "", caption)
|
||
|
||
# ip adresses:
|
||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||
|
||
# article ids:
|
||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||
|
||
# \n
|
||
caption = re.sub(r"\\n", " ", caption)
|
||
|
||
# "#123"
|
||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||
# "#12345.."
|
||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||
# "123456.."
|
||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||
# filenames:
|
||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||
|
||
#
|
||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||
|
||
caption = re.sub(bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||
|
||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||
regex2 = re.compile(r"(?:\-|\_)")
|
||
if len(re.findall(regex2, caption)) > 3:
|
||
caption = re.sub(regex2, " ", caption)
|
||
|
||
caption = ftfy.fix_text(caption)
|
||
caption = html.unescape(html.unescape(caption))
|
||
|
||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||
|
||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||
|
||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||
|
||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||
|
||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||
caption = re.sub(r"\s+", " ", caption)
|
||
|
||
caption.strip()
|
||
|
||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||
caption = re.sub(r"^\.\S+$", "", caption)
|
||
|
||
return caption.strip()
|