mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
first commit (#67)
This commit is contained in:
parent
9aab3ad343
commit
0ba56f5309
149
README.md
Normal file
149
README.md
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
<p align="center">
|
||||
<img src="./assets/readme/icon_zw.png" width="250"/>
|
||||
<p>
|
||||
|
||||
</p>
|
||||
<div align="center">
|
||||
<a href="https://github.com/hpcaitech/Open-Sora/stargazers"><img src="https://img.shields.io/github/stars/hpcaitech/Open-Sora?style=social"></a>
|
||||
<a href="https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack"><img src="https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&"></a>
|
||||
<a href="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png"><img src="https://img.shields.io/badge/微信-加入-green?logo=wechat&"></a>
|
||||
</div>
|
||||
|
||||
## Open-Sora: Towards Open Reproduction of Sora
|
||||
|
||||
**Open-Sora** is an **open-source** initiative dedicated to **efficiently** reproducing OpenAI's Sora. Our project aims to cover **the full pipeline**, including video data preprocessing, training with acceleration, efficient inference and more. Operating on a limited budget, we prioritize the vibrant open-source community, providing access to text-to-image, image captioning, and language models. We hope to make a contribution to the community and make the project more accessible to everyone.
|
||||
|
||||
## 📰 News
|
||||
|
||||
* **[2024.03.18]** 🔥 We release **Open-Sora 1.0**, an open-source project to reproduce OpenAI Sora.
|
||||
Open-Sora 1.0 supports a full pipeline of video data preprocessing, training with
|
||||
<a href="https://github.com/hpcaitech/ColossalAI"><img src="assets/readme/colossal_ai.png" width="8%" ></a> acceleration,
|
||||
inference, and more. Our provided checkpoint can produce 2s 512x512 videos.
|
||||
|
||||
## 🎥 Latest Demo
|
||||
|
||||
| **2s 512x512** | **2s 512x512** |
|
||||
| ----------------------------------------------- | ----------------------------------------------- |
|
||||
| <img src="assets/readme/sample_0.gif" width=""> | <img src="assets/readme/sample_0.gif" width=""> |
|
||||
|
||||
## 🔆 New Features/Updates
|
||||
|
||||
- 📍 Open-Sora-v1 is trained on xxx. We train the model in three stages. Model weights are available here. Training details can be found here.
|
||||
- ✅ Support training acceleration including flash-attention, accelerated T5, mixed precision, gradient checkpointing, splitted VAE, sequence parallelism, etc. XXX times. See more discussions [here]().
|
||||
- ✅ We provide video cutting and captioning tools for data preprocessing. Our data collection plan can be found [here]().
|
||||
- ✅ We find VQ-VAE from [] has a low quality and thus adopt a better VAE from []. We also find patching in the time dimension deteriorates the quality. See more discussions [here]().
|
||||
- ✅ We investigate different architectures including DiT, Latte, and our proposed STDiT. Our STDiT achieves a better trade-off between quality and speed. See more discussions [here]().
|
||||
- ✅ Support clip and t5 text conditioning.
|
||||
- ✅ By viewing images as one-frame videos, our project supports training DiT on both images and videos (e.g., ImageNet & UCF101).
|
||||
- ✅ Support inference with official weights from [DiT](https://github.com/facebookresearch/DiT), [Latte](https://github.com/Vchitect/Latte), and [PixArt](https://pixart-alpha.github.io/).
|
||||
|
||||
|
||||
|
||||
### TODO list sorted by priority
|
||||
|
||||
- [ ] Complete the data processing pipeline (including dense optical flow, aesthetics scores, text-image similarity, deduplication, etc.). See [datasets.md]() for more information. **[WIP]**
|
||||
- [ ] Training Video-VAE. **[WIP]**
|
||||
- [ ] Support image and video conditioning.
|
||||
- [ ] Evaluation pipeline.
|
||||
- [ ] Incoporate a better scheduler, e.g., rectified flow in SD3.
|
||||
- [ ] Support variable aspect ratios, resolutions, durations.
|
||||
- [ ] Support SD3 when released.
|
||||
|
||||
|
||||
## Contents
|
||||
|
||||
- [Open-Sora: Towards Open Reproduction of Sora](#open-sora-towards-open-reproduction-of-sora)
|
||||
- [📰 News](#-news)
|
||||
- [🎥 Latest Demo](#-latest-demo)
|
||||
- [🔆 New Features/Updates](#-new-featuresupdates)
|
||||
- [TODO list sorted by priority](#todo-list-sorted-by-priority)
|
||||
- [Contents](#contents)
|
||||
- [Installation](#installation)
|
||||
- [Model Weights](#model-weights)
|
||||
- [Inference](#inference)
|
||||
- [Data Processing](#data-processing)
|
||||
- [Split video into clips](#split-video-into-clips)
|
||||
- [Generate video caption](#generate-video-caption)
|
||||
- [Training](#training)
|
||||
- [Acknowledgement](#acknowledgement)
|
||||
- [Citation](#citation)
|
||||
- [Star History](#star-history)
|
||||
- [TODO](#todo)
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
git clone https://github.com/hpcaitech/Open-Sora
|
||||
cd Open-Sora
|
||||
pip install xxx
|
||||
```
|
||||
|
||||
After installation, to get fimilar with the project, you can check the [here]() for the project structure and how to use the config files.
|
||||
|
||||
## Model Weights
|
||||
|
||||
| Model | #Params | url |
|
||||
| ---------- | ------- | --- |
|
||||
| 16x256x256 | | |
|
||||
|
||||
## Inference
|
||||
|
||||
```bash
|
||||
python scripts/inference.py configs/opensora/inference/16x256x256.py
|
||||
```
|
||||
|
||||
## Data Processing
|
||||
|
||||
### Split video into clips
|
||||
|
||||
We provide code to split a long video into separate clips efficiently using `multiprocessing`. See `tools/data/scene_detect.py`.
|
||||
|
||||
### Generate video caption
|
||||
|
||||
## Training
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
* [DiT](https://github.com/facebookresearch/DiT): Scalable Diffusion Models with Transformers.
|
||||
* [OpenDiT](https://github.com/NUS-HPC-AI-Lab/OpenDiT): An acceleration for DiT training. OpenDiT's team provides valuable suggestions on acceleration of our training process.
|
||||
* [PixArt](https://github.com/PixArt-alpha/PixArt-alpha): An open-source DiT-based text-to-image model.
|
||||
* [Latte](https://github.com/Vchitect/Latte): An attempt to efficiently train DiT for video.
|
||||
* [StabilityAI VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse-original): A powerful image VAE model.
|
||||
* [CLIP](https://github.com/openai/CLIP): A powerful text-image embedding model.
|
||||
* [T5](https://github.com/google-research/text-to-text-transfer-transformer): The powerful text encoder.
|
||||
* [LLaVA](https://github.com/haotian-liu/LLaVA): A powerful image captioning model based on [LLaMA](https://github.com/meta-llama/llama) and [Yi-34B](https://huggingface.co/01-ai/Yi-34B).
|
||||
* [PySceneDetect](https://github.com/Breakthrough/PySceneDetect): A powerful tool to split video into clips.
|
||||
|
||||
We are grateful for their exceptional work and generous contribution to open source.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@software{opensora,
|
||||
author = {Zangwei Zheng and Xiangyu Peng and Shenggui Li and Yang You},
|
||||
title = {Open-Sora: Towards Open Reproduction of Sora},
|
||||
month = {March},
|
||||
year = {2024},
|
||||
url = {https://github.com/hpcaitech/Open-Sora}
|
||||
}
|
||||
```
|
||||
|
||||
Zangwei Zheng and Xiangyu Peng equally contributed to this work.
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#hpcaitech/Open-Sora&Date)
|
||||
|
||||
## TODO
|
||||
|
||||
Modules for releasing:
|
||||
|
||||
* `configs`
|
||||
* `opensora`
|
||||
* `assets`
|
||||
* `scripts`
|
||||
* `tools`
|
||||
|
||||
packages for data processing
|
||||
|
||||
put all outputs under ./checkpoints/, including pretrained_models, checkpoints, samples
|
||||
0
docs/datasets.md
Normal file
0
docs/datasets.md
Normal file
4
opensora/__init__.py
Normal file
4
opensora/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .acceleration import *
|
||||
from .datasets import *
|
||||
from .registry import *
|
||||
from .models import *
|
||||
39
opensora/registry.py
Normal file
39
opensora/registry.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import torch.nn as nn
|
||||
from mmengine.registry import Registry
|
||||
|
||||
|
||||
def build_module(module, builder, **kwargs):
|
||||
"""Build module from config or return the module itself.
|
||||
|
||||
Args:
|
||||
module (Union[dict, nn.Module]): The module to build.
|
||||
builder (Registry): The registry to build module.
|
||||
*args, **kwargs: Arguments passed to build function.
|
||||
|
||||
Returns:
|
||||
Any: The built module.
|
||||
"""
|
||||
if isinstance(module, dict):
|
||||
cfg = deepcopy(module)
|
||||
for k, v in kwargs.items():
|
||||
cfg[k] = v
|
||||
return builder.build(cfg)
|
||||
elif isinstance(module, nn.Module):
|
||||
return module
|
||||
elif module is None:
|
||||
return None
|
||||
else:
|
||||
raise TypeError(f"Only support dict and nn.Module, but got {type(module)}.")
|
||||
|
||||
|
||||
MODELS = Registry(
|
||||
"model",
|
||||
locations=["opensora.models"],
|
||||
)
|
||||
|
||||
SCHEDULERS = Registry(
|
||||
"scheduler",
|
||||
locations=["opensora.schedulers"],
|
||||
)
|
||||
2
opensora/schedulers/__init__.py
Normal file
2
opensora/schedulers/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from .iddpm import IDDPM
|
||||
from .dpms import DPMS
|
||||
50
opensora/schedulers/dpms/__init__.py
Normal file
50
opensora/schedulers/dpms/__init__.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from opensora.registry import SCHEDULERS
|
||||
|
||||
from .dpm_solver import DPMS
|
||||
|
||||
|
||||
@SCHEDULERS.register_module("dpm-solver")
|
||||
class DMP_SOLVER:
|
||||
def __init__(self, num_sampling_steps=None, cfg_scale=4.0):
|
||||
self.num_sampling_steps = num_sampling_steps
|
||||
self.cfg_scale = cfg_scale
|
||||
|
||||
def sample(
|
||||
self,
|
||||
model,
|
||||
text_encoder,
|
||||
z_size,
|
||||
prompts,
|
||||
device,
|
||||
additional_args=None,
|
||||
):
|
||||
n = len(prompts)
|
||||
z = torch.randn(n, *z_size, device=device)
|
||||
model_args = text_encoder.encode(prompts)
|
||||
y = model_args.pop("y")
|
||||
null_y = text_encoder.null(n)
|
||||
if additional_args is not None:
|
||||
model_args.update(additional_args)
|
||||
|
||||
dpms = DPMS(
|
||||
partial(forward_with_dpmsolver, model),
|
||||
condition=y,
|
||||
uncondition=null_y,
|
||||
cfg_scale=self.cfg_scale,
|
||||
model_kwargs=model_args,
|
||||
)
|
||||
samples = dpms.sample(z, steps=self.num_sampling_steps, order=2, skip_type="time_uniform", method="multistep")
|
||||
return samples
|
||||
|
||||
|
||||
def forward_with_dpmsolver(self, x, timestep, y, **kwargs):
|
||||
"""
|
||||
dpm solver donnot need variance prediction
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
||||
model_out = self.forward(x, timestep, y, **kwargs)
|
||||
return model_out.chunk(2, dim=1)[0]
|
||||
1558
opensora/schedulers/dpms/dpm_solver.py
Normal file
1558
opensora/schedulers/dpms/dpm_solver.py
Normal file
File diff suppressed because it is too large
Load diff
95
opensora/schedulers/iddpm/__init__.py
Normal file
95
opensora/schedulers/iddpm/__init__.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from opensora.registry import SCHEDULERS
|
||||
|
||||
from . import gaussian_diffusion as gd
|
||||
from .respace import SpacedDiffusion, space_timesteps
|
||||
|
||||
|
||||
@SCHEDULERS.register_module("iddpm")
|
||||
class IDDPM(SpacedDiffusion):
|
||||
def __init__(
|
||||
self,
|
||||
num_sampling_steps=None,
|
||||
timestep_respacing=None,
|
||||
noise_schedule="linear",
|
||||
use_kl=False,
|
||||
sigma_small=False,
|
||||
predict_xstart=False,
|
||||
learn_sigma=True,
|
||||
rescale_learned_sigmas=False,
|
||||
diffusion_steps=1000,
|
||||
cfg_scale=4.0,
|
||||
):
|
||||
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
|
||||
if use_kl:
|
||||
loss_type = gd.LossType.RESCALED_KL
|
||||
elif rescale_learned_sigmas:
|
||||
loss_type = gd.LossType.RESCALED_MSE
|
||||
else:
|
||||
loss_type = gd.LossType.MSE
|
||||
if num_sampling_steps is not None:
|
||||
assert timestep_respacing is None
|
||||
timestep_respacing = str(num_sampling_steps)
|
||||
if timestep_respacing is None or timestep_respacing == "":
|
||||
timestep_respacing = [diffusion_steps]
|
||||
super().__init__(
|
||||
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
|
||||
betas=betas,
|
||||
model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X),
|
||||
model_var_type=(
|
||||
(gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
|
||||
if not learn_sigma
|
||||
else gd.ModelVarType.LEARNED_RANGE
|
||||
),
|
||||
loss_type=loss_type,
|
||||
# rescale_timesteps=rescale_timesteps,
|
||||
)
|
||||
|
||||
self.cfg_scale = cfg_scale
|
||||
|
||||
def sample(
|
||||
self,
|
||||
model,
|
||||
text_encoder,
|
||||
z_size,
|
||||
prompts,
|
||||
device,
|
||||
additional_args=None,
|
||||
):
|
||||
n = len(prompts)
|
||||
z = torch.randn(n, *z_size, device=device)
|
||||
z = torch.cat([z, z], 0)
|
||||
model_args = text_encoder.encode(prompts)
|
||||
y_null = text_encoder.null(n)
|
||||
model_args["y"] = torch.cat([model_args["y"], y_null], 0)
|
||||
if additional_args is not None:
|
||||
model_args.update(additional_args)
|
||||
|
||||
forward = partial(forward_with_cfg, model, cfg_scale=self.cfg_scale)
|
||||
samples = self.p_sample_loop(
|
||||
forward,
|
||||
z.shape,
|
||||
z,
|
||||
clip_denoised=False,
|
||||
model_kwargs=model_args,
|
||||
progress=True,
|
||||
device=device,
|
||||
)
|
||||
samples, _ = samples.chunk(2, dim=0)
|
||||
return samples
|
||||
|
||||
|
||||
def forward_with_cfg(model, x, timestep, y, cfg_scale, **kwargs):
|
||||
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
||||
half = x[: len(x) // 2]
|
||||
combined = torch.cat([half, half], dim=0)
|
||||
model_out = model.forward(combined, timestep, y, **kwargs)
|
||||
model_out = model_out["x"] if isinstance(model_out, dict) else model_out
|
||||
eps, rest = model_out[:, :3], model_out[:, 3:]
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
||||
eps = torch.cat([half_eps, half_eps], dim=0)
|
||||
return torch.cat([eps, rest], dim=1)
|
||||
87
opensora/schedulers/iddpm/diffusion_utils.py
Normal file
87
opensora/schedulers/iddpm/diffusion_utils.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
# Adapted from DiT
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# DiT: https://github.com/facebookresearch/DiT/tree/main
|
||||
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
||||
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
||||
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, th.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for th.exp().
|
||||
logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)]
|
||||
|
||||
return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2))
|
||||
|
||||
|
||||
def approx_standard_normal_cdf(x):
|
||||
"""
|
||||
A fast approximation of the cumulative distribution function of the
|
||||
standard normal.
|
||||
"""
|
||||
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
||||
|
||||
|
||||
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
|
||||
"""
|
||||
Compute the log-likelihood of a continuous Gaussian distribution.
|
||||
:param x: the targets
|
||||
:param means: the Gaussian mean Tensor.
|
||||
:param log_scales: the Gaussian log stddev Tensor.
|
||||
:return: a tensor like x of log probabilities (in nats).
|
||||
"""
|
||||
centered_x = x - means
|
||||
inv_stdv = th.exp(-log_scales)
|
||||
normalized_x = centered_x * inv_stdv
|
||||
log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
|
||||
return log_probs
|
||||
|
||||
|
||||
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
||||
"""
|
||||
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
||||
given image.
|
||||
:param x: the target images. It is assumed that this was uint8 values,
|
||||
rescaled to the range [-1, 1].
|
||||
:param means: the Gaussian mean Tensor.
|
||||
:param log_scales: the Gaussian log stddev Tensor.
|
||||
:return: a tensor like x of log probabilities (in nats).
|
||||
"""
|
||||
assert x.shape == means.shape == log_scales.shape
|
||||
centered_x = x - means
|
||||
inv_stdv = th.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
||||
cdf_plus = approx_standard_normal_cdf(plus_in)
|
||||
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
||||
cdf_min = approx_standard_normal_cdf(min_in)
|
||||
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
||||
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
log_probs = th.where(
|
||||
x < -0.999,
|
||||
log_cdf_plus,
|
||||
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
||||
)
|
||||
assert log_probs.shape == x.shape
|
||||
return log_probs
|
||||
835
opensora/schedulers/iddpm/gaussian_diffusion.py
Normal file
835
opensora/schedulers/iddpm/gaussian_diffusion.py
Normal file
|
|
@ -0,0 +1,835 @@
|
|||
# Adapted from DiT
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# DiT: https://github.com/facebookresearch/DiT/tree/main
|
||||
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
||||
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
||||
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# --------------------------------------------------------
|
||||
|
||||
import enum
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
class ModelMeanType(enum.Enum):
|
||||
"""
|
||||
Which type of output the model predicts.
|
||||
"""
|
||||
|
||||
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
||||
START_X = enum.auto() # the model predicts x_0
|
||||
EPSILON = enum.auto() # the model predicts epsilon
|
||||
|
||||
|
||||
class ModelVarType(enum.Enum):
|
||||
"""
|
||||
What is used as the model's output variance.
|
||||
The LEARNED_RANGE option has been added to allow the model to predict
|
||||
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
||||
"""
|
||||
|
||||
LEARNED = enum.auto()
|
||||
FIXED_SMALL = enum.auto()
|
||||
FIXED_LARGE = enum.auto()
|
||||
LEARNED_RANGE = enum.auto()
|
||||
|
||||
|
||||
class LossType(enum.Enum):
|
||||
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
||||
RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances)
|
||||
KL = enum.auto() # use the variational lower-bound
|
||||
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
||||
|
||||
def is_vb(self):
|
||||
return self == LossType.KL or self == LossType.RESCALED_KL
|
||||
|
||||
|
||||
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
||||
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
||||
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
||||
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
||||
return betas
|
||||
|
||||
|
||||
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
||||
"""
|
||||
This is the deprecated API for creating beta schedules.
|
||||
See get_named_beta_schedule() for the new library of schedules.
|
||||
"""
|
||||
if beta_schedule == "quad":
|
||||
betas = (
|
||||
np.linspace(
|
||||
beta_start**0.5,
|
||||
beta_end**0.5,
|
||||
num_diffusion_timesteps,
|
||||
dtype=np.float64,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
elif beta_schedule == "linear":
|
||||
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
||||
elif beta_schedule == "warmup10":
|
||||
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
||||
elif beta_schedule == "warmup50":
|
||||
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
||||
elif beta_schedule == "const":
|
||||
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
||||
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
||||
betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
|
||||
else:
|
||||
raise NotImplementedError(beta_schedule)
|
||||
assert betas.shape == (num_diffusion_timesteps,)
|
||||
return betas
|
||||
|
||||
|
||||
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
||||
"""
|
||||
Get a pre-defined beta schedule for the given name.
|
||||
The beta schedule library consists of beta schedules which remain similar
|
||||
in the limit of num_diffusion_timesteps.
|
||||
Beta schedules may be added, but should not be removed or changed once
|
||||
they are committed to maintain backwards compatibility.
|
||||
"""
|
||||
if schedule_name == "linear":
|
||||
# Linear schedule from Ho et al, extended to work for any number of
|
||||
# diffusion steps.
|
||||
scale = 1000 / num_diffusion_timesteps
|
||||
return get_beta_schedule(
|
||||
"linear",
|
||||
beta_start=scale * 0.0001,
|
||||
beta_end=scale * 0.02,
|
||||
num_diffusion_timesteps=num_diffusion_timesteps,
|
||||
)
|
||||
elif schedule_name == "squaredcos_cap_v2":
|
||||
return betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas)
|
||||
|
||||
|
||||
class GaussianDiffusion:
|
||||
"""
|
||||
Utilities for training and sampling diffusion models.
|
||||
Original ported from this codebase:
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
||||
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
||||
starting at T and going to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, *, betas, model_mean_type, model_var_type, loss_type):
|
||||
self.model_mean_type = model_mean_type
|
||||
self.model_var_type = model_var_type
|
||||
self.loss_type = loss_type
|
||||
|
||||
# Use float64 for accuracy.
|
||||
betas = np.array(betas, dtype=np.float64)
|
||||
self.betas = betas
|
||||
assert len(betas.shape) == 1, "betas must be 1-D"
|
||||
assert (betas > 0).all() and (betas <= 1).all()
|
||||
|
||||
self.num_timesteps = int(betas.shape[0])
|
||||
|
||||
alphas = 1.0 - betas
|
||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
||||
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
||||
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
||||
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
||||
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
||||
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
||||
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
self.posterior_log_variance_clipped = (
|
||||
np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
|
||||
if len(self.posterior_variance) > 1
|
||||
else np.array([])
|
||||
)
|
||||
|
||||
self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
"""
|
||||
Get the distribution q(x_t | x_0).
|
||||
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
||||
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
||||
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
||||
"""
|
||||
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
||||
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
return mean, variance, log_variance
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
"""
|
||||
Diffuse the data for a given number of diffusion steps.
|
||||
In other words, sample from q(x_t | x_0).
|
||||
:param x_start: the initial data batch.
|
||||
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
||||
:param noise: if specified, the split-out normal noise.
|
||||
:return: A noisy version of x_start.
|
||||
"""
|
||||
if noise is None:
|
||||
noise = th.randn_like(x_start)
|
||||
assert noise.shape == x_start.shape
|
||||
return (
|
||||
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def q_posterior_mean_variance(self, x_start, x_t, t):
|
||||
"""
|
||||
Compute the mean and variance of the diffusion posterior:
|
||||
q(x_{t-1} | x_t, x_0)
|
||||
"""
|
||||
assert x_start.shape == x_t.shape
|
||||
posterior_mean = (
|
||||
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
||||
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
assert (
|
||||
posterior_mean.shape[0]
|
||||
== posterior_variance.shape[0]
|
||||
== posterior_log_variance_clipped.shape[0]
|
||||
== x_start.shape[0]
|
||||
)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
|
||||
"""
|
||||
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
||||
the initial x, x_0.
|
||||
:param model: the model, which takes a signal and a batch of timesteps
|
||||
as input.
|
||||
:param x: the [N x C x ...] tensor at time t.
|
||||
:param t: a 1-D Tensor of timesteps.
|
||||
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
||||
:param denoised_fn: if not None, a function which applies to the
|
||||
x_start prediction before it is used to sample. Applies before
|
||||
clip_denoised.
|
||||
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
||||
pass to the model. This can be used for conditioning.
|
||||
:return: a dict with the following keys:
|
||||
- 'mean': the model mean output.
|
||||
- 'variance': the model variance output.
|
||||
- 'log_variance': the log of 'variance'.
|
||||
- 'pred_xstart': the prediction for x_0.
|
||||
"""
|
||||
if model_kwargs is None:
|
||||
model_kwargs = {}
|
||||
|
||||
B, C = x.shape[:2]
|
||||
assert t.shape == (B,)
|
||||
model_output = model(x, t, **model_kwargs)
|
||||
if isinstance(model_output, tuple):
|
||||
model_output, extra = model_output
|
||||
else:
|
||||
extra = None
|
||||
|
||||
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
||||
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
||||
model_output, model_var_values = th.split(model_output, C, dim=1)
|
||||
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
||||
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
||||
# The model_var_values is [-1, 1] for [min_var, max_var].
|
||||
frac = (model_var_values + 1) / 2
|
||||
model_log_variance = frac * max_log + (1 - frac) * min_log
|
||||
model_variance = th.exp(model_log_variance)
|
||||
else:
|
||||
model_variance, model_log_variance = {
|
||||
# for fixedlarge, we set the initial (log-)variance like so
|
||||
# to get a better decoder log likelihood.
|
||||
ModelVarType.FIXED_LARGE: (
|
||||
np.append(self.posterior_variance[1], self.betas[1:]),
|
||||
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
||||
),
|
||||
ModelVarType.FIXED_SMALL: (
|
||||
self.posterior_variance,
|
||||
self.posterior_log_variance_clipped,
|
||||
),
|
||||
}[self.model_var_type]
|
||||
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
||||
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
||||
|
||||
def process_xstart(x):
|
||||
if denoised_fn is not None:
|
||||
x = denoised_fn(x)
|
||||
if clip_denoised:
|
||||
return x.clamp(-1, 1)
|
||||
return x
|
||||
|
||||
if self.model_mean_type == ModelMeanType.START_X:
|
||||
pred_xstart = process_xstart(model_output)
|
||||
else:
|
||||
pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
|
||||
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
||||
|
||||
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
||||
return {
|
||||
"mean": model_mean,
|
||||
"variance": model_variance,
|
||||
"log_variance": model_log_variance,
|
||||
"pred_xstart": pred_xstart,
|
||||
"extra": extra,
|
||||
}
|
||||
|
||||
def _predict_xstart_from_eps(self, x_t, t, eps):
|
||||
assert x_t.shape == eps.shape
|
||||
return (
|
||||
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
||||
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
||||
)
|
||||
|
||||
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
||||
return (
|
||||
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
||||
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
||||
|
||||
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
||||
"""
|
||||
Compute the mean for the previous step, given a function cond_fn that
|
||||
computes the gradient of a conditional log probability with respect to
|
||||
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
||||
condition on y.
|
||||
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
||||
"""
|
||||
gradient = cond_fn(x, t, **model_kwargs)
|
||||
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
||||
return new_mean
|
||||
|
||||
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
||||
"""
|
||||
Compute what the p_mean_variance output would have been, should the
|
||||
model's score function be conditioned by cond_fn.
|
||||
See condition_mean() for details on cond_fn.
|
||||
Unlike condition_mean(), this instead uses the conditioning strategy
|
||||
from Song et al (2020).
|
||||
"""
|
||||
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
||||
|
||||
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
||||
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
|
||||
|
||||
out = p_mean_var.copy()
|
||||
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
||||
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
||||
return out
|
||||
|
||||
def p_sample(
|
||||
self,
|
||||
model,
|
||||
x,
|
||||
t,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
):
|
||||
"""
|
||||
Sample x_{t-1} from the model at the given timestep.
|
||||
:param model: the model to sample from.
|
||||
:param x: the current tensor at x_{t-1}.
|
||||
:param t: the value of t, starting at 0 for the first diffusion step.
|
||||
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
||||
:param denoised_fn: if not None, a function which applies to the
|
||||
x_start prediction before it is used to sample.
|
||||
:param cond_fn: if not None, this is a gradient function that acts
|
||||
similarly to the model.
|
||||
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
||||
pass to the model. This can be used for conditioning.
|
||||
:return: a dict containing the following keys:
|
||||
- 'sample': a random sample from the model.
|
||||
- 'pred_xstart': a prediction of x_0.
|
||||
"""
|
||||
out = self.p_mean_variance(
|
||||
model,
|
||||
x,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
noise = th.randn_like(x)
|
||||
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
|
||||
if cond_fn is not None:
|
||||
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
||||
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
||||
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
def p_sample_loop(
|
||||
self,
|
||||
model,
|
||||
shape,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
device=None,
|
||||
progress=False,
|
||||
):
|
||||
"""
|
||||
Generate samples from the model.
|
||||
:param model: the model module.
|
||||
:param shape: the shape of the samples, (N, C, H, W).
|
||||
:param noise: if specified, the noise from the encoder to sample.
|
||||
Should be of the same shape as `shape`.
|
||||
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
||||
:param denoised_fn: if not None, a function which applies to the
|
||||
x_start prediction before it is used to sample.
|
||||
:param cond_fn: if not None, this is a gradient function that acts
|
||||
similarly to the model.
|
||||
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
||||
pass to the model. This can be used for conditioning.
|
||||
:param device: if specified, the device to create the samples on.
|
||||
If not specified, use a model parameter's device.
|
||||
:param progress: if True, show a tqdm progress bar.
|
||||
:return: a non-differentiable batch of samples.
|
||||
"""
|
||||
final = None
|
||||
for sample in self.p_sample_loop_progressive(
|
||||
model,
|
||||
shape,
|
||||
noise=noise,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
device=device,
|
||||
progress=progress,
|
||||
):
|
||||
final = sample
|
||||
return final["sample"]
|
||||
|
||||
def p_sample_loop_progressive(
|
||||
self,
|
||||
model,
|
||||
shape,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
device=None,
|
||||
progress=False,
|
||||
):
|
||||
"""
|
||||
Generate samples from the model and yield intermediate samples from
|
||||
each timestep of diffusion.
|
||||
Arguments are the same as p_sample_loop().
|
||||
Returns a generator over dicts, where each dict is the return value of
|
||||
p_sample().
|
||||
"""
|
||||
if device is None:
|
||||
device = next(model.parameters()).device
|
||||
assert isinstance(shape, (tuple, list))
|
||||
if noise is not None:
|
||||
img = noise
|
||||
else:
|
||||
img = th.randn(*shape, device=device)
|
||||
indices = list(range(self.num_timesteps))[::-1]
|
||||
|
||||
if progress:
|
||||
# Lazy import so that we don't depend on tqdm.
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
indices = tqdm(indices)
|
||||
|
||||
for i in indices:
|
||||
t = th.tensor([i] * shape[0], device=device)
|
||||
with th.no_grad():
|
||||
out = self.p_sample(
|
||||
model,
|
||||
img,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
yield out
|
||||
img = out["sample"]
|
||||
|
||||
def ddim_sample(
|
||||
self,
|
||||
model,
|
||||
x,
|
||||
t,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
eta=0.0,
|
||||
):
|
||||
"""
|
||||
Sample x_{t-1} from the model using DDIM.
|
||||
Same usage as p_sample().
|
||||
"""
|
||||
out = self.p_mean_variance(
|
||||
model,
|
||||
x,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
if cond_fn is not None:
|
||||
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
||||
|
||||
# Usually our model outputs epsilon, but we re-derive it
|
||||
# in case we used x_start or x_prev prediction.
|
||||
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
||||
|
||||
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
||||
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
||||
sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
||||
# Equation 12.
|
||||
noise = th.randn_like(x)
|
||||
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
|
||||
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
|
||||
sample = mean_pred + nonzero_mask * sigma * noise
|
||||
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
def ddim_reverse_sample(
|
||||
self,
|
||||
model,
|
||||
x,
|
||||
t,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
eta=0.0,
|
||||
):
|
||||
"""
|
||||
Sample x_{t+1} from the model using DDIM reverse ODE.
|
||||
"""
|
||||
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
||||
out = self.p_mean_variance(
|
||||
model,
|
||||
x,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
if cond_fn is not None:
|
||||
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
||||
# Usually our model outputs epsilon, but we re-derive it
|
||||
# in case we used x_start or x_prev prediction.
|
||||
eps = (
|
||||
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]
|
||||
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
||||
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
||||
|
||||
# Equation 12. reversed
|
||||
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
||||
|
||||
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
def ddim_sample_loop(
|
||||
self,
|
||||
model,
|
||||
shape,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
device=None,
|
||||
progress=False,
|
||||
eta=0.0,
|
||||
):
|
||||
"""
|
||||
Generate samples from the model using DDIM.
|
||||
Same usage as p_sample_loop().
|
||||
"""
|
||||
final = None
|
||||
for sample in self.ddim_sample_loop_progressive(
|
||||
model,
|
||||
shape,
|
||||
noise=noise,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
device=device,
|
||||
progress=progress,
|
||||
eta=eta,
|
||||
):
|
||||
final = sample
|
||||
return final["sample"]
|
||||
|
||||
def ddim_sample_loop_progressive(
|
||||
self,
|
||||
model,
|
||||
shape,
|
||||
noise=None,
|
||||
clip_denoised=True,
|
||||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
device=None,
|
||||
progress=False,
|
||||
eta=0.0,
|
||||
):
|
||||
"""
|
||||
Use DDIM to sample from the model and yield intermediate samples from
|
||||
each timestep of DDIM.
|
||||
Same usage as p_sample_loop_progressive().
|
||||
"""
|
||||
if device is None:
|
||||
device = next(model.parameters()).device
|
||||
assert isinstance(shape, (tuple, list))
|
||||
if noise is not None:
|
||||
img = noise
|
||||
else:
|
||||
img = th.randn(*shape, device=device)
|
||||
indices = list(range(self.num_timesteps))[::-1]
|
||||
|
||||
if progress:
|
||||
# Lazy import so that we don't depend on tqdm.
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
indices = tqdm(indices)
|
||||
|
||||
for i in indices:
|
||||
t = th.tensor([i] * shape[0], device=device)
|
||||
with th.no_grad():
|
||||
out = self.ddim_sample(
|
||||
model,
|
||||
img,
|
||||
t,
|
||||
clip_denoised=clip_denoised,
|
||||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
eta=eta,
|
||||
)
|
||||
yield out
|
||||
img = out["sample"]
|
||||
|
||||
def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
|
||||
"""
|
||||
Get a term for the variational lower-bound.
|
||||
The resulting units are bits (rather than nats, as one might expect).
|
||||
This allows for comparison to other papers.
|
||||
:return: a dict with the following keys:
|
||||
- 'output': a shape [N] tensor of NLLs or KLs.
|
||||
- 'pred_xstart': the x_0 predictions.
|
||||
"""
|
||||
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
|
||||
out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
|
||||
kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
|
||||
kl = mean_flat(kl) / np.log(2.0)
|
||||
|
||||
decoder_nll = -discretized_gaussian_log_likelihood(
|
||||
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
||||
)
|
||||
assert decoder_nll.shape == x_start.shape
|
||||
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
||||
|
||||
# At the first timestep return the decoder NLL,
|
||||
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
||||
output = th.where((t == 0), decoder_nll, kl)
|
||||
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
||||
"""
|
||||
Compute training losses for a single timestep.
|
||||
:param model: the model to evaluate loss on.
|
||||
:param x_start: the [N x C x ...] tensor of inputs.
|
||||
:param t: a batch of timestep indices.
|
||||
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
||||
pass to the model. This can be used for conditioning.
|
||||
:param noise: if specified, the specific Gaussian noise to try to remove.
|
||||
:return: a dict with the key "loss" containing a tensor of shape [N].
|
||||
Some mean or variance settings may also have other keys.
|
||||
"""
|
||||
if model_kwargs is None:
|
||||
model_kwargs = {}
|
||||
if noise is None:
|
||||
noise = th.randn_like(x_start)
|
||||
x_t = self.q_sample(x_start, t, noise=noise)
|
||||
|
||||
terms = {}
|
||||
|
||||
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
||||
terms["loss"] = self._vb_terms_bpd(
|
||||
model=model,
|
||||
x_start=x_start,
|
||||
x_t=x_t,
|
||||
t=t,
|
||||
clip_denoised=False,
|
||||
model_kwargs=model_kwargs,
|
||||
)["output"]
|
||||
if self.loss_type == LossType.RESCALED_KL:
|
||||
terms["loss"] *= self.num_timesteps
|
||||
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
||||
model_output = model(x_t, t, **model_kwargs)
|
||||
|
||||
if self.model_var_type in [
|
||||
ModelVarType.LEARNED,
|
||||
ModelVarType.LEARNED_RANGE,
|
||||
]:
|
||||
B, C = x_t.shape[:2]
|
||||
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
|
||||
model_output, model_var_values = th.split(model_output, C, dim=1)
|
||||
# Learn the variance using the variational bound, but don't let
|
||||
# it affect our mean prediction.
|
||||
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
|
||||
terms["vb"] = self._vb_terms_bpd(
|
||||
model=lambda *args, r=frozen_out: r,
|
||||
x_start=x_start,
|
||||
x_t=x_t,
|
||||
t=t,
|
||||
clip_denoised=False,
|
||||
)["output"]
|
||||
if self.loss_type == LossType.RESCALED_MSE:
|
||||
# Divide by 1000 for equivalence with initial implementation.
|
||||
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
||||
terms["vb"] *= self.num_timesteps / 1000.0
|
||||
|
||||
target = {
|
||||
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
|
||||
ModelMeanType.START_X: x_start,
|
||||
ModelMeanType.EPSILON: noise,
|
||||
}[self.model_mean_type]
|
||||
assert model_output.shape == target.shape == x_start.shape
|
||||
terms["mse"] = mean_flat((target - model_output) ** 2)
|
||||
if "vb" in terms:
|
||||
terms["loss"] = terms["mse"] + terms["vb"]
|
||||
else:
|
||||
terms["loss"] = terms["mse"]
|
||||
else:
|
||||
raise NotImplementedError(self.loss_type)
|
||||
|
||||
return terms
|
||||
|
||||
def _prior_bpd(self, x_start):
|
||||
"""
|
||||
Get the prior KL term for the variational lower-bound, measured in
|
||||
bits-per-dim.
|
||||
This term can't be optimized, as it only depends on the encoder.
|
||||
:param x_start: the [N x C x ...] tensor of inputs.
|
||||
:return: a batch of [N] KL values (in bits), one per batch element.
|
||||
"""
|
||||
batch_size = x_start.shape[0]
|
||||
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
||||
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
||||
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
||||
return mean_flat(kl_prior) / np.log(2.0)
|
||||
|
||||
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
||||
"""
|
||||
Compute the entire variational lower-bound, measured in bits-per-dim,
|
||||
as well as other related quantities.
|
||||
:param model: the model to evaluate loss on.
|
||||
:param x_start: the [N x C x ...] tensor of inputs.
|
||||
:param clip_denoised: if True, clip denoised samples.
|
||||
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
||||
pass to the model. This can be used for conditioning.
|
||||
:return: a dict containing the following keys:
|
||||
- total_bpd: the total variational lower-bound, per batch element.
|
||||
- prior_bpd: the prior term in the lower-bound.
|
||||
- vb: an [N x T] tensor of terms in the lower-bound.
|
||||
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
||||
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
||||
"""
|
||||
device = x_start.device
|
||||
batch_size = x_start.shape[0]
|
||||
|
||||
vb = []
|
||||
xstart_mse = []
|
||||
mse = []
|
||||
for t in list(range(self.num_timesteps))[::-1]:
|
||||
t_batch = th.tensor([t] * batch_size, device=device)
|
||||
noise = th.randn_like(x_start)
|
||||
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
||||
# Calculate VLB term at the current timestep
|
||||
with th.no_grad():
|
||||
out = self._vb_terms_bpd(
|
||||
model,
|
||||
x_start=x_start,
|
||||
x_t=x_t,
|
||||
t=t_batch,
|
||||
clip_denoised=clip_denoised,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
vb.append(out["output"])
|
||||
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
||||
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
||||
mse.append(mean_flat((eps - noise) ** 2))
|
||||
|
||||
vb = th.stack(vb, dim=1)
|
||||
xstart_mse = th.stack(xstart_mse, dim=1)
|
||||
mse = th.stack(mse, dim=1)
|
||||
|
||||
prior_bpd = self._prior_bpd(x_start)
|
||||
total_bpd = vb.sum(dim=1) + prior_bpd
|
||||
return {
|
||||
"total_bpd": total_bpd,
|
||||
"prior_bpd": prior_bpd,
|
||||
"vb": vb,
|
||||
"xstart_mse": xstart_mse,
|
||||
"mse": mse,
|
||||
}
|
||||
|
||||
|
||||
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
"""
|
||||
Extract values from a 1-D numpy array for a batch of indices.
|
||||
:param arr: the 1-D numpy array.
|
||||
:param timesteps: a tensor of indices into the array to extract.
|
||||
:param broadcast_shape: a larger shape of K dimensions with the batch
|
||||
dimension equal to the length of timesteps.
|
||||
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
||||
"""
|
||||
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
||||
while len(res.shape) < len(broadcast_shape):
|
||||
res = res[..., None]
|
||||
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
||||
127
opensora/schedulers/iddpm/respace.py
Normal file
127
opensora/schedulers/iddpm/respace.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
# Adapted from DiT
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# DiT: https://github.com/facebookresearch/DiT/tree/main
|
||||
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
||||
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
||||
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
from .gaussian_diffusion import GaussianDiffusion
|
||||
|
||||
|
||||
def space_timesteps(num_timesteps, section_counts):
|
||||
"""
|
||||
Create a list of timesteps to use from an original diffusion process,
|
||||
given the number of timesteps we want to take from equally-sized portions
|
||||
of the original process.
|
||||
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
||||
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
||||
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
||||
If the stride is a string starting with "ddim", then the fixed striding
|
||||
from the DDIM paper is used, and only one section is allowed.
|
||||
:param num_timesteps: the number of diffusion steps in the original
|
||||
process to divide up.
|
||||
:param section_counts: either a list of numbers, or a string containing
|
||||
comma-separated numbers, indicating the step count
|
||||
per section. As a special case, use "ddimN" where N
|
||||
is a number of steps to use the striding from the
|
||||
DDIM paper.
|
||||
:return: a set of diffusion steps from the original process to use.
|
||||
"""
|
||||
if isinstance(section_counts, str):
|
||||
if section_counts.startswith("ddim"):
|
||||
desired_count = int(section_counts[len("ddim") :])
|
||||
for i in range(1, num_timesteps):
|
||||
if len(range(0, num_timesteps, i)) == desired_count:
|
||||
return set(range(0, num_timesteps, i))
|
||||
raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
|
||||
section_counts = [int(x) for x in section_counts.split(",")]
|
||||
size_per = num_timesteps // len(section_counts)
|
||||
extra = num_timesteps % len(section_counts)
|
||||
start_idx = 0
|
||||
all_steps = []
|
||||
for i, section_count in enumerate(section_counts):
|
||||
size = size_per + (1 if i < extra else 0)
|
||||
if size < section_count:
|
||||
raise ValueError(f"cannot divide section of {size} steps into {section_count}")
|
||||
if section_count <= 1:
|
||||
frac_stride = 1
|
||||
else:
|
||||
frac_stride = (size - 1) / (section_count - 1)
|
||||
cur_idx = 0.0
|
||||
taken_steps = []
|
||||
for _ in range(section_count):
|
||||
taken_steps.append(start_idx + round(cur_idx))
|
||||
cur_idx += frac_stride
|
||||
all_steps += taken_steps
|
||||
start_idx += size
|
||||
return set(all_steps)
|
||||
|
||||
|
||||
class SpacedDiffusion(GaussianDiffusion):
|
||||
"""
|
||||
A diffusion process which can skip steps in a base diffusion process.
|
||||
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
||||
original diffusion process to retain.
|
||||
:param kwargs: the kwargs to create the base diffusion process.
|
||||
"""
|
||||
|
||||
def __init__(self, use_timesteps, **kwargs):
|
||||
self.use_timesteps = set(use_timesteps)
|
||||
self.timestep_map = []
|
||||
self.original_num_steps = len(kwargs["betas"])
|
||||
|
||||
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
||||
last_alpha_cumprod = 1.0
|
||||
new_betas = []
|
||||
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
||||
if i in self.use_timesteps:
|
||||
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
||||
last_alpha_cumprod = alpha_cumprod
|
||||
self.timestep_map.append(i)
|
||||
kwargs["betas"] = np.array(new_betas)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
|
||||
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
||||
|
||||
def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
|
||||
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
||||
|
||||
def condition_mean(self, cond_fn, *args, **kwargs):
|
||||
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
||||
|
||||
def condition_score(self, cond_fn, *args, **kwargs):
|
||||
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
||||
|
||||
def _wrap_model(self, model):
|
||||
if isinstance(model, _WrappedModel):
|
||||
return model
|
||||
return _WrappedModel(model, self.timestep_map, self.original_num_steps)
|
||||
|
||||
def _scale_timesteps(self, t):
|
||||
# Scaling is done by the wrapped model.
|
||||
return t
|
||||
|
||||
|
||||
class _WrappedModel:
|
||||
def __init__(self, model, timestep_map, original_num_steps):
|
||||
self.model = model
|
||||
self.timestep_map = timestep_map
|
||||
# self.rescale_timesteps = rescale_timesteps
|
||||
self.original_num_steps = original_num_steps
|
||||
|
||||
def __call__(self, x, ts, **kwargs):
|
||||
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
||||
new_ts = map_tensor[ts]
|
||||
# if self.rescale_timesteps:
|
||||
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
||||
return self.model(x, new_ts, **kwargs)
|
||||
150
opensora/schedulers/iddpm/timestep_sampler.py
Normal file
150
opensora/schedulers/iddpm/timestep_sampler.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
# Adapted from DiT
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# DiT: https://github.com/facebookresearch/DiT/tree/main
|
||||
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
||||
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
||||
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# --------------------------------------------------------
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def create_named_schedule_sampler(name, diffusion):
|
||||
"""
|
||||
Create a ScheduleSampler from a library of pre-defined samplers.
|
||||
:param name: the name of the sampler.
|
||||
:param diffusion: the diffusion object to sample for.
|
||||
"""
|
||||
if name == "uniform":
|
||||
return UniformSampler(diffusion)
|
||||
elif name == "loss-second-moment":
|
||||
return LossSecondMomentResampler(diffusion)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
||||
|
||||
|
||||
class ScheduleSampler(ABC):
|
||||
"""
|
||||
A distribution over timesteps in the diffusion process, intended to reduce
|
||||
variance of the objective.
|
||||
By default, samplers perform unbiased importance sampling, in which the
|
||||
objective's mean is unchanged.
|
||||
However, subclasses may override sample() to change how the resampled
|
||||
terms are reweighted, allowing for actual changes in the objective.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def weights(self):
|
||||
"""
|
||||
Get a numpy array of weights, one per diffusion step.
|
||||
The weights needn't be normalized, but must be positive.
|
||||
"""
|
||||
|
||||
def sample(self, batch_size, device):
|
||||
"""
|
||||
Importance-sample timesteps for a batch.
|
||||
:param batch_size: the number of timesteps.
|
||||
:param device: the torch device to save to.
|
||||
:return: a tuple (timesteps, weights):
|
||||
- timesteps: a tensor of timestep indices.
|
||||
- weights: a tensor of weights to scale the resulting losses.
|
||||
"""
|
||||
w = self.weights()
|
||||
p = w / np.sum(w)
|
||||
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
|
||||
indices = th.from_numpy(indices_np).long().to(device)
|
||||
weights_np = 1 / (len(p) * p[indices_np])
|
||||
weights = th.from_numpy(weights_np).float().to(device)
|
||||
return indices, weights
|
||||
|
||||
|
||||
class UniformSampler(ScheduleSampler):
|
||||
def __init__(self, diffusion):
|
||||
self.diffusion = diffusion
|
||||
self._weights = np.ones([diffusion.num_timesteps])
|
||||
|
||||
def weights(self):
|
||||
return self._weights
|
||||
|
||||
|
||||
class LossAwareSampler(ScheduleSampler):
|
||||
def update_with_local_losses(self, local_ts, local_losses):
|
||||
"""
|
||||
Update the reweighting using losses from a model.
|
||||
Call this method from each rank with a batch of timesteps and the
|
||||
corresponding losses for each of those timesteps.
|
||||
This method will perform synchronization to make sure all of the ranks
|
||||
maintain the exact same reweighting.
|
||||
:param local_ts: an integer Tensor of timesteps.
|
||||
:param local_losses: a 1D Tensor of losses.
|
||||
"""
|
||||
batch_sizes = [th.tensor([0], dtype=th.int32, device=local_ts.device) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(
|
||||
batch_sizes,
|
||||
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
|
||||
)
|
||||
|
||||
# Pad all_gather batches to be the maximum batch size.
|
||||
batch_sizes = [x.item() for x in batch_sizes]
|
||||
max_bs = max(batch_sizes)
|
||||
|
||||
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
|
||||
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
|
||||
dist.all_gather(timestep_batches, local_ts)
|
||||
dist.all_gather(loss_batches, local_losses)
|
||||
timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]]
|
||||
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
|
||||
self.update_with_all_losses(timesteps, losses)
|
||||
|
||||
@abstractmethod
|
||||
def update_with_all_losses(self, ts, losses):
|
||||
"""
|
||||
Update the reweighting using losses from a model.
|
||||
Sub-classes should override this method to update the reweighting
|
||||
using losses from the model.
|
||||
This method directly updates the reweighting without synchronizing
|
||||
between workers. It is called by update_with_local_losses from all
|
||||
ranks with identical arguments. Thus, it should have deterministic
|
||||
behavior to maintain state across workers.
|
||||
:param ts: a list of int timesteps.
|
||||
:param losses: a list of float losses, one per timestep.
|
||||
"""
|
||||
|
||||
|
||||
class LossSecondMomentResampler(LossAwareSampler):
|
||||
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
|
||||
self.diffusion = diffusion
|
||||
self.history_per_term = history_per_term
|
||||
self.uniform_prob = uniform_prob
|
||||
self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64)
|
||||
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
|
||||
|
||||
def weights(self):
|
||||
if not self._warmed_up():
|
||||
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
|
||||
weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
|
||||
weights /= np.sum(weights)
|
||||
weights *= 1 - self.uniform_prob
|
||||
weights += self.uniform_prob / len(weights)
|
||||
return weights
|
||||
|
||||
def update_with_all_losses(self, ts, losses):
|
||||
for t, loss in zip(ts, losses):
|
||||
if self._loss_counts[t] == self.history_per_term:
|
||||
# Shift out the oldest loss term.
|
||||
self._loss_history[t, :-1] = self._loss_history[t, 1:]
|
||||
self._loss_history[t, -1] = loss
|
||||
else:
|
||||
self._loss_history[t, self._loss_counts[t]] = loss
|
||||
self._loss_counts[t] += 1
|
||||
|
||||
def _warmed_up(self):
|
||||
return (self._loss_counts == self.history_per_term).all()
|
||||
0
opensora/utils/__init__.py
Normal file
0
opensora/utils/__init__.py
Normal file
233
opensora/utils/ckpt_utils.py
Normal file
233
opensora/utils/ckpt_utils.py
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
import functools
|
||||
import json
|
||||
import logging
|
||||
import operator
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
pretrained_models = {
|
||||
"DiT-XL-2-512x512.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt",
|
||||
"DiT-XL-2-256x256.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt",
|
||||
"Latte-XL-2-256x256-ucf101.pt": "https://huggingface.co/maxin-cn/Latte/resolve/main/ucf101.pt",
|
||||
"PixArt-XL-2-256x256.pth": "PixArt-XL-2-256x256.pth",
|
||||
"PixArt-XL-2-SAM-256x256.pth": "PixArt-XL-2-SAM-256x256.pth",
|
||||
"PixArt-XL-2-512x512.pth": "PixArt-XL-2-512x512.pth",
|
||||
"PixArt-XL-2-1024-MS.pth": "PixArt-XL-2-1024-MS.pth",
|
||||
}
|
||||
|
||||
|
||||
def reparameter(ckpt, name=None):
|
||||
if "DiT" in name:
|
||||
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
|
||||
del ckpt["pos_embed"]
|
||||
elif "Latte" in name:
|
||||
ckpt = ckpt["ema"]
|
||||
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
|
||||
del ckpt["pos_embed"]
|
||||
del ckpt["temp_embed"]
|
||||
elif "PixArt" in name:
|
||||
ckpt = ckpt["state_dict"]
|
||||
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
|
||||
del ckpt["pos_embed"]
|
||||
return ckpt
|
||||
|
||||
|
||||
def find_model(model_name):
|
||||
"""
|
||||
Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
|
||||
"""
|
||||
if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints
|
||||
model = download_model(model_name)
|
||||
model = reparameter(model, model_name)
|
||||
return model
|
||||
else: # Load a custom DiT checkpoint:
|
||||
assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}"
|
||||
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
|
||||
if "pos_embed_temporal" in checkpoint:
|
||||
del checkpoint["pos_embed_temporal"]
|
||||
if "pos_embed" in checkpoint:
|
||||
del checkpoint["pos_embed"]
|
||||
if "ema" in checkpoint: # supports checkpoints from train.py
|
||||
checkpoint = checkpoint["ema"]
|
||||
return checkpoint
|
||||
|
||||
|
||||
def download_model(model_name):
|
||||
"""
|
||||
Downloads a pre-trained DiT model from the web.
|
||||
"""
|
||||
assert model_name in pretrained_models
|
||||
local_path = f"pretrained_models/{model_name}"
|
||||
if not os.path.isfile(local_path):
|
||||
os.makedirs("pretrained_models", exist_ok=True)
|
||||
web_path = pretrained_models[model_name]
|
||||
download_url(web_path, "pretrained_models", model_name)
|
||||
model = torch.load(local_path, map_location=lambda storage, loc: storage)
|
||||
return model
|
||||
|
||||
|
||||
def load_from_sharded_state_dict(model, ckpt_path):
|
||||
# TODO: harded-coded for colossal loading
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "29501"
|
||||
colossalai.launch_from_torch({})
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision="fp32",
|
||||
initial_scale=2**16,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
model, _, _, _, _ = booster.boost(model=model)
|
||||
booster.load_model(model, os.path.join(ckpt_path, "model"))
|
||||
|
||||
save_path = os.path.join(ckpt_path, "model_ckpt.pt")
|
||||
torch.save(model.module.state_dict(), save_path)
|
||||
print(f"Model checkpoint saved to {save_path}")
|
||||
|
||||
|
||||
def model_sharding(model: torch.nn.Module):
|
||||
global_rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
for _, param in model.named_parameters():
|
||||
padding_size = (world_size - param.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
||||
else:
|
||||
padding_param = param.data.view(-1)
|
||||
splited_params = padding_param.split(padding_param.numel() // world_size)
|
||||
splited_params = splited_params[global_rank]
|
||||
param.data = splited_params
|
||||
|
||||
|
||||
def load_json(file_path: str):
|
||||
with open(file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_json(data, file_path: str):
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
|
||||
def remove_padding(tensor: torch.Tensor, original_shape: Tuple) -> torch.Tensor:
|
||||
return tensor[: functools.reduce(operator.mul, original_shape)]
|
||||
|
||||
|
||||
def model_gathering(model: torch.nn.Module, model_shape_dict: dict):
|
||||
global_rank = dist.get_rank()
|
||||
global_size = dist.get_world_size()
|
||||
for name, param in model.named_parameters():
|
||||
all_params = [torch.empty_like(param.data) for _ in range(global_size)]
|
||||
dist.all_gather(all_params, param.data, group=dist.group.WORLD)
|
||||
if int(global_rank) == 0:
|
||||
all_params = torch.cat(all_params)
|
||||
param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name])
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def record_model_param_shape(model: torch.nn.Module) -> dict:
|
||||
param_shape = {}
|
||||
for name, param in model.named_parameters():
|
||||
param_shape[name] = param.shape
|
||||
return param_shape
|
||||
|
||||
|
||||
def save(
|
||||
booster: Booster,
|
||||
model: nn.Module,
|
||||
ema: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
epoch: int,
|
||||
step: int,
|
||||
global_step: int,
|
||||
batch_size: int,
|
||||
coordinator: DistCoordinator,
|
||||
save_dir: str,
|
||||
shape_dict: dict,
|
||||
):
|
||||
save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}")
|
||||
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
|
||||
|
||||
booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
|
||||
# ema is not boosted, so we don't need to use booster.save_model
|
||||
model_gathering(ema, shape_dict)
|
||||
global_rank = dist.get_rank()
|
||||
if int(global_rank) == 0:
|
||||
torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt"))
|
||||
model_sharding(ema)
|
||||
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
|
||||
if lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"global_step": global_step,
|
||||
"sample_start_index": step * batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def load(
|
||||
booster: Booster, model: nn.Module, ema: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str
|
||||
) -> Tuple[int, int, int]:
|
||||
booster.load_model(model, os.path.join(load_dir, "model"))
|
||||
# ema is not boosted, so we don't use booster.load_model
|
||||
# ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt")))
|
||||
ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")))
|
||||
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
|
||||
if lr_scheduler is not None:
|
||||
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
|
||||
running_states = load_json(os.path.join(load_dir, "running_states.json"))
|
||||
dist.barrier()
|
||||
return running_states["epoch"], running_states["step"], running_states["sample_start_index"]
|
||||
|
||||
|
||||
def create_logger(logging_dir):
|
||||
"""
|
||||
Create a logger that writes to a log file and stdout.
|
||||
"""
|
||||
if dist.get_rank() == 0: # real logger
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="[\033[34m%(asctime)s\033[0m] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
else: # dummy logger (does nothing)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.addHandler(logging.NullHandler())
|
||||
return logger
|
||||
|
||||
|
||||
def load_checkpoint(model, ckpt_path, save_as_pt=True):
|
||||
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
|
||||
state_dict = find_model(ckpt_path)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
print(f"Missing keys: {missing_keys}")
|
||||
print(f"Unexpected keys: {unexpected_keys}")
|
||||
elif os.path.isdir(ckpt_path):
|
||||
load_from_sharded_state_dict(model, ckpt_path)
|
||||
if save_as_pt:
|
||||
save_path = os.path.join(ckpt_path, "model_ckpt.pt")
|
||||
torch.save(model.state_dict(), save_path)
|
||||
else:
|
||||
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
|
||||
96
opensora/utils/config_utils.py
Normal file
96
opensora/utils/config_utils.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
from mmengine.config import Config
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
def parse_args(training=False):
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# model config
|
||||
parser.add_argument("config", help="model config file path")
|
||||
|
||||
parser.add_argument("--seed", default=42, type=int, help="generation seed")
|
||||
parser.add_argument("--ckpt-path", type=str, help="path to model ckpt; will overwrite cfg.ckpt_path if specified")
|
||||
parser.add_argument("--batch-size", default=None, type=int, help="batch size")
|
||||
|
||||
# ======================================================
|
||||
# Inference
|
||||
# ======================================================
|
||||
|
||||
if not training:
|
||||
# prompt
|
||||
parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file")
|
||||
parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples")
|
||||
|
||||
# hyperparameters
|
||||
parser.add_argument("--num-sampling-steps", default=None, type=int, help="sampling steps")
|
||||
parser.add_argument("--cfg-scale", default=None, type=float, help="balance between cond & uncond")
|
||||
else:
|
||||
parser.add_argument("--wandb", default=None, type=bool, help="enable wandb")
|
||||
parser.add_argument("--load", default=None, type=str, help="path to continue training")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def merge_args(cfg, args, training=False):
|
||||
if args.ckpt_path is not None:
|
||||
cfg.model["from_pratrained"] = args.ckpt_path
|
||||
args.ckpt_path = None
|
||||
|
||||
if not training:
|
||||
if args.cfg_scale is not None:
|
||||
cfg.scheduler["cfg_scale"] = args.cfg_scale
|
||||
args.cfg_scale = None
|
||||
|
||||
if "multi_resolution" not in cfg:
|
||||
cfg["multi_resolution"] = False
|
||||
for k, v in vars(args).items():
|
||||
if k in cfg and v is not None:
|
||||
cfg[k] = v
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def parse_configs(training=False):
|
||||
args = parse_args(training)
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg = merge_args(cfg, args, training)
|
||||
return cfg
|
||||
|
||||
|
||||
def create_experiment_workspace(cfg):
|
||||
"""
|
||||
This function creates a folder for experiment tracking.
|
||||
|
||||
Args:
|
||||
args: The parsed arguments.
|
||||
|
||||
Returns:
|
||||
exp_dir: The path to the experiment folder.
|
||||
"""
|
||||
# Make outputs folder (holds all experiment subfolders)
|
||||
os.makedirs(cfg.outputs, exist_ok=True)
|
||||
experiment_index = len(glob(f"{cfg.outputs}/*"))
|
||||
|
||||
# Create an experiment folder
|
||||
model_name = cfg.model["type"].replace("/", "-")
|
||||
exp_name = f"{experiment_index:03d}-F{cfg.num_frames}S{cfg.frame_interval}-{model_name}"
|
||||
exp_dir = f"{cfg.outputs}/{exp_name}"
|
||||
os.makedirs(exp_dir, exist_ok=True)
|
||||
return exp_name, exp_dir
|
||||
|
||||
|
||||
def save_training_config(cfg, experiment_dir):
|
||||
with open(f"{experiment_dir}/config.txt", "w") as f:
|
||||
json.dump(cfg, f, indent=4)
|
||||
|
||||
|
||||
def create_tensorboard_writer(exp_dir):
|
||||
tensorboard_dir = f"{exp_dir}/tensorboard"
|
||||
os.makedirs(tensorboard_dir, exist_ok=True)
|
||||
writer = SummaryWriter(tensorboard_dir)
|
||||
return writer
|
||||
339
opensora/utils/misc.py
Normal file
339
opensora/utils/misc.py
Normal file
|
|
@ -0,0 +1,339 @@
|
|||
import collections
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
from itertools import repeat
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def print_rank(var_name, var_value, rank=0):
|
||||
if dist.get_rank() == rank:
|
||||
print(f"[Rank {rank}] {var_name}: {var_value}")
|
||||
|
||||
|
||||
def print_0(*args, **kwargs):
|
||||
if dist.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
|
||||
"""
|
||||
Set requires_grad flag for all parameters in a model.
|
||||
"""
|
||||
for p in model.parameters():
|
||||
p.requires_grad = flag
|
||||
|
||||
|
||||
def format_numel_str(numel: int) -> str:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
if numel >= B:
|
||||
return f"{numel / B:.2f} B"
|
||||
elif numel >= M:
|
||||
return f"{numel / M:.2f} M"
|
||||
elif numel >= K:
|
||||
return f"{numel / K:.2f} K"
|
||||
else:
|
||||
return f"{numel}"
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||
tensor.div_(dist.get_world_size())
|
||||
return tensor
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> (int, int):
|
||||
num_params = 0
|
||||
num_params_trainable = 0
|
||||
for p in model.parameters():
|
||||
num_params += p.numel()
|
||||
if p.requires_grad:
|
||||
num_params_trainable += p.numel()
|
||||
return num_params, num_params_trainable
|
||||
|
||||
|
||||
def try_import(name):
|
||||
"""Try to import a module.
|
||||
|
||||
Args:
|
||||
name (str): Specifies what module to import in absolute or relative
|
||||
terms (e.g. either pkg.mod or ..mod).
|
||||
Returns:
|
||||
ModuleType or None: If importing successfully, returns the imported
|
||||
module, otherwise returns None.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(name)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def transpose(x):
|
||||
"""
|
||||
transpose a list of list
|
||||
Args:
|
||||
x (list[list]):
|
||||
"""
|
||||
ret = list(map(list, zip(*x)))
|
||||
return ret
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime(time.time()))
|
||||
return timestamp
|
||||
|
||||
|
||||
def format_time(seconds):
|
||||
days = int(seconds / 3600 / 24)
|
||||
seconds = seconds - days * 3600 * 24
|
||||
hours = int(seconds / 3600)
|
||||
seconds = seconds - hours * 3600
|
||||
minutes = int(seconds / 60)
|
||||
seconds = seconds - minutes * 60
|
||||
secondsf = int(seconds)
|
||||
seconds = seconds - secondsf
|
||||
millis = int(seconds * 1000)
|
||||
|
||||
f = ""
|
||||
i = 1
|
||||
if days > 0:
|
||||
f += str(days) + "D"
|
||||
i += 1
|
||||
if hours > 0 and i <= 2:
|
||||
f += str(hours) + "h"
|
||||
i += 1
|
||||
if minutes > 0 and i <= 2:
|
||||
f += str(minutes) + "m"
|
||||
i += 1
|
||||
if secondsf > 0 and i <= 2:
|
||||
f += str(secondsf) + "s"
|
||||
i += 1
|
||||
if millis > 0 and i <= 2:
|
||||
f += str(millis) + "ms"
|
||||
i += 1
|
||||
if f == "":
|
||||
f = "0ms"
|
||||
return f
|
||||
|
||||
|
||||
def to_tensor(data):
|
||||
"""Convert objects of various python types to :obj:`torch.Tensor`.
|
||||
|
||||
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
||||
:class:`Sequence`, :class:`int` and :class:`float`.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
|
||||
be converted.
|
||||
"""
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data
|
||||
elif isinstance(data, np.ndarray):
|
||||
return torch.from_numpy(data)
|
||||
elif isinstance(data, Sequence) and not isinstance(data, str):
|
||||
return torch.tensor(data)
|
||||
elif isinstance(data, int):
|
||||
return torch.LongTensor([data])
|
||||
elif isinstance(data, float):
|
||||
return torch.FloatTensor([data])
|
||||
else:
|
||||
raise TypeError(f"type {type(data)} cannot be converted to tensor.")
|
||||
|
||||
|
||||
def to_ndarray(data):
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.numpy()
|
||||
elif isinstance(data, np.ndarray):
|
||||
return data
|
||||
elif isinstance(data, Sequence):
|
||||
return np.array(data)
|
||||
elif isinstance(data, int):
|
||||
return np.ndarray([data], dtype=int)
|
||||
elif isinstance(data, float):
|
||||
return np.array([data], dtype=float)
|
||||
else:
|
||||
raise TypeError(f"type {type(data)} cannot be converted to ndarray.")
|
||||
|
||||
|
||||
def to_torch_dtype(dtype):
|
||||
if isinstance(dtype, torch.dtype):
|
||||
return dtype
|
||||
elif isinstance(dtype, str):
|
||||
dtype_mapping = {
|
||||
"float64": torch.float64,
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"half": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
if dtype not in dtype_mapping:
|
||||
raise ValueError
|
||||
dtype = dtype_mapping[dtype]
|
||||
return dtype
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def count_params(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_1tuple = _ntuple(1)
|
||||
to_2tuple = _ntuple(2)
|
||||
to_3tuple = _ntuple(3)
|
||||
to_4tuple = _ntuple(4)
|
||||
to_ntuple = _ntuple
|
||||
|
||||
|
||||
def convert_SyncBN_to_BN2d(model_cfg):
|
||||
for k in model_cfg:
|
||||
v = model_cfg[k]
|
||||
if k == "norm_cfg" and v["type"] == "SyncBN":
|
||||
v["type"] = "BN2d"
|
||||
elif isinstance(v, dict):
|
||||
convert_SyncBN_to_BN2d(v)
|
||||
|
||||
|
||||
def get_topk(x, dim=4, k=5):
|
||||
x = to_tensor(x)
|
||||
inds = x[..., dim].topk(k)[1]
|
||||
return x[inds]
|
||||
|
||||
|
||||
def param_sigmoid(x, alpha):
|
||||
ret = 1 / (1 + (-alpha * x).exp())
|
||||
return ret
|
||||
|
||||
|
||||
def inverse_param_sigmoid(x, alpha, eps=1e-5):
|
||||
x = x.clamp(min=0, max=1)
|
||||
x1 = x.clamp(min=eps)
|
||||
x2 = (1 - x).clamp(min=eps)
|
||||
return torch.log(x1 / x2) / alpha
|
||||
|
||||
|
||||
def inverse_sigmoid(x, eps=1e-5):
|
||||
"""Inverse function of sigmoid.
|
||||
|
||||
Args:
|
||||
x (Tensor): The tensor to do the
|
||||
inverse.
|
||||
eps (float): EPS avoid numerical
|
||||
overflow. Defaults 1e-5.
|
||||
Returns:
|
||||
Tensor: The x has passed the inverse
|
||||
function of sigmoid, has same
|
||||
shape with input.
|
||||
"""
|
||||
x = x.clamp(min=0, max=1)
|
||||
x1 = x.clamp(min=eps)
|
||||
x2 = (1 - x).clamp(min=eps)
|
||||
return torch.log(x1 / x2)
|
||||
|
||||
|
||||
def count_columns(df, columns):
|
||||
cnt_dict = OrderedDict()
|
||||
num_samples = len(df)
|
||||
|
||||
for col in columns:
|
||||
d_i = df[col].value_counts().to_dict()
|
||||
for k in d_i:
|
||||
d_i[k] = (d_i[k], d_i[k] / num_samples)
|
||||
cnt_dict[col] = d_i
|
||||
|
||||
return cnt_dict
|
||||
|
||||
|
||||
def build_logger(work_dir, cfgname):
|
||||
log_file = cfgname + ".log"
|
||||
log_path = os.path.join(work_dir, log_file)
|
||||
|
||||
logger = logging.getLogger(cfgname)
|
||||
logger.setLevel(logging.INFO)
|
||||
# formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
|
||||
formatter = logging.Formatter("%(asctime)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
||||
|
||||
handler1 = logging.FileHandler(log_path)
|
||||
handler1.setFormatter(formatter)
|
||||
|
||||
handler2 = logging.StreamHandler()
|
||||
handler2.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(handler1)
|
||||
logger.addHandler(handler2)
|
||||
logger.propagate = False
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
timings = {}
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_timer(prefix):
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.synchronize()
|
||||
t_diff = (time.time() - start) * 1000
|
||||
if prefix not in timings:
|
||||
timings[prefix] = []
|
||||
|
||||
timings[prefix].append(t_diff)
|
||||
|
||||
num_ignored = 10
|
||||
|
||||
if len(timings[prefix]) > num_ignored:
|
||||
# avg = sum(timings[prefix][num_ignored:]) / (len(timings[prefix]) - num_ignored)
|
||||
avg = sum(timings[prefix][-num_ignored:]) / num_ignored
|
||||
print("{}: {} ({})".format(prefix, t_diff, avg))
|
||||
else:
|
||||
print("{}: {}".format(prefix, t_diff))
|
||||
|
||||
|
||||
def strip_dc(x):
|
||||
"""
|
||||
strip DataContainer
|
||||
"""
|
||||
try:
|
||||
import mmcv
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if isinstance(x, dict):
|
||||
res = {}
|
||||
for k, v in x.items():
|
||||
res[k] = strip_dc(v)
|
||||
return res
|
||||
if isinstance(x, (list, tuple)) and isinstance(x[0], mmcv.parallel.DataContainer):
|
||||
return strip_dc(x[0])
|
||||
elif isinstance(x, mmcv.parallel.DataContainer):
|
||||
return strip_dc(x.data)
|
||||
return x
|
||||
34
opensora/utils/train_utils.py
Normal file
34
opensora/utils/train_utils.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def update_ema(
|
||||
ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Step the EMA model towards the current model.
|
||||
"""
|
||||
ema_params = OrderedDict(ema_model.named_parameters())
|
||||
model_params = OrderedDict(model.named_parameters())
|
||||
|
||||
for name, param in model_params.items():
|
||||
if name == "pos_embed":
|
||||
continue
|
||||
if param.requires_grad == False:
|
||||
continue
|
||||
if not sharded:
|
||||
param_data = param.data
|
||||
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
|
||||
else:
|
||||
if param.data.dtype != torch.float32:
|
||||
param_id = id(param)
|
||||
master_param = optimizer._param_store.working_to_master_param[param_id]
|
||||
param_data = master_param.data
|
||||
else:
|
||||
param_data = param.data
|
||||
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
|
||||
|
||||
|
||||
|
||||
0
tools/data/README.md
Normal file
0
tools/data/README.md
Normal file
0
tools/data/__init__.py
Normal file
0
tools/data/__init__.py
Normal file
96
tools/data/csvutil.py
Normal file
96
tools/data/csvutil.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
import argparse
|
||||
import csv
|
||||
import os
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
# path, name, #frames
|
||||
PREFIX = [
|
||||
"The video shows",
|
||||
"The video captures",
|
||||
"The video features",
|
||||
"The video depicts",
|
||||
"The video presents",
|
||||
"The video features",
|
||||
"The video is ",
|
||||
"In the video,",
|
||||
]
|
||||
|
||||
|
||||
def get_video_length(path):
|
||||
import cv2
|
||||
|
||||
cap = cv2.VideoCapture(path)
|
||||
return int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
|
||||
def main(args):
|
||||
input_path = args.input
|
||||
output_path = args.output
|
||||
if output_path is None:
|
||||
name = os.path.basename(input_path)
|
||||
name, ext = os.path.splitext(name)
|
||||
if args.num_frames_min is not None:
|
||||
name += f"_fmin_{args.num_frames_min}"
|
||||
if args.num_frames_max is not None:
|
||||
name += f"_fmax_{args.num_frames_max}"
|
||||
if args.filter_null_text:
|
||||
name += "_fnt"
|
||||
if args.remove_prefix:
|
||||
name += "_rp"
|
||||
if args.root is not None:
|
||||
name += f"_root"
|
||||
if args.relength:
|
||||
name += "_relength"
|
||||
output_path = os.path.join(os.path.dirname(input_path), name + ext)
|
||||
|
||||
with open(input_path, "r") as f:
|
||||
reader = csv.reader(f)
|
||||
data = list(reader)
|
||||
print("Number of videos before filtering:", len(data))
|
||||
|
||||
data_new = []
|
||||
for i, row in tqdm(enumerate(data)):
|
||||
path = row[0]
|
||||
caption = row[1]
|
||||
n_frames = int(row[2])
|
||||
if args.num_frames_min is not None and n_frames < args.num_frames_min:
|
||||
continue
|
||||
if args.num_frames_max is not None and n_frames > args.num_frames_max:
|
||||
continue
|
||||
if args.filter_null_text and len(caption) == 0:
|
||||
continue
|
||||
if args.remove_prefix:
|
||||
for prefix in PREFIX:
|
||||
if caption.startswith(prefix):
|
||||
caption = caption[len(prefix) :].strip()
|
||||
if caption[0].islower():
|
||||
caption = caption[0].upper() + caption[1:]
|
||||
row[1] = caption
|
||||
break
|
||||
if args.root is not None:
|
||||
row[0] = os.path.join(args.root, path)
|
||||
if args.relength:
|
||||
n_frames = get_video_length(row[0])
|
||||
row[2] = n_frames
|
||||
data_new.append(row)
|
||||
|
||||
print("Number of videos after filtering:", len(data_new))
|
||||
with open(output_path, "w") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerows(data_new)
|
||||
print("Output saved to", output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", type=str, required=True)
|
||||
parser.add_argument("--output", type=str, default=None)
|
||||
parser.add_argument("--num_frames_min", type=int, default=None)
|
||||
parser.add_argument("--num_frames_max", type=int, default=None)
|
||||
parser.add_argument("--filter_null_text", action="store_true")
|
||||
parser.add_argument("--remove_prefix", action="store_true")
|
||||
parser.add_argument("--root", type=str, default=None)
|
||||
parser.add_argument("--relength", action="store_true")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
82
tools/data/get_caption_gpt4.py
Normal file
82
tools/data/get_caption_gpt4.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
import argparse
|
||||
import base64
|
||||
import csv
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import requests
|
||||
import tqdm
|
||||
|
||||
# OpenAI API Key
|
||||
api_key = ""
|
||||
|
||||
|
||||
def get_caption(frame):
|
||||
prompt = "The middle frame from a video clip are given. Describe this video and its style to generate a description for the video. The description should be useful for AI to re-generate the video. Here are some examples of good descriptions:\n\n 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.\n2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.\n 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway."
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
payload = {
|
||||
"model": "gpt-4-vision-preview",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame}"}},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": 300,
|
||||
}
|
||||
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=60)
|
||||
caption = response.json()["choices"][0]["message"]["content"]
|
||||
return caption
|
||||
|
||||
|
||||
# Function to encode the image
|
||||
def encode_image(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
|
||||
def extract_frames(video_path):
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
point = length // 2
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, point)
|
||||
ret, frame = cap.read()
|
||||
_, buffer = cv2.imencode(".jpg", frame)
|
||||
img_base64 = base64.b64encode(buffer).decode("utf-8")
|
||||
return img_base64
|
||||
|
||||
|
||||
def main(args):
|
||||
processed_videos = []
|
||||
if os.path.exists(args.output_file):
|
||||
with open(args.output_file, "r") as f:
|
||||
reader = csv.reader(f)
|
||||
samples = list(reader)
|
||||
processed_videos = [sample[0] for sample in samples]
|
||||
|
||||
f = open(args.output_file, "a")
|
||||
writer = csv.writer(f)
|
||||
for video in tqdm.tqdm(os.listdir(args.video_folder)):
|
||||
if video in processed_videos:
|
||||
continue
|
||||
video_path = os.path.join(args.video_folder, video)
|
||||
base64_image = extract_frames(video_path)
|
||||
caption = get_caption(base64_image)
|
||||
caption = caption.replace("\n", " ")
|
||||
writer.writerow([video, caption])
|
||||
f.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--video_folder", type=str, required=True, help="Path to the folder containing the videos.")
|
||||
parser.add_argument("--output_file", type=str, default="video_captions.csv", help="Path to the output file.")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
382
tools/data/get_caption_llava.py
Normal file
382
tools/data/get_caption_llava.py
Normal file
|
|
@ -0,0 +1,382 @@
|
|||
import argparse
|
||||
import csv
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
|
||||
from llava.conversation import conv_templates
|
||||
from llava.eval.run_llava import eval_model
|
||||
from llava.mm_utils import get_anyres_image_grid_shape, get_model_name_from_path, process_images, tokenizer_image_token
|
||||
from llava.model.builder import load_pretrained_model
|
||||
from llava.model.llava_arch import unpad_image
|
||||
from llava.utils import disable_torch_init
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
disable_torch_init()
|
||||
|
||||
prompts = {
|
||||
"three_frames": "A video is given by providing three frames in chronological order. Describe this video and its style to generate a description. Pay attention to all objects in the video. Do not describe each frame individually. Do not reply with words like 'first frame'. The description should be useful for AI to re-generate the video. The description should be less than six sentences. Here are some examples of good descriptions: 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.",
|
||||
}
|
||||
|
||||
|
||||
def get_filelist(file_path):
|
||||
Filelist = []
|
||||
for home, dirs, files in os.walk(file_path):
|
||||
for filename in files:
|
||||
Filelist.append(os.path.join(home, filename))
|
||||
return Filelist
|
||||
|
||||
|
||||
def get_video_length(cap):
|
||||
return int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
|
||||
def extract_frames(video_path, points=[0.2, 0.5, 0.8]):
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
length = get_video_length(cap)
|
||||
points = [int(length * point) for point in points]
|
||||
frames = []
|
||||
if length < 3:
|
||||
return frames, length
|
||||
for point in points:
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, point)
|
||||
ret, frame = cap.read()
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame = Image.fromarray(frame)
|
||||
frames.append(frame)
|
||||
return frames, length
|
||||
|
||||
|
||||
def prepare_inputs_labels_for_multimodal(
|
||||
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
|
||||
):
|
||||
# llava_arch.py
|
||||
vision_tower = self.get_vision_tower()
|
||||
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
||||
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
||||
|
||||
if type(images) is list or images.ndim == 5:
|
||||
if type(images) is list:
|
||||
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
|
||||
concat_images = torch.cat([image for image in images], dim=0)
|
||||
image_features = self.encode_images(concat_images)
|
||||
split_sizes = [image.shape[0] for image in images]
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
||||
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
||||
if mm_patch_merge_type == "flat":
|
||||
image_features = [x.flatten(0, 1) for x in image_features]
|
||||
elif mm_patch_merge_type.startswith("spatial"):
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
height = width = self.get_vision_tower().num_patches_per_side
|
||||
assert height * width == base_image_feature.shape[0]
|
||||
if image_aspect_ratio == "anyres":
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.get_vision_tower().config.image_size,
|
||||
)
|
||||
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if "unpad" in mm_patch_merge_type:
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.model.image_newline[:, None, None]
|
||||
.expand(*image_feature.shape[:-1], 1)
|
||||
.to(image_feature.device),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
else:
|
||||
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
|
||||
image_feature = image_feature.flatten(0, 3)
|
||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
if "unpad" in mm_patch_merge_type:
|
||||
image_feature = torch.cat(
|
||||
(image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = new_image_features
|
||||
else:
|
||||
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
|
||||
else:
|
||||
image_features = self.encode_images(images)
|
||||
|
||||
# TODO: image start / end is not implemented here to support pretraining.
|
||||
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
|
||||
raise NotImplementedError
|
||||
|
||||
# Let's just add dummy tensors if they do not exist,
|
||||
# it is a headache to deal with None all the time.
|
||||
# But it is not ideal, and if you have a better idea,
|
||||
# please open an issue / submit a PR, thanks.
|
||||
_labels = labels
|
||||
_position_ids = position_ids
|
||||
_attention_mask = attention_mask
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
||||
else:
|
||||
attention_mask = attention_mask.bool()
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
||||
if labels is None:
|
||||
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
||||
|
||||
# remove the padding using attention_mask -- FIXME
|
||||
_input_ids = input_ids
|
||||
input_ids = [
|
||||
cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
|
||||
]
|
||||
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
||||
|
||||
new_input_embeds = []
|
||||
new_labels = []
|
||||
cur_image_idx = 0
|
||||
for batch_idx, cur_input_ids in enumerate(input_ids):
|
||||
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
||||
if num_images == 0:
|
||||
cur_image_features = image_features[cur_image_idx]
|
||||
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
||||
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
||||
new_input_embeds.append(cur_input_embeds)
|
||||
new_labels.append(labels[batch_idx])
|
||||
cur_image_idx += 1
|
||||
continue
|
||||
|
||||
image_token_indices = (
|
||||
[-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
||||
)
|
||||
cur_input_ids_noim = []
|
||||
cur_labels = labels[batch_idx]
|
||||
cur_labels_noim = []
|
||||
for i in range(len(image_token_indices) - 1):
|
||||
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
||||
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
||||
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
||||
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
||||
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
||||
cur_new_input_embeds = []
|
||||
cur_new_labels = []
|
||||
|
||||
for i in range(num_images + 1):
|
||||
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
||||
cur_new_labels.append(cur_labels_noim[i])
|
||||
if i < num_images:
|
||||
cur_image_features = image_features[cur_image_idx]
|
||||
cur_image_idx += 1
|
||||
cur_new_input_embeds.append(cur_image_features)
|
||||
cur_new_labels.append(
|
||||
torch.full(
|
||||
(cur_image_features.shape[0],),
|
||||
IGNORE_INDEX,
|
||||
device=cur_labels.device,
|
||||
dtype=cur_labels.dtype,
|
||||
)
|
||||
)
|
||||
|
||||
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
||||
|
||||
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
||||
cur_new_labels = torch.cat(cur_new_labels)
|
||||
|
||||
new_input_embeds.append(cur_new_input_embeds)
|
||||
new_labels.append(cur_new_labels)
|
||||
|
||||
# Truncate sequences to max length as image embeddings can make the sequence longer
|
||||
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
|
||||
if tokenizer_model_max_length is not None:
|
||||
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
||||
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
||||
|
||||
# Combine them
|
||||
max_len = max(x.shape[0] for x in new_input_embeds)
|
||||
batch_size = len(new_input_embeds)
|
||||
|
||||
new_input_embeds_padded = []
|
||||
new_labels_padded = torch.full(
|
||||
(batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device
|
||||
)
|
||||
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
||||
|
||||
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
||||
cur_len = cur_new_embed.shape[0]
|
||||
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
|
||||
new_input_embeds_padded.append(
|
||||
torch.cat(
|
||||
(
|
||||
torch.zeros(
|
||||
(max_len - cur_len, cur_new_embed.shape[1]),
|
||||
dtype=cur_new_embed.dtype,
|
||||
device=cur_new_embed.device,
|
||||
),
|
||||
cur_new_embed,
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
if cur_len > 0:
|
||||
new_labels_padded[i, -cur_len:] = cur_new_labels
|
||||
attention_mask[i, -cur_len:] = True
|
||||
position_ids[i, -cur_len:] = torch.arange(
|
||||
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
|
||||
)
|
||||
else:
|
||||
new_input_embeds_padded.append(
|
||||
torch.cat(
|
||||
(
|
||||
cur_new_embed,
|
||||
torch.zeros(
|
||||
(max_len - cur_len, cur_new_embed.shape[1]),
|
||||
dtype=cur_new_embed.dtype,
|
||||
device=cur_new_embed.device,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
if cur_len > 0:
|
||||
new_labels_padded[i, :cur_len] = cur_new_labels
|
||||
attention_mask[i, :cur_len] = True
|
||||
position_ids[i, :cur_len] = torch.arange(
|
||||
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
|
||||
)
|
||||
|
||||
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
||||
|
||||
if _labels is None:
|
||||
new_labels = None
|
||||
else:
|
||||
new_labels = new_labels_padded
|
||||
|
||||
if _attention_mask is None:
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
||||
|
||||
if _position_ids is None:
|
||||
position_ids = None
|
||||
|
||||
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(args):
|
||||
bs = args.bs
|
||||
video_folder = args.video_folder
|
||||
|
||||
processed_videos = []
|
||||
if os.path.exists(args.output_file):
|
||||
with open(args.output_file, "r") as f:
|
||||
reader = csv.reader(f)
|
||||
samples = list(reader)
|
||||
processed_videos = [sample[0] for sample in samples]
|
||||
f = open(args.output_file, "a")
|
||||
writer = csv.writer(f)
|
||||
|
||||
model_path = "liuhaotian/llava-v1.6-34b"
|
||||
query = prompts["three_frames"]
|
||||
conv = conv_templates["chatml_direct"].copy()
|
||||
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + query)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
||||
model_path=model_path,
|
||||
model_base=None,
|
||||
model_name=get_model_name_from_path(model_path),
|
||||
)
|
||||
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
||||
input_ids = input_ids.unsqueeze(0).to(model.device)
|
||||
|
||||
videos = get_filelist(video_folder)
|
||||
print(f"Dataset contains {len(videos)} videos.")
|
||||
videos = [video for video in videos if video not in processed_videos]
|
||||
print(f"Processing {len(videos)} new videos.")
|
||||
for i in tqdm(range(0, len(videos), bs)):
|
||||
# prepare a batch of inputs
|
||||
video_files = videos[i : i + bs]
|
||||
frames = []
|
||||
video_lengths = []
|
||||
for video_file in video_files:
|
||||
frame, length = extract_frames(os.path.join(video_folder, video_file))
|
||||
if len(frame) < 3:
|
||||
continue
|
||||
frames.append(frame)
|
||||
video_lengths.append(length)
|
||||
if len(frames) == 0:
|
||||
continue
|
||||
|
||||
# encode the batch of inputs
|
||||
samples = []
|
||||
for imgs in frames:
|
||||
imgs_size = [img.size for img in imgs]
|
||||
imgs = process_images(imgs, image_processor, model.config)
|
||||
imgs = imgs.to(model.device, dtype=torch.float16)
|
||||
with torch.inference_mode():
|
||||
_, _, _, _, inputs_embeds, _ = prepare_inputs_labels_for_multimodal(
|
||||
model, input_ids, None, None, None, None, images=imgs, image_sizes=imgs_size
|
||||
)
|
||||
samples.append(inputs_embeds)
|
||||
|
||||
# padding
|
||||
max_len = max([sample.shape[1] for sample in samples])
|
||||
attention_mask = torch.tensor(
|
||||
[[0] * (max_len - samples[i].shape[1]) + [1] * samples[i].shape[1] for i in range(len(samples))]
|
||||
).to(model.device)
|
||||
inputs_embeds = [
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(1, max_len - samples[i].shape[1], samples[i].shape[-1]),
|
||||
device=model.device,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
samples[i],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
for i in range(len(samples))
|
||||
]
|
||||
inputs_embeds = torch.cat(inputs_embeds, dim=0)
|
||||
|
||||
# generate outputs
|
||||
output_ids = super(type(model), model).generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=True,
|
||||
temperature=0.2,
|
||||
max_new_tokens=512,
|
||||
use_cache=True,
|
||||
)
|
||||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
outputs = [output.replace("\n", " ").strip() for output in outputs]
|
||||
|
||||
# save results
|
||||
result = list(zip(video_files, outputs, video_lengths))
|
||||
for t in result:
|
||||
writer.writerow(t)
|
||||
|
||||
f.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--video_folder", type=str, required=True)
|
||||
parser.add_argument("--bs", type=int, default=32)
|
||||
parser.add_argument("--output_file", type=str, default="video_captions.csv")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
139
tools/data/scene_detect.py
Normal file
139
tools/data/scene_detect.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
import os
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
|
||||
from scenedetect import detect, ContentDetector
|
||||
|
||||
from .utils import check_mp4_integrity, split_video
|
||||
from .utils import clone_folder_structure, iterate_files, iterate_folders
|
||||
from opensora.utils.misc import get_timestamp
|
||||
|
||||
|
||||
# config
|
||||
target_fps = 30 # int
|
||||
shorter_size = 512 # int
|
||||
min_seconds = 1 # float
|
||||
max_seconds = 5 # float
|
||||
assert max_seconds > min_seconds
|
||||
cfg = dict(
|
||||
target_fps=target_fps,
|
||||
min_seconds=min_seconds,
|
||||
max_seconds=max_seconds,
|
||||
shorter_size=shorter_size,
|
||||
)
|
||||
|
||||
|
||||
def process_folder(root_src, root_dst):
|
||||
# create logger
|
||||
folder_path_log = os.path.dirname(root_dst)
|
||||
log_name = os.path.basename(root_dst)
|
||||
timestamp = get_timestamp()
|
||||
log_path = os.path.join(folder_path_log, f'{log_name}_{timestamp}.log')
|
||||
logger = MMLogger.get_instance(log_name, log_file=log_path)
|
||||
|
||||
# clone folder structure
|
||||
clone_folder_structure(root_src, root_dst)
|
||||
|
||||
# all source videos
|
||||
mp4_list = [x for x in iterate_files(root_src) if x.endswith('.mp4')]
|
||||
mp4_list = sorted(mp4_list)
|
||||
|
||||
for idx, sample_path in tqdm(enumerate(mp4_list)):
|
||||
folder_src = os.path.dirname(sample_path)
|
||||
folder_dst = os.path.join(root_dst, os.path.relpath(folder_src, root_src))
|
||||
|
||||
# check src video integrity
|
||||
if not check_mp4_integrity(sample_path, logger=logger):
|
||||
continue
|
||||
|
||||
# detect scenes
|
||||
scene_list = detect(sample_path, ContentDetector(), start_in_scene=True)
|
||||
|
||||
# split scenes
|
||||
save_path_list = split_video(sample_path, scene_list, save_dir=folder_dst, **cfg, logger=logger)
|
||||
|
||||
# check integrity of generated clips
|
||||
for x in save_path_list:
|
||||
check_mp4_integrity(x, logger=logger)
|
||||
|
||||
|
||||
def scene_detect():
|
||||
""" detect & cut scenes using a single process
|
||||
Expected dataset structure:
|
||||
data/
|
||||
your_dataset/
|
||||
raw_videos/
|
||||
xxx.mp4
|
||||
yyy.mp4
|
||||
|
||||
This function results in:
|
||||
data/
|
||||
your_dataset/
|
||||
raw_videos/
|
||||
xxx.mp4
|
||||
yyy.mp4
|
||||
zzz.mp4
|
||||
clips/
|
||||
xxx_scene-0.mp4
|
||||
yyy_scene-0.mp4
|
||||
yyy_scene-1.mp4
|
||||
"""
|
||||
# TODO: specify your dataset root
|
||||
root_src = f'./data/your_dataset/raw_videos'
|
||||
root_dst = f'./data/your_dataset/clips'
|
||||
|
||||
process_folder(root_src, root_dst)
|
||||
|
||||
|
||||
def scene_detect_mp():
|
||||
""" detect & cut scenes using multiple processes
|
||||
Expected dataset structure:
|
||||
data/
|
||||
your_dataset/
|
||||
raw_videos/
|
||||
split_0/
|
||||
xxx.mp4
|
||||
yyy.mp4
|
||||
split_1/
|
||||
xxx.mp4
|
||||
yyy.mp4
|
||||
|
||||
This function results in:
|
||||
data/
|
||||
your_dataset/
|
||||
raw_videos/
|
||||
split_0/
|
||||
xxx.mp4
|
||||
yyy.mp4
|
||||
split_1/
|
||||
xxx.mp4
|
||||
yyy.mp4
|
||||
clips/
|
||||
split_0/
|
||||
xxx_scene-0.mp4
|
||||
yyy_scene-0.mp4
|
||||
split_1/
|
||||
xxx_scene-0.mp4
|
||||
yyy_scene-0.mp4
|
||||
yyy_scene-1.mp4
|
||||
"""
|
||||
# TODO: specify your dataset root
|
||||
root_src = f'./data/your_dataset/raw_videos'
|
||||
root_dst = f'./data/your_dataset/clips'
|
||||
|
||||
# TODO: specify your splits
|
||||
splits = ['split_0', 'split_1']
|
||||
|
||||
# process folders
|
||||
root_src_list = [os.path.join(root_src, x) for x in splits]
|
||||
root_dst_list = [os.path.join(root_dst, x) for x in splits]
|
||||
|
||||
with Pool(processes=len(splits)) as pool:
|
||||
pool.starmap(process_folder, list(zip(root_src_list, root_dst_list)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# TODO: choose single process or multiprocessing
|
||||
scene_detect()
|
||||
# scene_detect_mp()
|
||||
17
tools/data/to_csv_imagenet.py
Normal file
17
tools/data/to_csv_imagenet.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
import csv
|
||||
import os
|
||||
|
||||
from torchvision.datasets import ImageNet
|
||||
|
||||
root = "~/data/imagenet"
|
||||
split = "train"
|
||||
|
||||
root = os.path.expanduser(root)
|
||||
data = ImageNet(root, split=split)
|
||||
samples = [(path, data.classes[label][0]) for path, label in data.samples]
|
||||
|
||||
with open(f"preprocess/imagenet_{split}.csv", "w") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerows(samples)
|
||||
|
||||
print(f"Saved {len(samples)} samples to preprocess/imagenet_{split}.csv.")
|
||||
36
tools/data/to_csv_ucf101.py
Normal file
36
tools/data/to_csv_ucf101.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
import csv
|
||||
import os
|
||||
|
||||
|
||||
def get_filelist(file_path):
|
||||
Filelist = []
|
||||
for home, dirs, files in os.walk(file_path):
|
||||
for filename in files:
|
||||
Filelist.append(os.path.join(home, filename))
|
||||
return Filelist
|
||||
|
||||
|
||||
def split_by_capital(name):
|
||||
# BoxingPunchingBag -> Boxing Punching Bag
|
||||
new_name = ""
|
||||
for i in range(len(name)):
|
||||
if name[i].isupper() and i != 0:
|
||||
new_name += " "
|
||||
new_name += name[i]
|
||||
return new_name
|
||||
|
||||
|
||||
root = "~/data/ucf101"
|
||||
split = "videos"
|
||||
|
||||
root = os.path.expanduser(root)
|
||||
video_lists = get_filelist(os.path.join(root, split))
|
||||
classes = [x.split("/")[-2] for x in video_lists]
|
||||
classes = [split_by_capital(x) for x in classes]
|
||||
samples = list(zip(video_lists, classes))
|
||||
|
||||
with open(f"preprocess/ucf101_{split}.csv", "w") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerows(samples)
|
||||
|
||||
print(f"Saved {len(samples)} samples to preprocess/ucf101_{split}.csv.")
|
||||
148
tools/data/utils.py
Normal file
148
tools/data/utils.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
import os
|
||||
import cv2
|
||||
import subprocess
|
||||
from mmengine.logging import print_log
|
||||
|
||||
from moviepy.editor import VideoFileClip
|
||||
from imageio_ffmpeg import get_ffmpeg_exe
|
||||
from scenedetect import FrameTimecode
|
||||
|
||||
|
||||
def iterate_files(folder_path):
|
||||
for root, dirs, files in os.walk(folder_path):
|
||||
# root contains the current directory path
|
||||
# dirs contains the list of subdirectories in the current directory
|
||||
# files contains the list of files in the current directory
|
||||
|
||||
# Process files in the current directory
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
# print("File:", file_path)
|
||||
yield file_path
|
||||
|
||||
# Process subdirectories and recursively call the function
|
||||
for subdir in dirs:
|
||||
subdir_path = os.path.join(root, subdir)
|
||||
# print("Subdirectory:", subdir_path)
|
||||
iterate_files(subdir_path)
|
||||
|
||||
|
||||
def iterate_folders(folder_path):
|
||||
for root, dirs, files in os.walk(folder_path):
|
||||
for subdir in dirs:
|
||||
subdir_path = os.path.join(root, subdir)
|
||||
yield subdir_path
|
||||
# print("Subdirectory:", subdir_path)
|
||||
iterate_folders(subdir_path)
|
||||
|
||||
|
||||
def clone_folder_structure(root_src, root_dst, verbose=False):
|
||||
src_path_list = iterate_folders(root_src)
|
||||
src_relpath_list = [os.path.relpath(x, root_src) for x in src_path_list]
|
||||
|
||||
os.makedirs(root_dst, exist_ok=True)
|
||||
dst_path_list = [os.path.join(root_dst, x) for x in src_relpath_list]
|
||||
for folder_path in dst_path_list:
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
if verbose:
|
||||
print(f'Create folder: \'{folder_path}\'')
|
||||
|
||||
|
||||
def count_files(root, suffix='.mp4'):
|
||||
files_list = iterate_files(root)
|
||||
cnt = len([x for x in files_list if x.endswith(suffix)])
|
||||
return cnt
|
||||
|
||||
|
||||
def check_mp4_integrity(file_path, verbose=True, logger=None):
|
||||
try:
|
||||
video_clip = VideoFileClip(file_path)
|
||||
if verbose:
|
||||
print_log(f'The MP4 file \'{file_path}\' is intact.', logger=logger)
|
||||
return True
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print_log(f'Error: {e}', logger=logger)
|
||||
print_log(f'The MP4 file \'{file_path}\' is not intact.', logger=logger)
|
||||
return False
|
||||
|
||||
|
||||
def count_frames(video_path):
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
if not cap.isOpened():
|
||||
print(f"Error: Could not open video file '{video_path}'")
|
||||
return
|
||||
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
print(f"Total frames in the video '{video_path}': {total_frames}")
|
||||
|
||||
cap.release()
|
||||
|
||||
|
||||
def split_video(sample_path,
|
||||
scene_list,
|
||||
save_dir,
|
||||
target_fps=30,
|
||||
min_seconds=1,
|
||||
max_seconds=10,
|
||||
shorter_size=512,
|
||||
verbose=False,
|
||||
logger=None,
|
||||
):
|
||||
FFMPEG_PATH = get_ffmpeg_exe()
|
||||
|
||||
save_path_list = []
|
||||
for idx, scene in enumerate(scene_list):
|
||||
s, t = scene # FrameTimecode
|
||||
fps = s.framerate
|
||||
max_duration = FrameTimecode(timecode='00:00:00', fps=fps)
|
||||
max_duration.frame_num = round(fps * max_seconds)
|
||||
duration = min(max_duration, t - s)
|
||||
if duration.get_frames() < round(min_seconds * fps):
|
||||
continue
|
||||
|
||||
# save path
|
||||
fname = os.path.basename(sample_path)
|
||||
fname_wo_ext = os.path.splitext(fname)[0]
|
||||
# TODO: fname pattern
|
||||
save_path = os.path.join(save_dir, f'{fname_wo_ext}_scene-{idx}.mp4')
|
||||
|
||||
# ffmpeg cmd
|
||||
cmd = [FFMPEG_PATH]
|
||||
|
||||
# Only show ffmpeg output for the first call, which will display any
|
||||
# errors if it fails, and then break the loop. We only show error messages
|
||||
# for the remaining calls.
|
||||
# cmd += ['-v', 'error']
|
||||
|
||||
# input path
|
||||
cmd += ['-i', sample_path]
|
||||
|
||||
# clip to cut
|
||||
cmd += [
|
||||
'-nostdin', '-y',
|
||||
'-ss', str(s.get_seconds()),
|
||||
'-t', str(duration.get_seconds())
|
||||
]
|
||||
|
||||
# target fps
|
||||
# cmd += ['-vf', 'select=mod(n\,2)']
|
||||
cmd += ['-r', f'{target_fps}']
|
||||
|
||||
# aspect ratio
|
||||
cmd += ['-vf', f"scale='if(gt(iw,ih),-2,{shorter_size})':'if(gt(iw,ih),{shorter_size},-2)'"]
|
||||
# cmd += ['-vf', f"scale='if(gt(iw,ih),{shorter_size},trunc(ow/a/2)*2)':-2"]
|
||||
|
||||
cmd += ['-map', '0', save_path]
|
||||
|
||||
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
stdout, stderr = proc.communicate()
|
||||
if verbose:
|
||||
stdout = stdout.decode('utf-8')
|
||||
print_log(stdout, logger=logger)
|
||||
|
||||
save_path_list.append(sample_path)
|
||||
print_log(f'Video clip saved to \'{save_path}\'', logger=logger)
|
||||
|
||||
return save_path_list
|
||||
Loading…
Reference in a new issue