diff --git a/README.md b/README.md new file mode 100644 index 0000000..3b632d0 --- /dev/null +++ b/README.md @@ -0,0 +1,149 @@ +

+ +

+ +

+
+ + + +
+ +## Open-Sora: Towards Open Reproduction of Sora + +**Open-Sora** is an **open-source** initiative dedicated to **efficiently** reproducing OpenAI's Sora. Our project aims to cover **the full pipeline**, including video data preprocessing, training with acceleration, efficient inference and more. Operating on a limited budget, we prioritize the vibrant open-source community, providing access to text-to-image, image captioning, and language models. We hope to make a contribution to the community and make the project more accessible to everyone. + +## 📰 News + +* **[2024.03.18]** 🔥 We release **Open-Sora 1.0**, an open-source project to reproduce OpenAI Sora. +Open-Sora 1.0 supports a full pipeline of video data preprocessing, training with + acceleration, +inference, and more. Our provided checkpoint can produce 2s 512x512 videos. + +## 🎥 Latest Demo + +| **2s 512x512** | **2s 512x512** | +| ----------------------------------------------- | ----------------------------------------------- | +| | | + +## 🔆 New Features/Updates + +- 📍 Open-Sora-v1 is trained on xxx. We train the model in three stages. Model weights are available here. Training details can be found here. +- ✅ Support training acceleration including flash-attention, accelerated T5, mixed precision, gradient checkpointing, splitted VAE, sequence parallelism, etc. XXX times. See more discussions [here](). +- ✅ We provide video cutting and captioning tools for data preprocessing. Our data collection plan can be found [here](). +- ✅ We find VQ-VAE from [] has a low quality and thus adopt a better VAE from []. We also find patching in the time dimension deteriorates the quality. See more discussions [here](). +- ✅ We investigate different architectures including DiT, Latte, and our proposed STDiT. Our STDiT achieves a better trade-off between quality and speed. See more discussions [here](). +- ✅ Support clip and t5 text conditioning. +- ✅ By viewing images as one-frame videos, our project supports training DiT on both images and videos (e.g., ImageNet & UCF101). +- ✅ Support inference with official weights from [DiT](https://github.com/facebookresearch/DiT), [Latte](https://github.com/Vchitect/Latte), and [PixArt](https://pixart-alpha.github.io/). + + + +### TODO list sorted by priority + +- [ ] Complete the data processing pipeline (including dense optical flow, aesthetics scores, text-image similarity, deduplication, etc.). See [datasets.md]() for more information. **[WIP]** +- [ ] Training Video-VAE. **[WIP]** +- [ ] Support image and video conditioning. +- [ ] Evaluation pipeline. +- [ ] Incoporate a better scheduler, e.g., rectified flow in SD3. +- [ ] Support variable aspect ratios, resolutions, durations. +- [ ] Support SD3 when released. + + +## Contents + +- [Open-Sora: Towards Open Reproduction of Sora](#open-sora-towards-open-reproduction-of-sora) +- [📰 News](#-news) +- [🎥 Latest Demo](#-latest-demo) +- [🔆 New Features/Updates](#-new-featuresupdates) + - [TODO list sorted by priority](#todo-list-sorted-by-priority) +- [Contents](#contents) +- [Installation](#installation) +- [Model Weights](#model-weights) +- [Inference](#inference) +- [Data Processing](#data-processing) + - [Split video into clips](#split-video-into-clips) + - [Generate video caption](#generate-video-caption) +- [Training](#training) +- [Acknowledgement](#acknowledgement) +- [Citation](#citation) +- [Star History](#star-history) +- [TODO](#todo) + +## Installation + +```bash +git clone https://github.com/hpcaitech/Open-Sora +cd Open-Sora +pip install xxx +``` + +After installation, to get fimilar with the project, you can check the [here]() for the project structure and how to use the config files. + +## Model Weights + +| Model | #Params | url | +| ---------- | ------- | --- | +| 16x256x256 | | | + +## Inference + +```bash +python scripts/inference.py configs/opensora/inference/16x256x256.py +``` + +## Data Processing + +### Split video into clips + +We provide code to split a long video into separate clips efficiently using `multiprocessing`. See `tools/data/scene_detect.py`. + +### Generate video caption + +## Training + +## Acknowledgement + +* [DiT](https://github.com/facebookresearch/DiT): Scalable Diffusion Models with Transformers. +* [OpenDiT](https://github.com/NUS-HPC-AI-Lab/OpenDiT): An acceleration for DiT training. OpenDiT's team provides valuable suggestions on acceleration of our training process. +* [PixArt](https://github.com/PixArt-alpha/PixArt-alpha): An open-source DiT-based text-to-image model. +* [Latte](https://github.com/Vchitect/Latte): An attempt to efficiently train DiT for video. +* [StabilityAI VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse-original): A powerful image VAE model. +* [CLIP](https://github.com/openai/CLIP): A powerful text-image embedding model. +* [T5](https://github.com/google-research/text-to-text-transfer-transformer): The powerful text encoder. +* [LLaVA](https://github.com/haotian-liu/LLaVA): A powerful image captioning model based on [LLaMA](https://github.com/meta-llama/llama) and [Yi-34B](https://huggingface.co/01-ai/Yi-34B). +* [PySceneDetect](https://github.com/Breakthrough/PySceneDetect): A powerful tool to split video into clips. + +We are grateful for their exceptional work and generous contribution to open source. + +## Citation + +```bibtex +@software{opensora, + author = {Zangwei Zheng and Xiangyu Peng and Shenggui Li and Yang You}, + title = {Open-Sora: Towards Open Reproduction of Sora}, + month = {March}, + year = {2024}, + url = {https://github.com/hpcaitech/Open-Sora} +} +``` + +Zangwei Zheng and Xiangyu Peng equally contributed to this work. + +## Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=hpcaitech/Open-Sora&type=Date)](https://star-history.com/#hpcaitech/Open-Sora&Date) + +## TODO + +Modules for releasing: + +* `configs` +* `opensora` +* `assets` +* `scripts` +* `tools` + +packages for data processing + +put all outputs under ./checkpoints/, including pretrained_models, checkpoints, samples diff --git a/docs/datasets.md b/docs/datasets.md new file mode 100644 index 0000000..e69de29 diff --git a/opensora/__init__.py b/opensora/__init__.py new file mode 100644 index 0000000..17a73fc --- /dev/null +++ b/opensora/__init__.py @@ -0,0 +1,4 @@ +from .acceleration import * +from .datasets import * +from .registry import * +from .models import * diff --git a/opensora/registry.py b/opensora/registry.py new file mode 100644 index 0000000..af1bbc6 --- /dev/null +++ b/opensora/registry.py @@ -0,0 +1,39 @@ +from copy import deepcopy + +import torch.nn as nn +from mmengine.registry import Registry + + +def build_module(module, builder, **kwargs): + """Build module from config or return the module itself. + + Args: + module (Union[dict, nn.Module]): The module to build. + builder (Registry): The registry to build module. + *args, **kwargs: Arguments passed to build function. + + Returns: + Any: The built module. + """ + if isinstance(module, dict): + cfg = deepcopy(module) + for k, v in kwargs.items(): + cfg[k] = v + return builder.build(cfg) + elif isinstance(module, nn.Module): + return module + elif module is None: + return None + else: + raise TypeError(f"Only support dict and nn.Module, but got {type(module)}.") + + +MODELS = Registry( + "model", + locations=["opensora.models"], +) + +SCHEDULERS = Registry( + "scheduler", + locations=["opensora.schedulers"], +) diff --git a/opensora/schedulers/__init__.py b/opensora/schedulers/__init__.py new file mode 100644 index 0000000..f64aa25 --- /dev/null +++ b/opensora/schedulers/__init__.py @@ -0,0 +1,2 @@ +from .iddpm import IDDPM +from .dpms import DPMS diff --git a/opensora/schedulers/dpms/__init__.py b/opensora/schedulers/dpms/__init__.py new file mode 100644 index 0000000..f0cebbc --- /dev/null +++ b/opensora/schedulers/dpms/__init__.py @@ -0,0 +1,50 @@ +from functools import partial + +import torch + +from opensora.registry import SCHEDULERS + +from .dpm_solver import DPMS + + +@SCHEDULERS.register_module("dpm-solver") +class DMP_SOLVER: + def __init__(self, num_sampling_steps=None, cfg_scale=4.0): + self.num_sampling_steps = num_sampling_steps + self.cfg_scale = cfg_scale + + def sample( + self, + model, + text_encoder, + z_size, + prompts, + device, + additional_args=None, + ): + n = len(prompts) + z = torch.randn(n, *z_size, device=device) + model_args = text_encoder.encode(prompts) + y = model_args.pop("y") + null_y = text_encoder.null(n) + if additional_args is not None: + model_args.update(additional_args) + + dpms = DPMS( + partial(forward_with_dpmsolver, model), + condition=y, + uncondition=null_y, + cfg_scale=self.cfg_scale, + model_kwargs=model_args, + ) + samples = dpms.sample(z, steps=self.num_sampling_steps, order=2, skip_type="time_uniform", method="multistep") + return samples + + +def forward_with_dpmsolver(self, x, timestep, y, **kwargs): + """ + dpm solver donnot need variance prediction + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + model_out = self.forward(x, timestep, y, **kwargs) + return model_out.chunk(2, dim=1)[0] diff --git a/opensora/schedulers/dpms/dpm_solver.py b/opensora/schedulers/dpms/dpm_solver.py new file mode 100644 index 0000000..d887398 --- /dev/null +++ b/opensora/schedulers/dpms/dpm_solver.py @@ -0,0 +1,1558 @@ +# Adapted from dpm-solver + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# PixArt: https://github.com/PixArt-alpha/PixArt-alpha +# dpm-solver: https://github.com/LuChengTHU/dpm-solver +# -------------------------------------------------------- + +import math + +import numpy as np +import torch +from tqdm import tqdm + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class NoiseScheduleVP: + def __init__( + self, + schedule="discrete", + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20.0, + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise + schedule are the default settings in Yang Song's ScoreSDE: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ["discrete", "linear"]: + raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear'") + + self.schedule = schedule + if schedule == "discrete": + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.T = 1.0 + self.log_alpha_array = ( + self.numerical_clip_alpha(log_alphas) + .reshape( + ( + 1, + -1, + ) + ) + .to(dtype=dtype) + ) + self.total_N = self.log_alpha_array.shape[1] + self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + else: + self.T = 1.0 + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + + def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1): + """ + For some beta schedules such as cosine schedule, the log-SNR has numerical isssues. + We clip the log-SNR near t=T within -5.1 to ensure the stability. + Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE. + """ + log_sigmas = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_alphas)) + lambs = log_alphas - log_sigmas + idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda) + if idx > 0: + log_alphas = log_alphas[:-idx] + return log_alphas + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == "discrete": + return interpolate_fn( + t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device) + ).reshape((-1)) + elif self.schedule == "linear": + return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == "linear": + tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == "discrete": + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) + t = interpolate_fn( + log_alpha.reshape((-1, 1)), + torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1]), + ) + return t.reshape((-1,)) + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1.0, + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == "discrete": + return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0 + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim()) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -expand_dims(sigma_t, x.dim()) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1.0 or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="dpmsolver++", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1.0, + dynamic_thresholding_ratio=0.995, + ): + """Construct a DPM-Solver. + + We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). + + We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you + can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the + dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space + DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space + DPMs (such as stable-diffusion). + + To support advanced algorithms in image-to-image applications, we also support corrector functions for + both x0 and xt. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". + correcting_x0_fn: A `str` or a function with the following format: + ``` + def correcting_x0_fn(x0, t): + x0_new = ... + return x0_new + ``` + This function is to correct the outputs of the data prediction model at each sampling step. e.g., + ``` + x0_pred = data_pred_model(xt, t) + if correcting_x0_fn is not None: + x0_pred = correcting_x0_fn(x0_pred, t) + xt_1 = update(x0_pred, xt, t) + ``` + If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. + correcting_xt_fn: A function with the following format: + ``` + def correcting_xt_fn(xt, t, step): + x_new = ... + return x_new + ``` + This function is to correct the intermediate samples xt at each sampling step. e.g., + ``` + xt = ... + xt = correcting_xt_fn(xt, t, step) + ``` + thresholding_max_val: A `float`. The max value for thresholding. + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models + with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["dpmsolver", "dpmsolver++"] + self.algorithm_type = algorithm_type + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + def dynamic_thresholding_fn(self, x0, t): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.algorithm_type == "dpmsolver++": + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == "logSNR": + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == "time_uniform": + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == "time_quadratic": + t_order = 2 + return torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) + else: + raise ValueError( + f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'" + ) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [ + 3, + ] * ( + K - 2 + ) + [2, 1] + elif steps % 3 == 1: + orders = [ + 3, + ] * ( + K - 1 + ) + [1] + else: + orders = [ + 3, + ] * ( + K - 1 + ) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [ + 2, + ] * K + else: + K = steps // 2 + 1 + orders = [ + 2, + ] * ( + K - 1 + ) + [1] + elif order == 1: + K = 1 + orders = [ + 1, + ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == "logSNR": + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum( + torch.tensor( + [ + 0, + ] + + orders + ), + 0, + ).to(device) + ] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s + return (x_t, {"model_s": model_s}) if return_intermediate else x_t + + def singlestep_dpm_solver_second_update( + self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type="dpmsolver" + ): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}") + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == "taylor": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == "taylor": + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update( + self, + x, + s, + t, + r1=1.0 / 3.0, + r2=2.0 / 3.0, + model_s=None, + model_s1=None, + return_intermediate=False, + solver_type="dpmsolver", + ): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}") + if r1 is None: + r1 = 1.0 / 3.0 + if r2 is None: + r2 = 2.0 / 3.0 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(s2), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_s2, sigma_t = ( + ns.marginal_std(s), + ns.marginal_std(s1), + ns.marginal_std(s2), + ns.marginal_std(t), + ) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0 + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (sigma_s2 / sigma_s) * x + - (alpha_s2 * phi_12) * model_s + + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0 + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (torch.exp(log_alpha_s2 - log_alpha_s)) * x + - (sigma_s2 * phi_12) * model_s + - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}") + ns = self.noise_schedule + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if solver_type == "dpmsolver": + x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0 + elif solver_type == "taylor": + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * (phi_1 / h + 1.0)) * D1_0 + ) + else: + phi_1 = torch.expm1(h) + if solver_type == "dpmsolver": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - 0.5 * (sigma_t * phi_1) * D1_0 + ) + elif solver_type == "taylor": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * (phi_1 / h - 1.0)) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_2), + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + return ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + return ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + def singlestep_dpm_solver_update( + self, x, s, t, order, return_intermediate=False, solver_type="dpmsolver", r1=None, r2=None + ): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1 + ) + elif order == 3: + return self.singlestep_dpm_solver_third_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2 + ) + else: + raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}") + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}") + + def dpm_solver_adaptive( + self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpmsolver" + ): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((1,)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, solver_type=solver_type, **kwargs + ) + elif order == 3: + r1, r2 = 1.0 / 3.0, 2.0 / 3.0 + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type + ) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update( + x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs + ) + else: + raise ValueError(f"For adaptive step size solver, order must be 2 or 3, got {order}") + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.0): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s) + nfe += order + print("adaptive solver nfe", nfe) + return x + + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise + return xt.squeeze(0) if t.shape[0] == 1 else xt + + def inverse( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + ): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert ( + t_0 > 0 and t_T > 0 + ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + return self.sample( + x, + steps=steps, + t_start=t_0, + t_end=t_T, + order=order, + skip_type=skip_type, + method=method, + lower_order_final=lower_order_final, + denoise_to_zero=denoise_to_zero, + solver_type=solver_type, + atol=atol, + rtol=rtol, + return_intermediate=return_intermediate, + ) + + def sample( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g., DPM-Solver: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + e.g., DPM-Solver++: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + return_intermediate: A `bool`. Whether to save the xt at each step. + When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert ( + t_0 > 0 and t_T > 0 + ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in [ + "multistep", + "singlestep", + "singlestep_fixed", + ], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in [ + "multistep", + "singlestep", + "singlestep_fixed", + ], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == "adaptive": + x = self.dpm_solver_adaptive( + x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type + ) + elif method == "multistep": + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + # Init the first `order` values by lower order multistep DPM-Solver. + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, t, step, solver_type=solver_type + ) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in tqdm(range(order, steps + 1)): + t = timesteps[step] + # We only use lower order for steps < 10 + if lower_order_final and steps < 10: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type + ) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, t) + elif method in ["singlestep", "singlestep_fixed"]: + if method == "singlestep": + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver( + steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device + ) + elif method == "singlestep_fixed": + K = steps // order + orders = [ + order, + ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps( + skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device + ) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + else: + raise ValueError(f"Got wrong method {method}") + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + return (x, intermediates) if return_intermediate else x + + +############################################################# +# other utility functions +############################################################# + + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + return start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] + + +def DPMS( + model, + condition, + uncondition, + cfg_scale, + model_type="noise", + noise_schedule="linear", + guidance_type="classifier-free", + model_kwargs=None, + diffusion_steps=1000, +): + if model_kwargs is None: + model_kwargs = {} + betas = torch.tensor(get_named_beta_schedule(noise_schedule, diffusion_steps)) + + ## 1. Define the noise schedule. + noise_schedule = NoiseScheduleVP(schedule="discrete", betas=betas) + + ## 2. Convert your discrete-time `model` to the continuous-time + ## noise prediction model. Here is an example for a diffusion model + ## `model` with the noise prediction type ("noise") . + model_fn = model_wrapper( + model, + noise_schedule, + model_type=model_type, + model_kwargs=model_kwargs, + guidance_type=guidance_type, + condition=condition, + unconditional_condition=uncondition, + guidance_scale=cfg_scale, + ) + ## 3. Define dpm-solver and sample by multistep DPM-Solver. + return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") diff --git a/opensora/schedulers/iddpm/__init__.py b/opensora/schedulers/iddpm/__init__.py new file mode 100644 index 0000000..b9806ad --- /dev/null +++ b/opensora/schedulers/iddpm/__init__.py @@ -0,0 +1,95 @@ +from functools import partial + +import torch + +from opensora.registry import SCHEDULERS + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps + + +@SCHEDULERS.register_module("iddpm") +class IDDPM(SpacedDiffusion): + def __init__( + self, + num_sampling_steps=None, + timestep_respacing=None, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000, + cfg_scale=4.0, + ): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if num_sampling_steps is not None: + assert timestep_respacing is None + timestep_respacing = str(num_sampling_steps) + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + super().__init__( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), + model_var_type=( + (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + # rescale_timesteps=rescale_timesteps, + ) + + self.cfg_scale = cfg_scale + + def sample( + self, + model, + text_encoder, + z_size, + prompts, + device, + additional_args=None, + ): + n = len(prompts) + z = torch.randn(n, *z_size, device=device) + z = torch.cat([z, z], 0) + model_args = text_encoder.encode(prompts) + y_null = text_encoder.null(n) + model_args["y"] = torch.cat([model_args["y"], y_null], 0) + if additional_args is not None: + model_args.update(additional_args) + + forward = partial(forward_with_cfg, model, cfg_scale=self.cfg_scale) + samples = self.p_sample_loop( + forward, + z.shape, + z, + clip_denoised=False, + model_kwargs=model_args, + progress=True, + device=device, + ) + samples, _ = samples.chunk(2, dim=0) + return samples + + +def forward_with_cfg(model, x, timestep, y, cfg_scale, **kwargs): + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = model.forward(combined, timestep, y, **kwargs) + model_out = model_out["x"] if isinstance(model_out, dict) else model_out + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) diff --git a/opensora/schedulers/iddpm/diffusion_utils.py b/opensora/schedulers/iddpm/diffusion_utils.py new file mode 100644 index 0000000..c097ac5 --- /dev/null +++ b/opensora/schedulers/iddpm/diffusion_utils.py @@ -0,0 +1,87 @@ +# Adapted from DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# -------------------------------------------------------- + + +import numpy as np +import torch as th + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2)) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/opensora/schedulers/iddpm/gaussian_diffusion.py b/opensora/schedulers/iddpm/gaussian_diffusion.py new file mode 100644 index 0000000..4a74592 --- /dev/null +++ b/opensora/schedulers/iddpm/gaussian_diffusion.py @@ -0,0 +1,835 @@ +# Adapted from DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# -------------------------------------------------------- + +import enum +import math + +import numpy as np +import torch as th + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__(self, *, betas, model_mean_type, model_var_type, loss_type): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = ( + np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) + if len(self.posterior_variance) > 1 + else np.array([]) + ) + + self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) + # Equation 12. + noise = th.randn_like(x) + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) + out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/opensora/schedulers/iddpm/respace.py b/opensora/schedulers/iddpm/respace.py new file mode 100644 index 0000000..d5ea16c --- /dev/null +++ b/opensora/schedulers/iddpm/respace.py @@ -0,0 +1,127 @@ +# Adapted from DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# -------------------------------------------------------- + + +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError(f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/opensora/schedulers/iddpm/timestep_sampler.py b/opensora/schedulers/iddpm/timestep_sampler.py new file mode 100644 index 0000000..52b6717 --- /dev/null +++ b/opensora/schedulers/iddpm/timestep_sampler.py @@ -0,0 +1,150 @@ +# Adapted from DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# -------------------------------------------------------- + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [th.tensor([0], dtype=th.int32, device=local_ts.device) for _ in range(dist.get_world_size())] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/opensora/utils/__init__.py b/opensora/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/opensora/utils/ckpt_utils.py b/opensora/utils/ckpt_utils.py new file mode 100644 index 0000000..7960f3e --- /dev/null +++ b/opensora/utils/ckpt_utils.py @@ -0,0 +1,233 @@ +import functools +import json +import logging +import operator +import os +from typing import Tuple + +import colossalai +import torch +import torch.distributed as dist +import torch.nn as nn +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torchvision.datasets.utils import download_url + +pretrained_models = { + "DiT-XL-2-512x512.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt", + "DiT-XL-2-256x256.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt", + "Latte-XL-2-256x256-ucf101.pt": "https://huggingface.co/maxin-cn/Latte/resolve/main/ucf101.pt", + "PixArt-XL-2-256x256.pth": "PixArt-XL-2-256x256.pth", + "PixArt-XL-2-SAM-256x256.pth": "PixArt-XL-2-SAM-256x256.pth", + "PixArt-XL-2-512x512.pth": "PixArt-XL-2-512x512.pth", + "PixArt-XL-2-1024-MS.pth": "PixArt-XL-2-1024-MS.pth", +} + + +def reparameter(ckpt, name=None): + if "DiT" in name: + ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) + del ckpt["pos_embed"] + elif "Latte" in name: + ckpt = ckpt["ema"] + ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) + del ckpt["pos_embed"] + del ckpt["temp_embed"] + elif "PixArt" in name: + ckpt = ckpt["state_dict"] + ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) + del ckpt["pos_embed"] + return ckpt + + +def find_model(model_name): + """ + Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path. + """ + if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints + model = download_model(model_name) + model = reparameter(model, model_name) + return model + else: # Load a custom DiT checkpoint: + assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}" + checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) + if "pos_embed_temporal" in checkpoint: + del checkpoint["pos_embed_temporal"] + if "pos_embed" in checkpoint: + del checkpoint["pos_embed"] + if "ema" in checkpoint: # supports checkpoints from train.py + checkpoint = checkpoint["ema"] + return checkpoint + + +def download_model(model_name): + """ + Downloads a pre-trained DiT model from the web. + """ + assert model_name in pretrained_models + local_path = f"pretrained_models/{model_name}" + if not os.path.isfile(local_path): + os.makedirs("pretrained_models", exist_ok=True) + web_path = pretrained_models[model_name] + download_url(web_path, "pretrained_models", model_name) + model = torch.load(local_path, map_location=lambda storage, loc: storage) + return model + + +def load_from_sharded_state_dict(model, ckpt_path): + # TODO: harded-coded for colossal loading + os.environ["RANK"] = "0" + os.environ["LOCAL_RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29501" + colossalai.launch_from_torch({}) + plugin = LowLevelZeroPlugin( + stage=2, + precision="fp32", + initial_scale=2**16, + ) + booster = Booster(plugin=plugin) + model, _, _, _, _ = booster.boost(model=model) + booster.load_model(model, os.path.join(ckpt_path, "model")) + + save_path = os.path.join(ckpt_path, "model_ckpt.pt") + torch.save(model.module.state_dict(), save_path) + print(f"Model checkpoint saved to {save_path}") + + +def model_sharding(model: torch.nn.Module): + global_rank = dist.get_rank() + world_size = dist.get_world_size() + for _, param in model.named_parameters(): + padding_size = (world_size - param.numel() % world_size) % world_size + if padding_size > 0: + padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + else: + padding_param = param.data.view(-1) + splited_params = padding_param.split(padding_param.numel() // world_size) + splited_params = splited_params[global_rank] + param.data = splited_params + + +def load_json(file_path: str): + with open(file_path, "r") as f: + return json.load(f) + + +def save_json(data, file_path: str): + with open(file_path, "w") as f: + json.dump(data, f, indent=4) + + +def remove_padding(tensor: torch.Tensor, original_shape: Tuple) -> torch.Tensor: + return tensor[: functools.reduce(operator.mul, original_shape)] + + +def model_gathering(model: torch.nn.Module, model_shape_dict: dict): + global_rank = dist.get_rank() + global_size = dist.get_world_size() + for name, param in model.named_parameters(): + all_params = [torch.empty_like(param.data) for _ in range(global_size)] + dist.all_gather(all_params, param.data, group=dist.group.WORLD) + if int(global_rank) == 0: + all_params = torch.cat(all_params) + param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name]) + dist.barrier() + + +def record_model_param_shape(model: torch.nn.Module) -> dict: + param_shape = {} + for name, param in model.named_parameters(): + param_shape[name] = param.shape + return param_shape + + +def save( + booster: Booster, + model: nn.Module, + ema: nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + global_step: int, + batch_size: int, + coordinator: DistCoordinator, + save_dir: str, + shape_dict: dict, +): + save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}") + os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "model"), shard=True) + # ema is not boosted, so we don't need to use booster.save_model + model_gathering(ema, shape_dict) + global_rank = dist.get_rank() + if int(global_rank) == 0: + torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt")) + model_sharding(ema) + + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) + if lr_scheduler is not None: + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + running_states = { + "epoch": epoch, + "step": step, + "global_step": global_step, + "sample_start_index": step * batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, "running_states.json")) + dist.barrier() + + +def load( + booster: Booster, model: nn.Module, ema: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str +) -> Tuple[int, int, int]: + booster.load_model(model, os.path.join(load_dir, "model")) + # ema is not boosted, so we don't use booster.load_model + # ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"))) + ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu"))) + booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) + if lr_scheduler is not None: + booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) + running_states = load_json(os.path.join(load_dir, "running_states.json")) + dist.barrier() + return running_states["epoch"], running_states["step"], running_states["sample_start_index"] + + +def create_logger(logging_dir): + """ + Create a logger that writes to a log file and stdout. + """ + if dist.get_rank() == 0: # real logger + logging.basicConfig( + level=logging.INFO, + format="[\033[34m%(asctime)s\033[0m] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")], + ) + logger = logging.getLogger(__name__) + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + + +def load_checkpoint(model, ckpt_path, save_as_pt=True): + if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"): + state_dict = find_model(ckpt_path) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print(f"Missing keys: {missing_keys}") + print(f"Unexpected keys: {unexpected_keys}") + elif os.path.isdir(ckpt_path): + load_from_sharded_state_dict(model, ckpt_path) + if save_as_pt: + save_path = os.path.join(ckpt_path, "model_ckpt.pt") + torch.save(model.state_dict(), save_path) + else: + raise ValueError(f"Invalid checkpoint path: {ckpt_path}") diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py new file mode 100644 index 0000000..48f2d9a --- /dev/null +++ b/opensora/utils/config_utils.py @@ -0,0 +1,96 @@ +import argparse +import json +import os +from glob import glob + +from mmengine.config import Config +from torch.utils.tensorboard import SummaryWriter + + +def parse_args(training=False): + parser = argparse.ArgumentParser() + + # model config + parser.add_argument("config", help="model config file path") + + parser.add_argument("--seed", default=42, type=int, help="generation seed") + parser.add_argument("--ckpt-path", type=str, help="path to model ckpt; will overwrite cfg.ckpt_path if specified") + parser.add_argument("--batch-size", default=None, type=int, help="batch size") + + # ====================================================== + # Inference + # ====================================================== + + if not training: + # prompt + parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file") + parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples") + + # hyperparameters + parser.add_argument("--num-sampling-steps", default=None, type=int, help="sampling steps") + parser.add_argument("--cfg-scale", default=None, type=float, help="balance between cond & uncond") + else: + parser.add_argument("--wandb", default=None, type=bool, help="enable wandb") + parser.add_argument("--load", default=None, type=str, help="path to continue training") + + return parser.parse_args() + + +def merge_args(cfg, args, training=False): + if args.ckpt_path is not None: + cfg.model["from_pratrained"] = args.ckpt_path + args.ckpt_path = None + + if not training: + if args.cfg_scale is not None: + cfg.scheduler["cfg_scale"] = args.cfg_scale + args.cfg_scale = None + + if "multi_resolution" not in cfg: + cfg["multi_resolution"] = False + for k, v in vars(args).items(): + if k in cfg and v is not None: + cfg[k] = v + + return cfg + + +def parse_configs(training=False): + args = parse_args(training) + cfg = Config.fromfile(args.config) + cfg = merge_args(cfg, args, training) + return cfg + + +def create_experiment_workspace(cfg): + """ + This function creates a folder for experiment tracking. + + Args: + args: The parsed arguments. + + Returns: + exp_dir: The path to the experiment folder. + """ + # Make outputs folder (holds all experiment subfolders) + os.makedirs(cfg.outputs, exist_ok=True) + experiment_index = len(glob(f"{cfg.outputs}/*")) + + # Create an experiment folder + model_name = cfg.model["type"].replace("/", "-") + exp_name = f"{experiment_index:03d}-F{cfg.num_frames}S{cfg.frame_interval}-{model_name}" + exp_dir = f"{cfg.outputs}/{exp_name}" + os.makedirs(exp_dir, exist_ok=True) + return exp_name, exp_dir + + +def save_training_config(cfg, experiment_dir): + with open(f"{experiment_dir}/config.txt", "w") as f: + json.dump(cfg, f, indent=4) + + +def create_tensorboard_writer(exp_dir): + tensorboard_dir = f"{exp_dir}/tensorboard" + os.makedirs(tensorboard_dir, exist_ok=True) + writer = SummaryWriter(tensorboard_dir) + return writer diff --git a/opensora/utils/misc.py b/opensora/utils/misc.py new file mode 100644 index 0000000..056b6ae --- /dev/null +++ b/opensora/utils/misc.py @@ -0,0 +1,339 @@ +import collections +import importlib +import logging +import os +import time +from collections import OrderedDict +from collections.abc import Sequence +from itertools import repeat + +import numpy as np +import torch +import torch.distributed as dist + + +def print_rank(var_name, var_value, rank=0): + if dist.get_rank() == rank: + print(f"[Rank {rank}] {var_name}: {var_value}") + + +def print_0(*args, **kwargs): + if dist.get_rank() == 0: + print(*args, **kwargs) + + +def requires_grad(model: torch.nn.Module, flag: bool = True) -> None: + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag + + +def format_numel_str(numel: int) -> str: + B = 1024**3 + M = 1024**2 + K = 1024 + if numel >= B: + return f"{numel / B:.2f} B" + elif numel >= M: + return f"{numel / M:.2f} M" + elif numel >= K: + return f"{numel / K:.2f} K" + else: + return f"{numel}" + + +def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + tensor.div_(dist.get_world_size()) + return tensor + + +def get_model_numel(model: torch.nn.Module) -> (int, int): + num_params = 0 + num_params_trainable = 0 + for p in model.parameters(): + num_params += p.numel() + if p.requires_grad: + num_params_trainable += p.numel() + return num_params, num_params_trainable + + +def try_import(name): + """Try to import a module. + + Args: + name (str): Specifies what module to import in absolute or relative + terms (e.g. either pkg.mod or ..mod). + Returns: + ModuleType or None: If importing successfully, returns the imported + module, otherwise returns None. + """ + try: + return importlib.import_module(name) + except ImportError: + return None + + +def transpose(x): + """ + transpose a list of list + Args: + x (list[list]): + """ + ret = list(map(list, zip(*x))) + return ret + + +def get_timestamp(): + timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime(time.time())) + return timestamp + + +def format_time(seconds): + days = int(seconds / 3600 / 24) + seconds = seconds - days * 3600 * 24 + hours = int(seconds / 3600) + seconds = seconds - hours * 3600 + minutes = int(seconds / 60) + seconds = seconds - minutes * 60 + secondsf = int(seconds) + seconds = seconds - secondsf + millis = int(seconds * 1000) + + f = "" + i = 1 + if days > 0: + f += str(days) + "D" + i += 1 + if hours > 0 and i <= 2: + f += str(hours) + "h" + i += 1 + if minutes > 0 and i <= 2: + f += str(minutes) + "m" + i += 1 + if secondsf > 0 and i <= 2: + f += str(secondsf) + "s" + i += 1 + if millis > 0 and i <= 2: + f += str(millis) + "ms" + i += 1 + if f == "": + f = "0ms" + return f + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not isinstance(data, str): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f"type {type(data)} cannot be converted to tensor.") + + +def to_ndarray(data): + if isinstance(data, torch.Tensor): + return data.numpy() + elif isinstance(data, np.ndarray): + return data + elif isinstance(data, Sequence): + return np.array(data) + elif isinstance(data, int): + return np.ndarray([data], dtype=int) + elif isinstance(data, float): + return np.array([data], dtype=float) + else: + raise TypeError(f"type {type(data)} cannot be converted to ndarray.") + + +def to_torch_dtype(dtype): + if isinstance(dtype, torch.dtype): + return dtype + elif isinstance(dtype, str): + dtype_mapping = { + "float64": torch.float64, + "float32": torch.float32, + "float16": torch.float16, + "fp32": torch.float32, + "fp16": torch.float16, + "half": torch.float16, + "bf16": torch.bfloat16, + } + if dtype not in dtype_mapping: + raise ValueError + dtype = dtype_mapping[dtype] + return dtype + else: + raise ValueError + + +def count_params(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def convert_SyncBN_to_BN2d(model_cfg): + for k in model_cfg: + v = model_cfg[k] + if k == "norm_cfg" and v["type"] == "SyncBN": + v["type"] = "BN2d" + elif isinstance(v, dict): + convert_SyncBN_to_BN2d(v) + + +def get_topk(x, dim=4, k=5): + x = to_tensor(x) + inds = x[..., dim].topk(k)[1] + return x[inds] + + +def param_sigmoid(x, alpha): + ret = 1 / (1 + (-alpha * x).exp()) + return ret + + +def inverse_param_sigmoid(x, alpha, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) / alpha + + +def inverse_sigmoid(x, eps=1e-5): + """Inverse function of sigmoid. + + Args: + x (Tensor): The tensor to do the + inverse. + eps (float): EPS avoid numerical + overflow. Defaults 1e-5. + Returns: + Tensor: The x has passed the inverse + function of sigmoid, has same + shape with input. + """ + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +def count_columns(df, columns): + cnt_dict = OrderedDict() + num_samples = len(df) + + for col in columns: + d_i = df[col].value_counts().to_dict() + for k in d_i: + d_i[k] = (d_i[k], d_i[k] / num_samples) + cnt_dict[col] = d_i + + return cnt_dict + + +def build_logger(work_dir, cfgname): + log_file = cfgname + ".log" + log_path = os.path.join(work_dir, log_file) + + logger = logging.getLogger(cfgname) + logger.setLevel(logging.INFO) + # formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') + formatter = logging.Formatter("%(asctime)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S") + + handler1 = logging.FileHandler(log_path) + handler1.setFormatter(formatter) + + handler2 = logging.StreamHandler() + handler2.setFormatter(formatter) + + logger.addHandler(handler1) + logger.addHandler(handler2) + logger.propagate = False + + return logger + + +timings = {} +from contextlib import contextmanager + + +@contextmanager +def torch_timer(prefix): + import time + + import torch + + torch.cuda.synchronize() + start = time.time() + try: + yield + finally: + torch.cuda.synchronize() + t_diff = (time.time() - start) * 1000 + if prefix not in timings: + timings[prefix] = [] + + timings[prefix].append(t_diff) + + num_ignored = 10 + + if len(timings[prefix]) > num_ignored: + # avg = sum(timings[prefix][num_ignored:]) / (len(timings[prefix]) - num_ignored) + avg = sum(timings[prefix][-num_ignored:]) / num_ignored + print("{}: {} ({})".format(prefix, t_diff, avg)) + else: + print("{}: {}".format(prefix, t_diff)) + + +def strip_dc(x): + """ + strip DataContainer + """ + try: + import mmcv + except ImportError: + pass + + if isinstance(x, dict): + res = {} + for k, v in x.items(): + res[k] = strip_dc(v) + return res + if isinstance(x, (list, tuple)) and isinstance(x[0], mmcv.parallel.DataContainer): + return strip_dc(x[0]) + elif isinstance(x, mmcv.parallel.DataContainer): + return strip_dc(x.data) + return x diff --git a/opensora/utils/train_utils.py b/opensora/utils/train_utils.py new file mode 100644 index 0000000..8a5c90f --- /dev/null +++ b/opensora/utils/train_utils.py @@ -0,0 +1,34 @@ +from collections import OrderedDict + +import torch + + +@torch.no_grad() +def update_ema( + ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True +) -> None: + """ + Step the EMA model towards the current model. + """ + ema_params = OrderedDict(ema_model.named_parameters()) + model_params = OrderedDict(model.named_parameters()) + + for name, param in model_params.items(): + if name == "pos_embed": + continue + if param.requires_grad == False: + continue + if not sharded: + param_data = param.data + ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) + else: + if param.data.dtype != torch.float32: + param_id = id(param) + master_param = optimizer._param_store.working_to_master_param[param_id] + param_data = master_param.data + else: + param_data = param.data + ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) + + + diff --git a/tools/data/README.md b/tools/data/README.md new file mode 100644 index 0000000..e69de29 diff --git a/tools/data/__init__.py b/tools/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/data/csvutil.py b/tools/data/csvutil.py new file mode 100644 index 0000000..9bd5db2 --- /dev/null +++ b/tools/data/csvutil.py @@ -0,0 +1,96 @@ +import argparse +import csv +import os + +from tqdm import tqdm + +# path, name, #frames +PREFIX = [ + "The video shows", + "The video captures", + "The video features", + "The video depicts", + "The video presents", + "The video features", + "The video is ", + "In the video,", +] + + +def get_video_length(path): + import cv2 + + cap = cv2.VideoCapture(path) + return int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + +def main(args): + input_path = args.input + output_path = args.output + if output_path is None: + name = os.path.basename(input_path) + name, ext = os.path.splitext(name) + if args.num_frames_min is not None: + name += f"_fmin_{args.num_frames_min}" + if args.num_frames_max is not None: + name += f"_fmax_{args.num_frames_max}" + if args.filter_null_text: + name += "_fnt" + if args.remove_prefix: + name += "_rp" + if args.root is not None: + name += f"_root" + if args.relength: + name += "_relength" + output_path = os.path.join(os.path.dirname(input_path), name + ext) + + with open(input_path, "r") as f: + reader = csv.reader(f) + data = list(reader) + print("Number of videos before filtering:", len(data)) + + data_new = [] + for i, row in tqdm(enumerate(data)): + path = row[0] + caption = row[1] + n_frames = int(row[2]) + if args.num_frames_min is not None and n_frames < args.num_frames_min: + continue + if args.num_frames_max is not None and n_frames > args.num_frames_max: + continue + if args.filter_null_text and len(caption) == 0: + continue + if args.remove_prefix: + for prefix in PREFIX: + if caption.startswith(prefix): + caption = caption[len(prefix) :].strip() + if caption[0].islower(): + caption = caption[0].upper() + caption[1:] + row[1] = caption + break + if args.root is not None: + row[0] = os.path.join(args.root, path) + if args.relength: + n_frames = get_video_length(row[0]) + row[2] = n_frames + data_new.append(row) + + print("Number of videos after filtering:", len(data_new)) + with open(output_path, "w") as f: + writer = csv.writer(f) + writer.writerows(data_new) + print("Output saved to", output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, required=True) + parser.add_argument("--output", type=str, default=None) + parser.add_argument("--num_frames_min", type=int, default=None) + parser.add_argument("--num_frames_max", type=int, default=None) + parser.add_argument("--filter_null_text", action="store_true") + parser.add_argument("--remove_prefix", action="store_true") + parser.add_argument("--root", type=str, default=None) + parser.add_argument("--relength", action="store_true") + args = parser.parse_args() + main(args) diff --git a/tools/data/get_caption_gpt4.py b/tools/data/get_caption_gpt4.py new file mode 100644 index 0000000..8ca973f --- /dev/null +++ b/tools/data/get_caption_gpt4.py @@ -0,0 +1,82 @@ +import argparse +import base64 +import csv +import os + +import cv2 +import requests +import tqdm + +# OpenAI API Key +api_key = "" + + +def get_caption(frame): + prompt = "The middle frame from a video clip are given. Describe this video and its style to generate a description for the video. The description should be useful for AI to re-generate the video. Here are some examples of good descriptions:\n\n 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.\n2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.\n 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway." + + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + payload = { + "model": "gpt-4-vision-preview", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt, + }, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame}"}}, + ], + } + ], + "max_tokens": 300, + } + response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=60) + caption = response.json()["choices"][0]["message"]["content"] + return caption + + +# Function to encode the image +def encode_image(image_path): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + +def extract_frames(video_path): + cap = cv2.VideoCapture(video_path) + length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + point = length // 2 + cap.set(cv2.CAP_PROP_POS_FRAMES, point) + ret, frame = cap.read() + _, buffer = cv2.imencode(".jpg", frame) + img_base64 = base64.b64encode(buffer).decode("utf-8") + return img_base64 + + +def main(args): + processed_videos = [] + if os.path.exists(args.output_file): + with open(args.output_file, "r") as f: + reader = csv.reader(f) + samples = list(reader) + processed_videos = [sample[0] for sample in samples] + + f = open(args.output_file, "a") + writer = csv.writer(f) + for video in tqdm.tqdm(os.listdir(args.video_folder)): + if video in processed_videos: + continue + video_path = os.path.join(args.video_folder, video) + base64_image = extract_frames(video_path) + caption = get_caption(base64_image) + caption = caption.replace("\n", " ") + writer.writerow([video, caption]) + f.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--video_folder", type=str, required=True, help="Path to the folder containing the videos.") + parser.add_argument("--output_file", type=str, default="video_captions.csv", help="Path to the output file.") + args = parser.parse_args() + main(args) diff --git a/tools/data/get_caption_llava.py b/tools/data/get_caption_llava.py new file mode 100644 index 0000000..bb54974 --- /dev/null +++ b/tools/data/get_caption_llava.py @@ -0,0 +1,382 @@ +import argparse +import csv +import os + +import cv2 +import torch +from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX +from llava.conversation import conv_templates +from llava.eval.run_llava import eval_model +from llava.mm_utils import get_anyres_image_grid_shape, get_model_name_from_path, process_images, tokenizer_image_token +from llava.model.builder import load_pretrained_model +from llava.model.llava_arch import unpad_image +from llava.utils import disable_torch_init +from PIL import Image +from tqdm import tqdm + +disable_torch_init() + +prompts = { + "three_frames": "A video is given by providing three frames in chronological order. Describe this video and its style to generate a description. Pay attention to all objects in the video. Do not describe each frame individually. Do not reply with words like 'first frame'. The description should be useful for AI to re-generate the video. The description should be less than six sentences. Here are some examples of good descriptions: 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.", +} + + +def get_filelist(file_path): + Filelist = [] + for home, dirs, files in os.walk(file_path): + for filename in files: + Filelist.append(os.path.join(home, filename)) + return Filelist + + +def get_video_length(cap): + return int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + +def extract_frames(video_path, points=[0.2, 0.5, 0.8]): + cap = cv2.VideoCapture(video_path) + length = get_video_length(cap) + points = [int(length * point) for point in points] + frames = [] + if length < 3: + return frames, length + for point in points: + cap.set(cv2.CAP_PROP_POS_FRAMES, point) + ret, frame = cap.read() + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + frames.append(frame) + return frames, length + + +def prepare_inputs_labels_for_multimodal( + self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None +): + # llava_arch.py + vision_tower = self.get_vision_tower() + if vision_tower is None or images is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + if type(images) is list or images.ndim == 5: + if type(images) is list: + images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] + concat_images = torch.cat([image for image in images], dim=0) + image_features = self.encode_images(concat_images) + split_sizes = [image.shape[0] for image in images] + image_features = torch.split(image_features, split_sizes, dim=0) + mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") + image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") + if mm_patch_merge_type == "flat": + image_features = [x.flatten(0, 1) for x in image_features] + elif mm_patch_merge_type.startswith("spatial"): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.get_vision_tower().num_patches_per_side + assert height * width == base_image_feature.shape[0] + if image_aspect_ratio == "anyres": + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.get_vision_tower().config.image_size, + ) + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + else: + raise NotImplementedError + if "unpad" in mm_patch_merge_type: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.model.image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + else: + image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() + image_feature = image_feature.flatten(0, 3) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + if "unpad" in mm_patch_merge_type: + image_feature = torch.cat( + (image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0 + ) + new_image_features.append(image_feature) + image_features = new_image_features + else: + raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") + else: + image_features = self.encode_images(images) + + # TODO: image start / end is not implemented here to support pretraining. + if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False): + raise NotImplementedError + + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [ + cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) + ] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_image_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + if num_images == 0: + cur_image_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_image_idx += 1 + continue + + image_token_indices = ( + [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] + ) + cur_input_ids_noim = [] + cur_labels = labels[batch_idx] + cur_labels_noim = [] + for i in range(len(image_token_indices) - 1): + cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) + cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) + split_sizes = [x.shape[0] for x in cur_labels_noim] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) + cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_im[i]) + cur_new_labels.append(cur_labels_noim[i]) + if i < num_images: + cur_image_features = image_features[cur_image_idx] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append( + torch.full( + (cur_image_features.shape[0],), + IGNORE_INDEX, + device=cur_labels.device, + dtype=cur_labels.dtype, + ) + ) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as image embeddings can make the sequence longer + tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) + if tokenizer_model_max_length is not None: + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full( + (batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device + ) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, "tokenizer_padding_side", "right") == "left": + new_input_embeds_padded.append( + torch.cat( + ( + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + cur_new_embed, + ), + dim=0, + ) + ) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + else: + new_input_embeds_padded.append( + torch.cat( + ( + cur_new_embed, + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + ), + dim=0, + ) + ) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels + + +@torch.inference_mode() +def main(args): + bs = args.bs + video_folder = args.video_folder + + processed_videos = [] + if os.path.exists(args.output_file): + with open(args.output_file, "r") as f: + reader = csv.reader(f) + samples = list(reader) + processed_videos = [sample[0] for sample in samples] + f = open(args.output_file, "a") + writer = csv.writer(f) + + model_path = "liuhaotian/llava-v1.6-34b" + query = prompts["three_frames"] + conv = conv_templates["chatml_direct"].copy() + conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + query) + prompt = conv.get_prompt() + + tokenizer, model, image_processor, context_len = load_pretrained_model( + model_path=model_path, + model_base=None, + model_name=get_model_name_from_path(model_path), + ) + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") + input_ids = input_ids.unsqueeze(0).to(model.device) + + videos = get_filelist(video_folder) + print(f"Dataset contains {len(videos)} videos.") + videos = [video for video in videos if video not in processed_videos] + print(f"Processing {len(videos)} new videos.") + for i in tqdm(range(0, len(videos), bs)): + # prepare a batch of inputs + video_files = videos[i : i + bs] + frames = [] + video_lengths = [] + for video_file in video_files: + frame, length = extract_frames(os.path.join(video_folder, video_file)) + if len(frame) < 3: + continue + frames.append(frame) + video_lengths.append(length) + if len(frames) == 0: + continue + + # encode the batch of inputs + samples = [] + for imgs in frames: + imgs_size = [img.size for img in imgs] + imgs = process_images(imgs, image_processor, model.config) + imgs = imgs.to(model.device, dtype=torch.float16) + with torch.inference_mode(): + _, _, _, _, inputs_embeds, _ = prepare_inputs_labels_for_multimodal( + model, input_ids, None, None, None, None, images=imgs, image_sizes=imgs_size + ) + samples.append(inputs_embeds) + + # padding + max_len = max([sample.shape[1] for sample in samples]) + attention_mask = torch.tensor( + [[0] * (max_len - samples[i].shape[1]) + [1] * samples[i].shape[1] for i in range(len(samples))] + ).to(model.device) + inputs_embeds = [ + torch.cat( + [ + torch.zeros( + (1, max_len - samples[i].shape[1], samples[i].shape[-1]), + device=model.device, + dtype=torch.float16, + ), + samples[i], + ], + dim=1, + ) + for i in range(len(samples)) + ] + inputs_embeds = torch.cat(inputs_embeds, dim=0) + + # generate outputs + output_ids = super(type(model), model).generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + do_sample=True, + temperature=0.2, + max_new_tokens=512, + use_cache=True, + ) + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + outputs = [output.replace("\n", " ").strip() for output in outputs] + + # save results + result = list(zip(video_files, outputs, video_lengths)) + for t in result: + writer.writerow(t) + + f.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--video_folder", type=str, required=True) + parser.add_argument("--bs", type=int, default=32) + parser.add_argument("--output_file", type=str, default="video_captions.csv") + args = parser.parse_args() + main(args) diff --git a/tools/data/scene_detect.py b/tools/data/scene_detect.py new file mode 100644 index 0000000..115f683 --- /dev/null +++ b/tools/data/scene_detect.py @@ -0,0 +1,139 @@ +import os +from tqdm import tqdm +from multiprocessing import Pool +from mmengine.logging import MMLogger, print_log + +from scenedetect import detect, ContentDetector + +from .utils import check_mp4_integrity, split_video +from .utils import clone_folder_structure, iterate_files, iterate_folders +from opensora.utils.misc import get_timestamp + + +# config +target_fps = 30 # int +shorter_size = 512 # int +min_seconds = 1 # float +max_seconds = 5 # float +assert max_seconds > min_seconds +cfg = dict( + target_fps=target_fps, + min_seconds=min_seconds, + max_seconds=max_seconds, + shorter_size=shorter_size, +) + + +def process_folder(root_src, root_dst): + # create logger + folder_path_log = os.path.dirname(root_dst) + log_name = os.path.basename(root_dst) + timestamp = get_timestamp() + log_path = os.path.join(folder_path_log, f'{log_name}_{timestamp}.log') + logger = MMLogger.get_instance(log_name, log_file=log_path) + + # clone folder structure + clone_folder_structure(root_src, root_dst) + + # all source videos + mp4_list = [x for x in iterate_files(root_src) if x.endswith('.mp4')] + mp4_list = sorted(mp4_list) + + for idx, sample_path in tqdm(enumerate(mp4_list)): + folder_src = os.path.dirname(sample_path) + folder_dst = os.path.join(root_dst, os.path.relpath(folder_src, root_src)) + + # check src video integrity + if not check_mp4_integrity(sample_path, logger=logger): + continue + + # detect scenes + scene_list = detect(sample_path, ContentDetector(), start_in_scene=True) + + # split scenes + save_path_list = split_video(sample_path, scene_list, save_dir=folder_dst, **cfg, logger=logger) + + # check integrity of generated clips + for x in save_path_list: + check_mp4_integrity(x, logger=logger) + + +def scene_detect(): + """ detect & cut scenes using a single process + Expected dataset structure: + data/ + your_dataset/ + raw_videos/ + xxx.mp4 + yyy.mp4 + + This function results in: + data/ + your_dataset/ + raw_videos/ + xxx.mp4 + yyy.mp4 + zzz.mp4 + clips/ + xxx_scene-0.mp4 + yyy_scene-0.mp4 + yyy_scene-1.mp4 + """ + # TODO: specify your dataset root + root_src = f'./data/your_dataset/raw_videos' + root_dst = f'./data/your_dataset/clips' + + process_folder(root_src, root_dst) + + +def scene_detect_mp(): + """ detect & cut scenes using multiple processes + Expected dataset structure: + data/ + your_dataset/ + raw_videos/ + split_0/ + xxx.mp4 + yyy.mp4 + split_1/ + xxx.mp4 + yyy.mp4 + + This function results in: + data/ + your_dataset/ + raw_videos/ + split_0/ + xxx.mp4 + yyy.mp4 + split_1/ + xxx.mp4 + yyy.mp4 + clips/ + split_0/ + xxx_scene-0.mp4 + yyy_scene-0.mp4 + split_1/ + xxx_scene-0.mp4 + yyy_scene-0.mp4 + yyy_scene-1.mp4 + """ + # TODO: specify your dataset root + root_src = f'./data/your_dataset/raw_videos' + root_dst = f'./data/your_dataset/clips' + + # TODO: specify your splits + splits = ['split_0', 'split_1'] + + # process folders + root_src_list = [os.path.join(root_src, x) for x in splits] + root_dst_list = [os.path.join(root_dst, x) for x in splits] + + with Pool(processes=len(splits)) as pool: + pool.starmap(process_folder, list(zip(root_src_list, root_dst_list))) + + +if __name__ == '__main__': + # TODO: choose single process or multiprocessing + scene_detect() + # scene_detect_mp() diff --git a/tools/data/to_csv_imagenet.py b/tools/data/to_csv_imagenet.py new file mode 100644 index 0000000..8b18faf --- /dev/null +++ b/tools/data/to_csv_imagenet.py @@ -0,0 +1,17 @@ +import csv +import os + +from torchvision.datasets import ImageNet + +root = "~/data/imagenet" +split = "train" + +root = os.path.expanduser(root) +data = ImageNet(root, split=split) +samples = [(path, data.classes[label][0]) for path, label in data.samples] + +with open(f"preprocess/imagenet_{split}.csv", "w") as f: + writer = csv.writer(f) + writer.writerows(samples) + +print(f"Saved {len(samples)} samples to preprocess/imagenet_{split}.csv.") diff --git a/tools/data/to_csv_ucf101.py b/tools/data/to_csv_ucf101.py new file mode 100644 index 0000000..fd99b7e --- /dev/null +++ b/tools/data/to_csv_ucf101.py @@ -0,0 +1,36 @@ +import csv +import os + + +def get_filelist(file_path): + Filelist = [] + for home, dirs, files in os.walk(file_path): + for filename in files: + Filelist.append(os.path.join(home, filename)) + return Filelist + + +def split_by_capital(name): + # BoxingPunchingBag -> Boxing Punching Bag + new_name = "" + for i in range(len(name)): + if name[i].isupper() and i != 0: + new_name += " " + new_name += name[i] + return new_name + + +root = "~/data/ucf101" +split = "videos" + +root = os.path.expanduser(root) +video_lists = get_filelist(os.path.join(root, split)) +classes = [x.split("/")[-2] for x in video_lists] +classes = [split_by_capital(x) for x in classes] +samples = list(zip(video_lists, classes)) + +with open(f"preprocess/ucf101_{split}.csv", "w") as f: + writer = csv.writer(f) + writer.writerows(samples) + +print(f"Saved {len(samples)} samples to preprocess/ucf101_{split}.csv.") diff --git a/tools/data/utils.py b/tools/data/utils.py new file mode 100644 index 0000000..b65270f --- /dev/null +++ b/tools/data/utils.py @@ -0,0 +1,148 @@ +import os +import cv2 +import subprocess +from mmengine.logging import print_log + +from moviepy.editor import VideoFileClip +from imageio_ffmpeg import get_ffmpeg_exe +from scenedetect import FrameTimecode + + +def iterate_files(folder_path): + for root, dirs, files in os.walk(folder_path): + # root contains the current directory path + # dirs contains the list of subdirectories in the current directory + # files contains the list of files in the current directory + + # Process files in the current directory + for file in files: + file_path = os.path.join(root, file) + # print("File:", file_path) + yield file_path + + # Process subdirectories and recursively call the function + for subdir in dirs: + subdir_path = os.path.join(root, subdir) + # print("Subdirectory:", subdir_path) + iterate_files(subdir_path) + + +def iterate_folders(folder_path): + for root, dirs, files in os.walk(folder_path): + for subdir in dirs: + subdir_path = os.path.join(root, subdir) + yield subdir_path + # print("Subdirectory:", subdir_path) + iterate_folders(subdir_path) + + +def clone_folder_structure(root_src, root_dst, verbose=False): + src_path_list = iterate_folders(root_src) + src_relpath_list = [os.path.relpath(x, root_src) for x in src_path_list] + + os.makedirs(root_dst, exist_ok=True) + dst_path_list = [os.path.join(root_dst, x) for x in src_relpath_list] + for folder_path in dst_path_list: + os.makedirs(folder_path, exist_ok=True) + if verbose: + print(f'Create folder: \'{folder_path}\'') + + +def count_files(root, suffix='.mp4'): + files_list = iterate_files(root) + cnt = len([x for x in files_list if x.endswith(suffix)]) + return cnt + + +def check_mp4_integrity(file_path, verbose=True, logger=None): + try: + video_clip = VideoFileClip(file_path) + if verbose: + print_log(f'The MP4 file \'{file_path}\' is intact.', logger=logger) + return True + except Exception as e: + if verbose: + print_log(f'Error: {e}', logger=logger) + print_log(f'The MP4 file \'{file_path}\' is not intact.', logger=logger) + return False + + +def count_frames(video_path): + cap = cv2.VideoCapture(video_path) + + if not cap.isOpened(): + print(f"Error: Could not open video file '{video_path}'") + return + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + print(f"Total frames in the video '{video_path}': {total_frames}") + + cap.release() + + +def split_video(sample_path, + scene_list, + save_dir, + target_fps=30, + min_seconds=1, + max_seconds=10, + shorter_size=512, + verbose=False, + logger=None, + ): + FFMPEG_PATH = get_ffmpeg_exe() + + save_path_list = [] + for idx, scene in enumerate(scene_list): + s, t = scene # FrameTimecode + fps = s.framerate + max_duration = FrameTimecode(timecode='00:00:00', fps=fps) + max_duration.frame_num = round(fps * max_seconds) + duration = min(max_duration, t - s) + if duration.get_frames() < round(min_seconds * fps): + continue + + # save path + fname = os.path.basename(sample_path) + fname_wo_ext = os.path.splitext(fname)[0] + # TODO: fname pattern + save_path = os.path.join(save_dir, f'{fname_wo_ext}_scene-{idx}.mp4') + + # ffmpeg cmd + cmd = [FFMPEG_PATH] + + # Only show ffmpeg output for the first call, which will display any + # errors if it fails, and then break the loop. We only show error messages + # for the remaining calls. + # cmd += ['-v', 'error'] + + # input path + cmd += ['-i', sample_path] + + # clip to cut + cmd += [ + '-nostdin', '-y', + '-ss', str(s.get_seconds()), + '-t', str(duration.get_seconds()) + ] + + # target fps + # cmd += ['-vf', 'select=mod(n\,2)'] + cmd += ['-r', f'{target_fps}'] + + # aspect ratio + cmd += ['-vf', f"scale='if(gt(iw,ih),-2,{shorter_size})':'if(gt(iw,ih),{shorter_size},-2)'"] + # cmd += ['-vf', f"scale='if(gt(iw,ih),{shorter_size},trunc(ow/a/2)*2)':-2"] + + cmd += ['-map', '0', save_path] + + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + stdout, stderr = proc.communicate() + if verbose: + stdout = stdout.decode('utf-8') + print_log(stdout, logger=logger) + + save_path_list.append(sample_path) + print_log(f'Video clip saved to \'{save_path}\'', logger=logger) + + return save_path_list