first commit (#67)

This commit is contained in:
Zheng Zangwei (Alex Zheng) 2024-03-15 22:06:36 +08:00 committed by GitHub
parent 9aab3ad343
commit 0ba56f5309
26 changed files with 4698 additions and 0 deletions

149
README.md Normal file
View file

@ -0,0 +1,149 @@
<p align="center">
<img src="./assets/readme/icon_zw.png" width="250"/>
<p>
</p>
<div align="center">
<a href="https://github.com/hpcaitech/Open-Sora/stargazers"><img src="https://img.shields.io/github/stars/hpcaitech/Open-Sora?style=social"></a>
<a href="https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack"><img src="https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&amp"></a>
<a href="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png"><img src="https://img.shields.io/badge/微信-加入-green?logo=wechat&amp"></a>
</div>
## Open-Sora: Towards Open Reproduction of Sora
**Open-Sora** is an **open-source** initiative dedicated to **efficiently** reproducing OpenAI's Sora. Our project aims to cover **the full pipeline**, including video data preprocessing, training with acceleration, efficient inference and more. Operating on a limited budget, we prioritize the vibrant open-source community, providing access to text-to-image, image captioning, and language models. We hope to make a contribution to the community and make the project more accessible to everyone.
## 📰 News
* **[2024.03.18]** 🔥 We release **Open-Sora 1.0**, an open-source project to reproduce OpenAI Sora.
Open-Sora 1.0 supports a full pipeline of video data preprocessing, training with
<a href="https://github.com/hpcaitech/ColossalAI"><img src="assets/readme/colossal_ai.png" width="8%" ></a> acceleration,
inference, and more. Our provided checkpoint can produce 2s 512x512 videos.
## 🎥 Latest Demo
| **2s 512x512** | **2s 512x512** |
| ----------------------------------------------- | ----------------------------------------------- |
| <img src="assets/readme/sample_0.gif" width=""> | <img src="assets/readme/sample_0.gif" width=""> |
## 🔆 New Features/Updates
- 📍 Open-Sora-v1 is trained on xxx. We train the model in three stages. Model weights are available here. Training details can be found here.
- ✅ Support training acceleration including flash-attention, accelerated T5, mixed precision, gradient checkpointing, splitted VAE, sequence parallelism, etc. XXX times. See more discussions [here]().
- ✅ We provide video cutting and captioning tools for data preprocessing. Our data collection plan can be found [here]().
- ✅ We find VQ-VAE from [] has a low quality and thus adopt a better VAE from []. We also find patching in the time dimension deteriorates the quality. See more discussions [here]().
- ✅ We investigate different architectures including DiT, Latte, and our proposed STDiT. Our STDiT achieves a better trade-off between quality and speed. See more discussions [here]().
- ✅ Support clip and t5 text conditioning.
- ✅ By viewing images as one-frame videos, our project supports training DiT on both images and videos (e.g., ImageNet & UCF101).
- ✅ Support inference with official weights from [DiT](https://github.com/facebookresearch/DiT), [Latte](https://github.com/Vchitect/Latte), and [PixArt](https://pixart-alpha.github.io/).
### TODO list sorted by priority
- [ ] Complete the data processing pipeline (including dense optical flow, aesthetics scores, text-image similarity, deduplication, etc.). See [datasets.md]() for more information. **[WIP]**
- [ ] Training Video-VAE. **[WIP]**
- [ ] Support image and video conditioning.
- [ ] Evaluation pipeline.
- [ ] Incoporate a better scheduler, e.g., rectified flow in SD3.
- [ ] Support variable aspect ratios, resolutions, durations.
- [ ] Support SD3 when released.
## Contents
- [Open-Sora: Towards Open Reproduction of Sora](#open-sora-towards-open-reproduction-of-sora)
- [📰 News](#-news)
- [🎥 Latest Demo](#-latest-demo)
- [🔆 New Features/Updates](#-new-featuresupdates)
- [TODO list sorted by priority](#todo-list-sorted-by-priority)
- [Contents](#contents)
- [Installation](#installation)
- [Model Weights](#model-weights)
- [Inference](#inference)
- [Data Processing](#data-processing)
- [Split video into clips](#split-video-into-clips)
- [Generate video caption](#generate-video-caption)
- [Training](#training)
- [Acknowledgement](#acknowledgement)
- [Citation](#citation)
- [Star History](#star-history)
- [TODO](#todo)
## Installation
```bash
git clone https://github.com/hpcaitech/Open-Sora
cd Open-Sora
pip install xxx
```
After installation, to get fimilar with the project, you can check the [here]() for the project structure and how to use the config files.
## Model Weights
| Model | #Params | url |
| ---------- | ------- | --- |
| 16x256x256 | | |
## Inference
```bash
python scripts/inference.py configs/opensora/inference/16x256x256.py
```
## Data Processing
### Split video into clips
We provide code to split a long video into separate clips efficiently using `multiprocessing`. See `tools/data/scene_detect.py`.
### Generate video caption
## Training
## Acknowledgement
* [DiT](https://github.com/facebookresearch/DiT): Scalable Diffusion Models with Transformers.
* [OpenDiT](https://github.com/NUS-HPC-AI-Lab/OpenDiT): An acceleration for DiT training. OpenDiT's team provides valuable suggestions on acceleration of our training process.
* [PixArt](https://github.com/PixArt-alpha/PixArt-alpha): An open-source DiT-based text-to-image model.
* [Latte](https://github.com/Vchitect/Latte): An attempt to efficiently train DiT for video.
* [StabilityAI VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse-original): A powerful image VAE model.
* [CLIP](https://github.com/openai/CLIP): A powerful text-image embedding model.
* [T5](https://github.com/google-research/text-to-text-transfer-transformer): The powerful text encoder.
* [LLaVA](https://github.com/haotian-liu/LLaVA): A powerful image captioning model based on [LLaMA](https://github.com/meta-llama/llama) and [Yi-34B](https://huggingface.co/01-ai/Yi-34B).
* [PySceneDetect](https://github.com/Breakthrough/PySceneDetect): A powerful tool to split video into clips.
We are grateful for their exceptional work and generous contribution to open source.
## Citation
```bibtex
@software{opensora,
author = {Zangwei Zheng and Xiangyu Peng and Shenggui Li and Yang You},
title = {Open-Sora: Towards Open Reproduction of Sora},
month = {March},
year = {2024},
url = {https://github.com/hpcaitech/Open-Sora}
}
```
Zangwei Zheng and Xiangyu Peng equally contributed to this work.
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=hpcaitech/Open-Sora&type=Date)](https://star-history.com/#hpcaitech/Open-Sora&Date)
## TODO
Modules for releasing:
* `configs`
* `opensora`
* `assets`
* `scripts`
* `tools`
packages for data processing
put all outputs under ./checkpoints/, including pretrained_models, checkpoints, samples

0
docs/datasets.md Normal file
View file

4
opensora/__init__.py Normal file
View file

@ -0,0 +1,4 @@
from .acceleration import *
from .datasets import *
from .registry import *
from .models import *

39
opensora/registry.py Normal file
View file

@ -0,0 +1,39 @@
from copy import deepcopy
import torch.nn as nn
from mmengine.registry import Registry
def build_module(module, builder, **kwargs):
"""Build module from config or return the module itself.
Args:
module (Union[dict, nn.Module]): The module to build.
builder (Registry): The registry to build module.
*args, **kwargs: Arguments passed to build function.
Returns:
Any: The built module.
"""
if isinstance(module, dict):
cfg = deepcopy(module)
for k, v in kwargs.items():
cfg[k] = v
return builder.build(cfg)
elif isinstance(module, nn.Module):
return module
elif module is None:
return None
else:
raise TypeError(f"Only support dict and nn.Module, but got {type(module)}.")
MODELS = Registry(
"model",
locations=["opensora.models"],
)
SCHEDULERS = Registry(
"scheduler",
locations=["opensora.schedulers"],
)

View file

@ -0,0 +1,2 @@
from .iddpm import IDDPM
from .dpms import DPMS

View file

@ -0,0 +1,50 @@
from functools import partial
import torch
from opensora.registry import SCHEDULERS
from .dpm_solver import DPMS
@SCHEDULERS.register_module("dpm-solver")
class DMP_SOLVER:
def __init__(self, num_sampling_steps=None, cfg_scale=4.0):
self.num_sampling_steps = num_sampling_steps
self.cfg_scale = cfg_scale
def sample(
self,
model,
text_encoder,
z_size,
prompts,
device,
additional_args=None,
):
n = len(prompts)
z = torch.randn(n, *z_size, device=device)
model_args = text_encoder.encode(prompts)
y = model_args.pop("y")
null_y = text_encoder.null(n)
if additional_args is not None:
model_args.update(additional_args)
dpms = DPMS(
partial(forward_with_dpmsolver, model),
condition=y,
uncondition=null_y,
cfg_scale=self.cfg_scale,
model_kwargs=model_args,
)
samples = dpms.sample(z, steps=self.num_sampling_steps, order=2, skip_type="time_uniform", method="multistep")
return samples
def forward_with_dpmsolver(self, x, timestep, y, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, **kwargs)
return model_out.chunk(2, dim=1)[0]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,95 @@
from functools import partial
import torch
from opensora.registry import SCHEDULERS
from . import gaussian_diffusion as gd
from .respace import SpacedDiffusion, space_timesteps
@SCHEDULERS.register_module("iddpm")
class IDDPM(SpacedDiffusion):
def __init__(
self,
num_sampling_steps=None,
timestep_respacing=None,
noise_schedule="linear",
use_kl=False,
sigma_small=False,
predict_xstart=False,
learn_sigma=True,
rescale_learned_sigmas=False,
diffusion_steps=1000,
cfg_scale=4.0,
):
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if num_sampling_steps is not None:
assert timestep_respacing is None
timestep_respacing = str(num_sampling_steps)
if timestep_respacing is None or timestep_respacing == "":
timestep_respacing = [diffusion_steps]
super().__init__(
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
betas=betas,
model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X),
model_var_type=(
(gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type,
# rescale_timesteps=rescale_timesteps,
)
self.cfg_scale = cfg_scale
def sample(
self,
model,
text_encoder,
z_size,
prompts,
device,
additional_args=None,
):
n = len(prompts)
z = torch.randn(n, *z_size, device=device)
z = torch.cat([z, z], 0)
model_args = text_encoder.encode(prompts)
y_null = text_encoder.null(n)
model_args["y"] = torch.cat([model_args["y"], y_null], 0)
if additional_args is not None:
model_args.update(additional_args)
forward = partial(forward_with_cfg, model, cfg_scale=self.cfg_scale)
samples = self.p_sample_loop(
forward,
z.shape,
z,
clip_denoised=False,
model_kwargs=model_args,
progress=True,
device=device,
)
samples, _ = samples.chunk(2, dim=0)
return samples
def forward_with_cfg(model, x, timestep, y, cfg_scale, **kwargs):
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = torch.cat([half, half], dim=0)
model_out = model.forward(combined, timestep, y, **kwargs)
model_out = model_out["x"] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)

View file

@ -0,0 +1,87 @@
# Adapted from DiT
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DiT: https://github.com/facebookresearch/DiT/tree/main
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# --------------------------------------------------------
import numpy as np
import torch as th
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, th.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)]
return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2))
def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a continuous Gaussian distribution.
:param x: the targets
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
centered_x = x - means
inv_stdv = th.exp(-log_scales)
normalized_x = centered_x * inv_stdv
log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
return log_probs
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = th.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = th.where(
x < -0.999,
log_cdf_plus,
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs

View file

@ -0,0 +1,835 @@
# Adapted from DiT
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DiT: https://github.com/facebookresearch/DiT/tree/main
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# --------------------------------------------------------
import enum
import math
import numpy as np
import torch as th
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
class ModelMeanType(enum.Enum):
"""
Which type of output the model predicts.
"""
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
START_X = enum.auto() # the model predicts x_0
EPSILON = enum.auto() # the model predicts epsilon
class ModelVarType(enum.Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
LEARNED = enum.auto()
FIXED_SMALL = enum.auto()
FIXED_LARGE = enum.auto()
LEARNED_RANGE = enum.auto()
class LossType(enum.Enum):
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances)
KL = enum.auto() # use the variational lower-bound
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
def is_vb(self):
return self == LossType.KL or self == LossType.RESCALED_KL
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
warmup_time = int(num_diffusion_timesteps * warmup_frac)
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
return betas
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
"""
This is the deprecated API for creating beta schedules.
See get_named_beta_schedule() for the new library of schedules.
"""
if beta_schedule == "quad":
betas = (
np.linspace(
beta_start**0.5,
beta_end**0.5,
num_diffusion_timesteps,
dtype=np.float64,
)
** 2
)
elif beta_schedule == "linear":
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "warmup10":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
elif beta_schedule == "warmup50":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
elif beta_schedule == "const":
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
return get_beta_schedule(
"linear",
beta_start=scale * 0.0001,
beta_end=scale * 0.02,
num_diffusion_timesteps=num_diffusion_timesteps,
)
elif schedule_name == "squaredcos_cap_v2":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
class GaussianDiffusion:
"""
Utilities for training and sampling diffusion models.
Original ported from this codebase:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
:param betas: a 1-D numpy array of betas for each diffusion timestep,
starting at T and going to 1.
"""
def __init__(self, *, betas, model_mean_type, model_var_type, loss_type):
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
self.loss_type = loss_type
# Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64)
self.betas = betas
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
self.num_timesteps = int(betas.shape[0])
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = (
np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
if len(self.posterior_variance) > 1
else np.array([])
)
self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
if noise is None:
noise = th.randn_like(x_start)
assert noise.shape == x_start.shape
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample. Applies before
clip_denoised.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if model_kwargs is None:
model_kwargs = {}
B, C = x.shape[:2]
assert t.shape == (B,)
model_output = model(x, t, **model_kwargs)
if isinstance(model_output, tuple):
model_output, extra = model_output
else:
extra = None
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1)
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = th.exp(model_log_variance)
else:
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so
# to get a better decoder log likelihood.
ModelVarType.FIXED_LARGE: (
np.append(self.posterior_variance[1], self.betas[1:]),
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
),
ModelVarType.FIXED_SMALL: (
self.posterior_variance,
self.posterior_log_variance_clipped,
),
}[self.model_var_type]
model_variance = _extract_into_tensor(model_variance, t, x.shape)
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
def process_xstart(x):
if denoised_fn is not None:
x = denoised_fn(x)
if clip_denoised:
return x.clamp(-1, 1)
return x
if self.model_mean_type == ModelMeanType.START_X:
pred_xstart = process_xstart(model_output)
else:
pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
return {
"mean": model_mean,
"variance": model_variance,
"log_variance": model_log_variance,
"pred_xstart": pred_xstart,
"extra": extra,
}
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute the mean for the previous step, given a function cond_fn that
computes the gradient of a conditional log probability with respect to
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
condition on y.
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
"""
gradient = cond_fn(x, t, **model_kwargs)
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
return new_mean
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute what the p_mean_variance output would have been, should the
model's score function be conditioned by cond_fn.
See condition_mean() for details on cond_fn.
Unlike condition_mean(), this instead uses the conditioning strategy
from Song et al (2020).
"""
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
out = p_mean_var.copy()
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
return out
def p_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
):
"""
Sample x_{t-1} from the model at the given timestep.
:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- 'sample': a random sample from the model.
- 'pred_xstart': a prediction of x_0.
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
noise = th.randn_like(x)
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
if cond_fn is not None:
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def p_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Generate samples from the model.
:param model: the model module.
:param shape: the shape of the samples, (N, C, H, W).
:param noise: if specified, the noise from the encoder to sample.
Should be of the same shape as `shape`.
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param device: if specified, the device to create the samples on.
If not specified, use a model parameter's device.
:param progress: if True, show a tqdm progress bar.
:return: a non-differentiable batch of samples.
"""
final = None
for sample in self.p_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
):
final = sample
return final["sample"]
def p_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Generate samples from the model and yield intermediate samples from
each timestep of diffusion.
Arguments are the same as p_sample_loop().
Returns a generator over dicts, where each dict is the return value of
p_sample().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
)
yield out
img = out["sample"]
def ddim_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t-1} from the model using DDIM.
Same usage as p_sample().
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)
# Equation 12.
noise = th.randn_like(x)
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def ddim_reverse_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t+1} from the model using DDIM reverse ODE.
"""
assert eta == 0.0, "Reverse ODE only for deterministic path"
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
# Equation 12. reversed
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
def ddim_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
):
"""
Generate samples from the model using DDIM.
Same usage as p_sample_loop().
"""
final = None
for sample in self.ddim_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
eta=eta,
):
final = sample
return final["sample"]
def ddim_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
):
"""
Use DDIM to sample from the model and yield intermediate samples from
each timestep of DDIM.
Same usage as p_sample_loop_progressive().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.ddim_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
eta=eta,
)
yield out
img = out["sample"]
def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
"""
Get a term for the variational lower-bound.
The resulting units are bits (rather than nats, as one might expect).
This allows for comparison to other papers.
:return: a dict with the following keys:
- 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions.
"""
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
kl = mean_flat(kl) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
)
assert decoder_nll.shape == x_start.shape
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = th.where((t == 0), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"]}
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param t: a batch of timestep indices.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param noise: if specified, the specific Gaussian noise to try to remove.
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""
if model_kwargs is None:
model_kwargs = {}
if noise is None:
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise)
terms = {}
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
terms["loss"] = self._vb_terms_bpd(
model=model,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
model_kwargs=model_kwargs,
)["output"]
if self.loss_type == LossType.RESCALED_KL:
terms["loss"] *= self.num_timesteps
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
model_output = model(x_t, t, **model_kwargs)
if self.model_var_type in [
ModelVarType.LEARNED,
ModelVarType.LEARNED_RANGE,
]:
B, C = x_t.shape[:2]
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
terms["vb"] = self._vb_terms_bpd(
model=lambda *args, r=frozen_out: r,
x_start=x_start,
x_t=x_t,
t=t,
clip_denoised=False,
)["output"]
if self.loss_type == LossType.RESCALED_MSE:
# Divide by 1000 for equivalence with initial implementation.
# Without a factor of 1/1000, the VB term hurts the MSE term.
terms["vb"] *= self.num_timesteps / 1000.0
target = {
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
ModelMeanType.START_X: x_start,
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
assert model_output.shape == target.shape == x_start.shape
terms["mse"] = mean_flat((target - model_output) ** 2)
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else:
terms["loss"] = terms["mse"]
else:
raise NotImplementedError(self.loss_type)
return terms
def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
return mean_flat(kl_prior) / np.log(2.0)
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
"""
Compute the entire variational lower-bound, measured in bits-per-dim,
as well as other related quantities.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param clip_denoised: if True, clip denoised samples.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- total_bpd: the total variational lower-bound, per batch element.
- prior_bpd: the prior term in the lower-bound.
- vb: an [N x T] tensor of terms in the lower-bound.
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
"""
device = x_start.device
batch_size = x_start.shape[0]
vb = []
xstart_mse = []
mse = []
for t in list(range(self.num_timesteps))[::-1]:
t_batch = th.tensor([t] * batch_size, device=device)
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
# Calculate VLB term at the current timestep
with th.no_grad():
out = self._vb_terms_bpd(
model,
x_start=x_start,
x_t=x_t,
t=t_batch,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs,
)
vb.append(out["output"])
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
mse.append(mean_flat((eps - noise) ** 2))
vb = th.stack(vb, dim=1)
xstart_mse = th.stack(xstart_mse, dim=1)
mse = th.stack(mse, dim=1)
prior_bpd = self._prior_bpd(x_start)
total_bpd = vb.sum(dim=1) + prior_bpd
return {
"total_bpd": total_bpd,
"prior_bpd": prior_bpd,
"vb": vb,
"xstart_mse": xstart_mse,
"mse": mse,
}
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + th.zeros(broadcast_shape, device=timesteps.device)

View file

@ -0,0 +1,127 @@
# Adapted from DiT
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DiT: https://github.com/facebookresearch/DiT/tree/main
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# --------------------------------------------------------
import numpy as np
import torch as th
from .gaussian_diffusion import GaussianDiffusion
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim") :])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(f"cannot divide section of {size} steps into {section_count}")
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
class SpacedDiffusion(GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps)
self.timestep_map = []
self.original_num_steps = len(kwargs["betas"])
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
kwargs["betas"] = np.array(new_betas)
super().__init__(**kwargs)
def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
return super().training_losses(self._wrap_model(model), *args, **kwargs)
def condition_mean(self, cond_fn, *args, **kwargs):
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
def condition_score(self, cond_fn, *args, **kwargs):
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
def _wrap_model(self, model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(model, self.timestep_map, self.original_num_steps)
def _scale_timesteps(self, t):
# Scaling is done by the wrapped model.
return t
class _WrappedModel:
def __init__(self, model, timestep_map, original_num_steps):
self.model = model
self.timestep_map = timestep_map
# self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs):
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts]
# if self.rescale_timesteps:
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return self.model(x, new_ts, **kwargs)

View file

@ -0,0 +1,150 @@
# Adapted from DiT
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DiT: https://github.com/facebookresearch/DiT/tree/main
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# --------------------------------------------------------
from abc import ABC, abstractmethod
import numpy as np
import torch as th
import torch.distributed as dist
def create_named_schedule_sampler(name, diffusion):
"""
Create a ScheduleSampler from a library of pre-defined samplers.
:param name: the name of the sampler.
:param diffusion: the diffusion object to sample for.
"""
if name == "uniform":
return UniformSampler(diffusion)
elif name == "loss-second-moment":
return LossSecondMomentResampler(diffusion)
else:
raise NotImplementedError(f"unknown schedule sampler: {name}")
class ScheduleSampler(ABC):
"""
A distribution over timesteps in the diffusion process, intended to reduce
variance of the objective.
By default, samplers perform unbiased importance sampling, in which the
objective's mean is unchanged.
However, subclasses may override sample() to change how the resampled
terms are reweighted, allowing for actual changes in the objective.
"""
@abstractmethod
def weights(self):
"""
Get a numpy array of weights, one per diffusion step.
The weights needn't be normalized, but must be positive.
"""
def sample(self, batch_size, device):
"""
Importance-sample timesteps for a batch.
:param batch_size: the number of timesteps.
:param device: the torch device to save to.
:return: a tuple (timesteps, weights):
- timesteps: a tensor of timestep indices.
- weights: a tensor of weights to scale the resulting losses.
"""
w = self.weights()
p = w / np.sum(w)
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
indices = th.from_numpy(indices_np).long().to(device)
weights_np = 1 / (len(p) * p[indices_np])
weights = th.from_numpy(weights_np).float().to(device)
return indices, weights
class UniformSampler(ScheduleSampler):
def __init__(self, diffusion):
self.diffusion = diffusion
self._weights = np.ones([diffusion.num_timesteps])
def weights(self):
return self._weights
class LossAwareSampler(ScheduleSampler):
def update_with_local_losses(self, local_ts, local_losses):
"""
Update the reweighting using losses from a model.
Call this method from each rank with a batch of timesteps and the
corresponding losses for each of those timesteps.
This method will perform synchronization to make sure all of the ranks
maintain the exact same reweighting.
:param local_ts: an integer Tensor of timesteps.
:param local_losses: a 1D Tensor of losses.
"""
batch_sizes = [th.tensor([0], dtype=th.int32, device=local_ts.device) for _ in range(dist.get_world_size())]
dist.all_gather(
batch_sizes,
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
)
# Pad all_gather batches to be the maximum batch size.
batch_sizes = [x.item() for x in batch_sizes]
max_bs = max(batch_sizes)
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
dist.all_gather(timestep_batches, local_ts)
dist.all_gather(loss_batches, local_losses)
timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]]
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
self.update_with_all_losses(timesteps, losses)
@abstractmethod
def update_with_all_losses(self, ts, losses):
"""
Update the reweighting using losses from a model.
Sub-classes should override this method to update the reweighting
using losses from the model.
This method directly updates the reweighting without synchronizing
between workers. It is called by update_with_local_losses from all
ranks with identical arguments. Thus, it should have deterministic
behavior to maintain state across workers.
:param ts: a list of int timesteps.
:param losses: a list of float losses, one per timestep.
"""
class LossSecondMomentResampler(LossAwareSampler):
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
self.diffusion = diffusion
self.history_per_term = history_per_term
self.uniform_prob = uniform_prob
self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64)
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
def weights(self):
if not self._warmed_up():
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
weights /= np.sum(weights)
weights *= 1 - self.uniform_prob
weights += self.uniform_prob / len(weights)
return weights
def update_with_all_losses(self, ts, losses):
for t, loss in zip(ts, losses):
if self._loss_counts[t] == self.history_per_term:
# Shift out the oldest loss term.
self._loss_history[t, :-1] = self._loss_history[t, 1:]
self._loss_history[t, -1] = loss
else:
self._loss_history[t, self._loss_counts[t]] = loss
self._loss_counts[t] += 1
def _warmed_up(self):
return (self._loss_counts == self.history_per_term).all()

View file

View file

@ -0,0 +1,233 @@
import functools
import json
import logging
import operator
import os
from typing import Tuple
import colossalai
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torchvision.datasets.utils import download_url
pretrained_models = {
"DiT-XL-2-512x512.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt",
"DiT-XL-2-256x256.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt",
"Latte-XL-2-256x256-ucf101.pt": "https://huggingface.co/maxin-cn/Latte/resolve/main/ucf101.pt",
"PixArt-XL-2-256x256.pth": "PixArt-XL-2-256x256.pth",
"PixArt-XL-2-SAM-256x256.pth": "PixArt-XL-2-SAM-256x256.pth",
"PixArt-XL-2-512x512.pth": "PixArt-XL-2-512x512.pth",
"PixArt-XL-2-1024-MS.pth": "PixArt-XL-2-1024-MS.pth",
}
def reparameter(ckpt, name=None):
if "DiT" in name:
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
elif "Latte" in name:
ckpt = ckpt["ema"]
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
del ckpt["temp_embed"]
elif "PixArt" in name:
ckpt = ckpt["state_dict"]
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
return ckpt
def find_model(model_name):
"""
Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
"""
if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints
model = download_model(model_name)
model = reparameter(model, model_name)
return model
else: # Load a custom DiT checkpoint:
assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}"
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
if "pos_embed_temporal" in checkpoint:
del checkpoint["pos_embed_temporal"]
if "pos_embed" in checkpoint:
del checkpoint["pos_embed"]
if "ema" in checkpoint: # supports checkpoints from train.py
checkpoint = checkpoint["ema"]
return checkpoint
def download_model(model_name):
"""
Downloads a pre-trained DiT model from the web.
"""
assert model_name in pretrained_models
local_path = f"pretrained_models/{model_name}"
if not os.path.isfile(local_path):
os.makedirs("pretrained_models", exist_ok=True)
web_path = pretrained_models[model_name]
download_url(web_path, "pretrained_models", model_name)
model = torch.load(local_path, map_location=lambda storage, loc: storage)
return model
def load_from_sharded_state_dict(model, ckpt_path):
# TODO: harded-coded for colossal loading
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
colossalai.launch_from_torch({})
plugin = LowLevelZeroPlugin(
stage=2,
precision="fp32",
initial_scale=2**16,
)
booster = Booster(plugin=plugin)
model, _, _, _, _ = booster.boost(model=model)
booster.load_model(model, os.path.join(ckpt_path, "model"))
save_path = os.path.join(ckpt_path, "model_ckpt.pt")
torch.save(model.module.state_dict(), save_path)
print(f"Model checkpoint saved to {save_path}")
def model_sharding(model: torch.nn.Module):
global_rank = dist.get_rank()
world_size = dist.get_world_size()
for _, param in model.named_parameters():
padding_size = (world_size - param.numel() % world_size) % world_size
if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
else:
padding_param = param.data.view(-1)
splited_params = padding_param.split(padding_param.numel() // world_size)
splited_params = splited_params[global_rank]
param.data = splited_params
def load_json(file_path: str):
with open(file_path, "r") as f:
return json.load(f)
def save_json(data, file_path: str):
with open(file_path, "w") as f:
json.dump(data, f, indent=4)
def remove_padding(tensor: torch.Tensor, original_shape: Tuple) -> torch.Tensor:
return tensor[: functools.reduce(operator.mul, original_shape)]
def model_gathering(model: torch.nn.Module, model_shape_dict: dict):
global_rank = dist.get_rank()
global_size = dist.get_world_size()
for name, param in model.named_parameters():
all_params = [torch.empty_like(param.data) for _ in range(global_size)]
dist.all_gather(all_params, param.data, group=dist.group.WORLD)
if int(global_rank) == 0:
all_params = torch.cat(all_params)
param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name])
dist.barrier()
def record_model_param_shape(model: torch.nn.Module) -> dict:
param_shape = {}
for name, param in model.named_parameters():
param_shape[name] = param.shape
return param_shape
def save(
booster: Booster,
model: nn.Module,
ema: nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
epoch: int,
step: int,
global_step: int,
batch_size: int,
coordinator: DistCoordinator,
save_dir: str,
shape_dict: dict,
):
save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}")
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
# ema is not boosted, so we don't need to use booster.save_model
model_gathering(ema, shape_dict)
global_rank = dist.get_rank()
if int(global_rank) == 0:
torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt"))
model_sharding(ema)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
if lr_scheduler is not None:
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step,
"global_step": global_step,
"sample_start_index": step * batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
dist.barrier()
def load(
booster: Booster, model: nn.Module, ema: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str
) -> Tuple[int, int, int]:
booster.load_model(model, os.path.join(load_dir, "model"))
# ema is not boosted, so we don't use booster.load_model
# ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt")))
ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")))
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
if lr_scheduler is not None:
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(os.path.join(load_dir, "running_states.json"))
dist.barrier()
return running_states["epoch"], running_states["step"], running_states["sample_start_index"]
def create_logger(logging_dir):
"""
Create a logger that writes to a log file and stdout.
"""
if dist.get_rank() == 0: # real logger
logging.basicConfig(
level=logging.INFO,
format="[\033[34m%(asctime)s\033[0m] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")],
)
logger = logging.getLogger(__name__)
else: # dummy logger (does nothing)
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
return logger
def load_checkpoint(model, ckpt_path, save_as_pt=True):
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
state_dict = find_model(ckpt_path)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print(f"Missing keys: {missing_keys}")
print(f"Unexpected keys: {unexpected_keys}")
elif os.path.isdir(ckpt_path):
load_from_sharded_state_dict(model, ckpt_path)
if save_as_pt:
save_path = os.path.join(ckpt_path, "model_ckpt.pt")
torch.save(model.state_dict(), save_path)
else:
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")

View file

@ -0,0 +1,96 @@
import argparse
import json
import os
from glob import glob
from mmengine.config import Config
from torch.utils.tensorboard import SummaryWriter
def parse_args(training=False):
parser = argparse.ArgumentParser()
# model config
parser.add_argument("config", help="model config file path")
parser.add_argument("--seed", default=42, type=int, help="generation seed")
parser.add_argument("--ckpt-path", type=str, help="path to model ckpt; will overwrite cfg.ckpt_path if specified")
parser.add_argument("--batch-size", default=None, type=int, help="batch size")
# ======================================================
# Inference
# ======================================================
if not training:
# prompt
parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file")
parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples")
# hyperparameters
parser.add_argument("--num-sampling-steps", default=None, type=int, help="sampling steps")
parser.add_argument("--cfg-scale", default=None, type=float, help="balance between cond & uncond")
else:
parser.add_argument("--wandb", default=None, type=bool, help="enable wandb")
parser.add_argument("--load", default=None, type=str, help="path to continue training")
return parser.parse_args()
def merge_args(cfg, args, training=False):
if args.ckpt_path is not None:
cfg.model["from_pratrained"] = args.ckpt_path
args.ckpt_path = None
if not training:
if args.cfg_scale is not None:
cfg.scheduler["cfg_scale"] = args.cfg_scale
args.cfg_scale = None
if "multi_resolution" not in cfg:
cfg["multi_resolution"] = False
for k, v in vars(args).items():
if k in cfg and v is not None:
cfg[k] = v
return cfg
def parse_configs(training=False):
args = parse_args(training)
cfg = Config.fromfile(args.config)
cfg = merge_args(cfg, args, training)
return cfg
def create_experiment_workspace(cfg):
"""
This function creates a folder for experiment tracking.
Args:
args: The parsed arguments.
Returns:
exp_dir: The path to the experiment folder.
"""
# Make outputs folder (holds all experiment subfolders)
os.makedirs(cfg.outputs, exist_ok=True)
experiment_index = len(glob(f"{cfg.outputs}/*"))
# Create an experiment folder
model_name = cfg.model["type"].replace("/", "-")
exp_name = f"{experiment_index:03d}-F{cfg.num_frames}S{cfg.frame_interval}-{model_name}"
exp_dir = f"{cfg.outputs}/{exp_name}"
os.makedirs(exp_dir, exist_ok=True)
return exp_name, exp_dir
def save_training_config(cfg, experiment_dir):
with open(f"{experiment_dir}/config.txt", "w") as f:
json.dump(cfg, f, indent=4)
def create_tensorboard_writer(exp_dir):
tensorboard_dir = f"{exp_dir}/tensorboard"
os.makedirs(tensorboard_dir, exist_ok=True)
writer = SummaryWriter(tensorboard_dir)
return writer

339
opensora/utils/misc.py Normal file
View file

@ -0,0 +1,339 @@
import collections
import importlib
import logging
import os
import time
from collections import OrderedDict
from collections.abc import Sequence
from itertools import repeat
import numpy as np
import torch
import torch.distributed as dist
def print_rank(var_name, var_value, rank=0):
if dist.get_rank() == rank:
print(f"[Rank {rank}] {var_name}: {var_value}")
def print_0(*args, **kwargs):
if dist.get_rank() == 0:
print(*args, **kwargs)
def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
def format_numel_str(numel: int) -> str:
B = 1024**3
M = 1024**2
K = 1024
if numel >= B:
return f"{numel / B:.2f} B"
elif numel >= M:
return f"{numel / M:.2f} M"
elif numel >= K:
return f"{numel / K:.2f} K"
else:
return f"{numel}"
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
return tensor
def get_model_numel(model: torch.nn.Module) -> (int, int):
num_params = 0
num_params_trainable = 0
for p in model.parameters():
num_params += p.numel()
if p.requires_grad:
num_params_trainable += p.numel()
return num_params, num_params_trainable
def try_import(name):
"""Try to import a module.
Args:
name (str): Specifies what module to import in absolute or relative
terms (e.g. either pkg.mod or ..mod).
Returns:
ModuleType or None: If importing successfully, returns the imported
module, otherwise returns None.
"""
try:
return importlib.import_module(name)
except ImportError:
return None
def transpose(x):
"""
transpose a list of list
Args:
x (list[list]):
"""
ret = list(map(list, zip(*x)))
return ret
def get_timestamp():
timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime(time.time()))
return timestamp
def format_time(seconds):
days = int(seconds / 3600 / 24)
seconds = seconds - days * 3600 * 24
hours = int(seconds / 3600)
seconds = seconds - hours * 3600
minutes = int(seconds / 60)
seconds = seconds - minutes * 60
secondsf = int(seconds)
seconds = seconds - secondsf
millis = int(seconds * 1000)
f = ""
i = 1
if days > 0:
f += str(days) + "D"
i += 1
if hours > 0 and i <= 2:
f += str(hours) + "h"
i += 1
if minutes > 0 and i <= 2:
f += str(minutes) + "m"
i += 1
if secondsf > 0 and i <= 2:
f += str(secondsf) + "s"
i += 1
if millis > 0 and i <= 2:
f += str(millis) + "ms"
i += 1
if f == "":
f = "0ms"
return f
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
Args:
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
be converted.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not isinstance(data, str):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
raise TypeError(f"type {type(data)} cannot be converted to tensor.")
def to_ndarray(data):
if isinstance(data, torch.Tensor):
return data.numpy()
elif isinstance(data, np.ndarray):
return data
elif isinstance(data, Sequence):
return np.array(data)
elif isinstance(data, int):
return np.ndarray([data], dtype=int)
elif isinstance(data, float):
return np.array([data], dtype=float)
else:
raise TypeError(f"type {type(data)} cannot be converted to ndarray.")
def to_torch_dtype(dtype):
if isinstance(dtype, torch.dtype):
return dtype
elif isinstance(dtype, str):
dtype_mapping = {
"float64": torch.float64,
"float32": torch.float32,
"float16": torch.float16,
"fp32": torch.float32,
"fp16": torch.float16,
"half": torch.float16,
"bf16": torch.bfloat16,
}
if dtype not in dtype_mapping:
raise ValueError
dtype = dtype_mapping[dtype]
return dtype
else:
raise ValueError
def count_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def convert_SyncBN_to_BN2d(model_cfg):
for k in model_cfg:
v = model_cfg[k]
if k == "norm_cfg" and v["type"] == "SyncBN":
v["type"] = "BN2d"
elif isinstance(v, dict):
convert_SyncBN_to_BN2d(v)
def get_topk(x, dim=4, k=5):
x = to_tensor(x)
inds = x[..., dim].topk(k)[1]
return x[inds]
def param_sigmoid(x, alpha):
ret = 1 / (1 + (-alpha * x).exp())
return ret
def inverse_param_sigmoid(x, alpha, eps=1e-5):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2) / alpha
def inverse_sigmoid(x, eps=1e-5):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
def count_columns(df, columns):
cnt_dict = OrderedDict()
num_samples = len(df)
for col in columns:
d_i = df[col].value_counts().to_dict()
for k in d_i:
d_i[k] = (d_i[k], d_i[k] / num_samples)
cnt_dict[col] = d_i
return cnt_dict
def build_logger(work_dir, cfgname):
log_file = cfgname + ".log"
log_path = os.path.join(work_dir, log_file)
logger = logging.getLogger(cfgname)
logger.setLevel(logging.INFO)
# formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
formatter = logging.Formatter("%(asctime)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
handler1 = logging.FileHandler(log_path)
handler1.setFormatter(formatter)
handler2 = logging.StreamHandler()
handler2.setFormatter(formatter)
logger.addHandler(handler1)
logger.addHandler(handler2)
logger.propagate = False
return logger
timings = {}
from contextlib import contextmanager
@contextmanager
def torch_timer(prefix):
import time
import torch
torch.cuda.synchronize()
start = time.time()
try:
yield
finally:
torch.cuda.synchronize()
t_diff = (time.time() - start) * 1000
if prefix not in timings:
timings[prefix] = []
timings[prefix].append(t_diff)
num_ignored = 10
if len(timings[prefix]) > num_ignored:
# avg = sum(timings[prefix][num_ignored:]) / (len(timings[prefix]) - num_ignored)
avg = sum(timings[prefix][-num_ignored:]) / num_ignored
print("{}: {} ({})".format(prefix, t_diff, avg))
else:
print("{}: {}".format(prefix, t_diff))
def strip_dc(x):
"""
strip DataContainer
"""
try:
import mmcv
except ImportError:
pass
if isinstance(x, dict):
res = {}
for k, v in x.items():
res[k] = strip_dc(v)
return res
if isinstance(x, (list, tuple)) and isinstance(x[0], mmcv.parallel.DataContainer):
return strip_dc(x[0])
elif isinstance(x, mmcv.parallel.DataContainer):
return strip_dc(x.data)
return x

View file

@ -0,0 +1,34 @@
from collections import OrderedDict
import torch
@torch.no_grad()
def update_ema(
ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True
) -> None:
"""
Step the EMA model towards the current model.
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
if name == "pos_embed":
continue
if param.requires_grad == False:
continue
if not sharded:
param_data = param.data
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
else:
if param.data.dtype != torch.float32:
param_id = id(param)
master_param = optimizer._param_store.working_to_master_param[param_id]
param_data = master_param.data
else:
param_data = param.data
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)

0
tools/data/README.md Normal file
View file

0
tools/data/__init__.py Normal file
View file

96
tools/data/csvutil.py Normal file
View file

@ -0,0 +1,96 @@
import argparse
import csv
import os
from tqdm import tqdm
# path, name, #frames
PREFIX = [
"The video shows",
"The video captures",
"The video features",
"The video depicts",
"The video presents",
"The video features",
"The video is ",
"In the video,",
]
def get_video_length(path):
import cv2
cap = cv2.VideoCapture(path)
return int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
def main(args):
input_path = args.input
output_path = args.output
if output_path is None:
name = os.path.basename(input_path)
name, ext = os.path.splitext(name)
if args.num_frames_min is not None:
name += f"_fmin_{args.num_frames_min}"
if args.num_frames_max is not None:
name += f"_fmax_{args.num_frames_max}"
if args.filter_null_text:
name += "_fnt"
if args.remove_prefix:
name += "_rp"
if args.root is not None:
name += f"_root"
if args.relength:
name += "_relength"
output_path = os.path.join(os.path.dirname(input_path), name + ext)
with open(input_path, "r") as f:
reader = csv.reader(f)
data = list(reader)
print("Number of videos before filtering:", len(data))
data_new = []
for i, row in tqdm(enumerate(data)):
path = row[0]
caption = row[1]
n_frames = int(row[2])
if args.num_frames_min is not None and n_frames < args.num_frames_min:
continue
if args.num_frames_max is not None and n_frames > args.num_frames_max:
continue
if args.filter_null_text and len(caption) == 0:
continue
if args.remove_prefix:
for prefix in PREFIX:
if caption.startswith(prefix):
caption = caption[len(prefix) :].strip()
if caption[0].islower():
caption = caption[0].upper() + caption[1:]
row[1] = caption
break
if args.root is not None:
row[0] = os.path.join(args.root, path)
if args.relength:
n_frames = get_video_length(row[0])
row[2] = n_frames
data_new.append(row)
print("Number of videos after filtering:", len(data_new))
with open(output_path, "w") as f:
writer = csv.writer(f)
writer.writerows(data_new)
print("Output saved to", output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=True)
parser.add_argument("--output", type=str, default=None)
parser.add_argument("--num_frames_min", type=int, default=None)
parser.add_argument("--num_frames_max", type=int, default=None)
parser.add_argument("--filter_null_text", action="store_true")
parser.add_argument("--remove_prefix", action="store_true")
parser.add_argument("--root", type=str, default=None)
parser.add_argument("--relength", action="store_true")
args = parser.parse_args()
main(args)

View file

@ -0,0 +1,82 @@
import argparse
import base64
import csv
import os
import cv2
import requests
import tqdm
# OpenAI API Key
api_key = ""
def get_caption(frame):
prompt = "The middle frame from a video clip are given. Describe this video and its style to generate a description for the video. The description should be useful for AI to re-generate the video. Here are some examples of good descriptions:\n\n 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.\n2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.\n 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliffs edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway."
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
payload = {
"model": "gpt-4-vision-preview",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt,
},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame}"}},
],
}
],
"max_tokens": 300,
}
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=60)
caption = response.json()["choices"][0]["message"]["content"]
return caption
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def extract_frames(video_path):
cap = cv2.VideoCapture(video_path)
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
point = length // 2
cap.set(cv2.CAP_PROP_POS_FRAMES, point)
ret, frame = cap.read()
_, buffer = cv2.imencode(".jpg", frame)
img_base64 = base64.b64encode(buffer).decode("utf-8")
return img_base64
def main(args):
processed_videos = []
if os.path.exists(args.output_file):
with open(args.output_file, "r") as f:
reader = csv.reader(f)
samples = list(reader)
processed_videos = [sample[0] for sample in samples]
f = open(args.output_file, "a")
writer = csv.writer(f)
for video in tqdm.tqdm(os.listdir(args.video_folder)):
if video in processed_videos:
continue
video_path = os.path.join(args.video_folder, video)
base64_image = extract_frames(video_path)
caption = get_caption(base64_image)
caption = caption.replace("\n", " ")
writer.writerow([video, caption])
f.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--video_folder", type=str, required=True, help="Path to the folder containing the videos.")
parser.add_argument("--output_file", type=str, default="video_captions.csv", help="Path to the output file.")
args = parser.parse_args()
main(args)

View file

@ -0,0 +1,382 @@
import argparse
import csv
import os
import cv2
import torch
from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
from llava.conversation import conv_templates
from llava.eval.run_llava import eval_model
from llava.mm_utils import get_anyres_image_grid_shape, get_model_name_from_path, process_images, tokenizer_image_token
from llava.model.builder import load_pretrained_model
from llava.model.llava_arch import unpad_image
from llava.utils import disable_torch_init
from PIL import Image
from tqdm import tqdm
disable_torch_init()
prompts = {
"three_frames": "A video is given by providing three frames in chronological order. Describe this video and its style to generate a description. Pay attention to all objects in the video. Do not describe each frame individually. Do not reply with words like 'first frame'. The description should be useful for AI to re-generate the video. The description should be less than six sentences. Here are some examples of good descriptions: 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliffs edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.",
}
def get_filelist(file_path):
Filelist = []
for home, dirs, files in os.walk(file_path):
for filename in files:
Filelist.append(os.path.join(home, filename))
return Filelist
def get_video_length(cap):
return int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
def extract_frames(video_path, points=[0.2, 0.5, 0.8]):
cap = cv2.VideoCapture(video_path)
length = get_video_length(cap)
points = [int(length * point) for point in points]
frames = []
if length < 3:
return frames, length
for point in points:
cap.set(cv2.CAP_PROP_POS_FRAMES, point)
ret, frame = cap.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frames.append(frame)
return frames, length
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
):
# llava_arch.py
vision_tower = self.get_vision_tower()
if vision_tower is None or images is None or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels
if type(images) is list or images.ndim == 5:
if type(images) is list:
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
if mm_patch_merge_type == "flat":
image_features = [x.flatten(0, 1) for x in image_features]
elif mm_patch_merge_type.startswith("spatial"):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.get_vision_tower().num_patches_per_side
assert height * width == base_image_feature.shape[0]
if image_aspect_ratio == "anyres":
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.get_vision_tower().config.image_size,
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
else:
raise NotImplementedError
if "unpad" in mm_patch_merge_type:
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
image_feature = image_feature.flatten(0, 3)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if "unpad" in mm_patch_merge_type:
image_feature = torch.cat(
(image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0
)
new_image_features.append(image_feature)
image_features = new_image_features
else:
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
else:
image_features = self.encode_images(images)
# TODO: image start / end is not implemented here to support pretraining.
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
raise NotImplementedError
# Let's just add dummy tensors if they do not exist,
# it is a headache to deal with None all the time.
# But it is not ideal, and if you have a better idea,
# please open an issue / submit a PR, thanks.
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# remove the padding using attention_mask -- FIXME
_input_ids = input_ids
input_ids = [
cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
]
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = (
[-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
)
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],),
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype,
)
)
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
if tokenizer_model_max_length is not None:
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device
)
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
new_input_embeds_padded.append(
torch.cat(
(
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
cur_new_embed,
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
else:
new_input_embeds_padded.append(
torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
@torch.inference_mode()
def main(args):
bs = args.bs
video_folder = args.video_folder
processed_videos = []
if os.path.exists(args.output_file):
with open(args.output_file, "r") as f:
reader = csv.reader(f)
samples = list(reader)
processed_videos = [sample[0] for sample in samples]
f = open(args.output_file, "a")
writer = csv.writer(f)
model_path = "liuhaotian/llava-v1.6-34b"
query = prompts["three_frames"]
conv = conv_templates["chatml_direct"].copy()
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + query)
prompt = conv.get_prompt()
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=get_model_name_from_path(model_path),
)
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
input_ids = input_ids.unsqueeze(0).to(model.device)
videos = get_filelist(video_folder)
print(f"Dataset contains {len(videos)} videos.")
videos = [video for video in videos if video not in processed_videos]
print(f"Processing {len(videos)} new videos.")
for i in tqdm(range(0, len(videos), bs)):
# prepare a batch of inputs
video_files = videos[i : i + bs]
frames = []
video_lengths = []
for video_file in video_files:
frame, length = extract_frames(os.path.join(video_folder, video_file))
if len(frame) < 3:
continue
frames.append(frame)
video_lengths.append(length)
if len(frames) == 0:
continue
# encode the batch of inputs
samples = []
for imgs in frames:
imgs_size = [img.size for img in imgs]
imgs = process_images(imgs, image_processor, model.config)
imgs = imgs.to(model.device, dtype=torch.float16)
with torch.inference_mode():
_, _, _, _, inputs_embeds, _ = prepare_inputs_labels_for_multimodal(
model, input_ids, None, None, None, None, images=imgs, image_sizes=imgs_size
)
samples.append(inputs_embeds)
# padding
max_len = max([sample.shape[1] for sample in samples])
attention_mask = torch.tensor(
[[0] * (max_len - samples[i].shape[1]) + [1] * samples[i].shape[1] for i in range(len(samples))]
).to(model.device)
inputs_embeds = [
torch.cat(
[
torch.zeros(
(1, max_len - samples[i].shape[1], samples[i].shape[-1]),
device=model.device,
dtype=torch.float16,
),
samples[i],
],
dim=1,
)
for i in range(len(samples))
]
inputs_embeds = torch.cat(inputs_embeds, dim=0)
# generate outputs
output_ids = super(type(model), model).generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
do_sample=True,
temperature=0.2,
max_new_tokens=512,
use_cache=True,
)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
outputs = [output.replace("\n", " ").strip() for output in outputs]
# save results
result = list(zip(video_files, outputs, video_lengths))
for t in result:
writer.writerow(t)
f.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--video_folder", type=str, required=True)
parser.add_argument("--bs", type=int, default=32)
parser.add_argument("--output_file", type=str, default="video_captions.csv")
args = parser.parse_args()
main(args)

139
tools/data/scene_detect.py Normal file
View file

@ -0,0 +1,139 @@
import os
from tqdm import tqdm
from multiprocessing import Pool
from mmengine.logging import MMLogger, print_log
from scenedetect import detect, ContentDetector
from .utils import check_mp4_integrity, split_video
from .utils import clone_folder_structure, iterate_files, iterate_folders
from opensora.utils.misc import get_timestamp
# config
target_fps = 30 # int
shorter_size = 512 # int
min_seconds = 1 # float
max_seconds = 5 # float
assert max_seconds > min_seconds
cfg = dict(
target_fps=target_fps,
min_seconds=min_seconds,
max_seconds=max_seconds,
shorter_size=shorter_size,
)
def process_folder(root_src, root_dst):
# create logger
folder_path_log = os.path.dirname(root_dst)
log_name = os.path.basename(root_dst)
timestamp = get_timestamp()
log_path = os.path.join(folder_path_log, f'{log_name}_{timestamp}.log')
logger = MMLogger.get_instance(log_name, log_file=log_path)
# clone folder structure
clone_folder_structure(root_src, root_dst)
# all source videos
mp4_list = [x for x in iterate_files(root_src) if x.endswith('.mp4')]
mp4_list = sorted(mp4_list)
for idx, sample_path in tqdm(enumerate(mp4_list)):
folder_src = os.path.dirname(sample_path)
folder_dst = os.path.join(root_dst, os.path.relpath(folder_src, root_src))
# check src video integrity
if not check_mp4_integrity(sample_path, logger=logger):
continue
# detect scenes
scene_list = detect(sample_path, ContentDetector(), start_in_scene=True)
# split scenes
save_path_list = split_video(sample_path, scene_list, save_dir=folder_dst, **cfg, logger=logger)
# check integrity of generated clips
for x in save_path_list:
check_mp4_integrity(x, logger=logger)
def scene_detect():
""" detect & cut scenes using a single process
Expected dataset structure:
data/
your_dataset/
raw_videos/
xxx.mp4
yyy.mp4
This function results in:
data/
your_dataset/
raw_videos/
xxx.mp4
yyy.mp4
zzz.mp4
clips/
xxx_scene-0.mp4
yyy_scene-0.mp4
yyy_scene-1.mp4
"""
# TODO: specify your dataset root
root_src = f'./data/your_dataset/raw_videos'
root_dst = f'./data/your_dataset/clips'
process_folder(root_src, root_dst)
def scene_detect_mp():
""" detect & cut scenes using multiple processes
Expected dataset structure:
data/
your_dataset/
raw_videos/
split_0/
xxx.mp4
yyy.mp4
split_1/
xxx.mp4
yyy.mp4
This function results in:
data/
your_dataset/
raw_videos/
split_0/
xxx.mp4
yyy.mp4
split_1/
xxx.mp4
yyy.mp4
clips/
split_0/
xxx_scene-0.mp4
yyy_scene-0.mp4
split_1/
xxx_scene-0.mp4
yyy_scene-0.mp4
yyy_scene-1.mp4
"""
# TODO: specify your dataset root
root_src = f'./data/your_dataset/raw_videos'
root_dst = f'./data/your_dataset/clips'
# TODO: specify your splits
splits = ['split_0', 'split_1']
# process folders
root_src_list = [os.path.join(root_src, x) for x in splits]
root_dst_list = [os.path.join(root_dst, x) for x in splits]
with Pool(processes=len(splits)) as pool:
pool.starmap(process_folder, list(zip(root_src_list, root_dst_list)))
if __name__ == '__main__':
# TODO: choose single process or multiprocessing
scene_detect()
# scene_detect_mp()

View file

@ -0,0 +1,17 @@
import csv
import os
from torchvision.datasets import ImageNet
root = "~/data/imagenet"
split = "train"
root = os.path.expanduser(root)
data = ImageNet(root, split=split)
samples = [(path, data.classes[label][0]) for path, label in data.samples]
with open(f"preprocess/imagenet_{split}.csv", "w") as f:
writer = csv.writer(f)
writer.writerows(samples)
print(f"Saved {len(samples)} samples to preprocess/imagenet_{split}.csv.")

View file

@ -0,0 +1,36 @@
import csv
import os
def get_filelist(file_path):
Filelist = []
for home, dirs, files in os.walk(file_path):
for filename in files:
Filelist.append(os.path.join(home, filename))
return Filelist
def split_by_capital(name):
# BoxingPunchingBag -> Boxing Punching Bag
new_name = ""
for i in range(len(name)):
if name[i].isupper() and i != 0:
new_name += " "
new_name += name[i]
return new_name
root = "~/data/ucf101"
split = "videos"
root = os.path.expanduser(root)
video_lists = get_filelist(os.path.join(root, split))
classes = [x.split("/")[-2] for x in video_lists]
classes = [split_by_capital(x) for x in classes]
samples = list(zip(video_lists, classes))
with open(f"preprocess/ucf101_{split}.csv", "w") as f:
writer = csv.writer(f)
writer.writerows(samples)
print(f"Saved {len(samples)} samples to preprocess/ucf101_{split}.csv.")

148
tools/data/utils.py Normal file
View file

@ -0,0 +1,148 @@
import os
import cv2
import subprocess
from mmengine.logging import print_log
from moviepy.editor import VideoFileClip
from imageio_ffmpeg import get_ffmpeg_exe
from scenedetect import FrameTimecode
def iterate_files(folder_path):
for root, dirs, files in os.walk(folder_path):
# root contains the current directory path
# dirs contains the list of subdirectories in the current directory
# files contains the list of files in the current directory
# Process files in the current directory
for file in files:
file_path = os.path.join(root, file)
# print("File:", file_path)
yield file_path
# Process subdirectories and recursively call the function
for subdir in dirs:
subdir_path = os.path.join(root, subdir)
# print("Subdirectory:", subdir_path)
iterate_files(subdir_path)
def iterate_folders(folder_path):
for root, dirs, files in os.walk(folder_path):
for subdir in dirs:
subdir_path = os.path.join(root, subdir)
yield subdir_path
# print("Subdirectory:", subdir_path)
iterate_folders(subdir_path)
def clone_folder_structure(root_src, root_dst, verbose=False):
src_path_list = iterate_folders(root_src)
src_relpath_list = [os.path.relpath(x, root_src) for x in src_path_list]
os.makedirs(root_dst, exist_ok=True)
dst_path_list = [os.path.join(root_dst, x) for x in src_relpath_list]
for folder_path in dst_path_list:
os.makedirs(folder_path, exist_ok=True)
if verbose:
print(f'Create folder: \'{folder_path}\'')
def count_files(root, suffix='.mp4'):
files_list = iterate_files(root)
cnt = len([x for x in files_list if x.endswith(suffix)])
return cnt
def check_mp4_integrity(file_path, verbose=True, logger=None):
try:
video_clip = VideoFileClip(file_path)
if verbose:
print_log(f'The MP4 file \'{file_path}\' is intact.', logger=logger)
return True
except Exception as e:
if verbose:
print_log(f'Error: {e}', logger=logger)
print_log(f'The MP4 file \'{file_path}\' is not intact.', logger=logger)
return False
def count_frames(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error: Could not open video file '{video_path}'")
return
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"Total frames in the video '{video_path}': {total_frames}")
cap.release()
def split_video(sample_path,
scene_list,
save_dir,
target_fps=30,
min_seconds=1,
max_seconds=10,
shorter_size=512,
verbose=False,
logger=None,
):
FFMPEG_PATH = get_ffmpeg_exe()
save_path_list = []
for idx, scene in enumerate(scene_list):
s, t = scene # FrameTimecode
fps = s.framerate
max_duration = FrameTimecode(timecode='00:00:00', fps=fps)
max_duration.frame_num = round(fps * max_seconds)
duration = min(max_duration, t - s)
if duration.get_frames() < round(min_seconds * fps):
continue
# save path
fname = os.path.basename(sample_path)
fname_wo_ext = os.path.splitext(fname)[0]
# TODO: fname pattern
save_path = os.path.join(save_dir, f'{fname_wo_ext}_scene-{idx}.mp4')
# ffmpeg cmd
cmd = [FFMPEG_PATH]
# Only show ffmpeg output for the first call, which will display any
# errors if it fails, and then break the loop. We only show error messages
# for the remaining calls.
# cmd += ['-v', 'error']
# input path
cmd += ['-i', sample_path]
# clip to cut
cmd += [
'-nostdin', '-y',
'-ss', str(s.get_seconds()),
'-t', str(duration.get_seconds())
]
# target fps
# cmd += ['-vf', 'select=mod(n\,2)']
cmd += ['-r', f'{target_fps}']
# aspect ratio
cmd += ['-vf', f"scale='if(gt(iw,ih),-2,{shorter_size})':'if(gt(iw,ih),{shorter_size},-2)'"]
# cmd += ['-vf', f"scale='if(gt(iw,ih),{shorter_size},trunc(ow/a/2)*2)':-2"]
cmd += ['-map', '0', save_path]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
stdout, stderr = proc.communicate()
if verbose:
stdout = stdout.decode('utf-8')
print_log(stdout, logger=logger)
save_path_list.append(sample_path)
print_log(f'Video clip saved to \'{save_path}\'', logger=logger)
return save_path_list