From 0dcf1ff888d35a37df1e81605ba35f3bf0533819 Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Mon, 25 Mar 2024 22:20:10 +0800 Subject: [PATCH] add speediffusion --- configs/opensora/train/16x256x256-spee.py | 54 +++++++++++++ opensora/schedulers/iddpm/__init__.py | 1 + .../schedulers/iddpm/gaussian_diffusion.py | 8 +- opensora/schedulers/iddpm/speed.py | 75 +++++++++++++++++++ opensora/utils/train_utils.py | 14 ++-- 5 files changed, 144 insertions(+), 8 deletions(-) create mode 100644 configs/opensora/train/16x256x256-spee.py create mode 100644 opensora/schedulers/iddpm/speed.py diff --git a/configs/opensora/train/16x256x256-spee.py b/configs/opensora/train/16x256x256-spee.py new file mode 100644 index 0000000..12f06c9 --- /dev/null +++ b/configs/opensora/train/16x256x256-spee.py @@ -0,0 +1,54 @@ +num_frames = 16 +frame_interval = 3 +image_size = (256, 256) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=0.5, + time_scale=1.0, + from_pretrained="PixArt-XL-2-512x512.pth", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07] +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="iddpm-speed", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 1000 +load = None + +batch_size = 8 +lr = 2e-5 +grad_clip = 1.0 diff --git a/opensora/schedulers/iddpm/__init__.py b/opensora/schedulers/iddpm/__init__.py index 4542e6c..6989e64 100644 --- a/opensora/schedulers/iddpm/__init__.py +++ b/opensora/schedulers/iddpm/__init__.py @@ -6,6 +6,7 @@ from opensora.registry import SCHEDULERS from . import gaussian_diffusion as gd from .respace import SpacedDiffusion, space_timesteps +from .speed import SpeeDiffusion @SCHEDULERS.register_module("iddpm") diff --git a/opensora/schedulers/iddpm/gaussian_diffusion.py b/opensora/schedulers/iddpm/gaussian_diffusion.py index da233c3..b099de7 100644 --- a/opensora/schedulers/iddpm/gaussian_diffusion.py +++ b/opensora/schedulers/iddpm/gaussian_diffusion.py @@ -697,7 +697,7 @@ class GaussianDiffusion: output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} - def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, mask=None): + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, mask=None, weights=None): """ Compute training losses for a single timestep. :param model: the model to evaluate loss on. @@ -763,7 +763,11 @@ class GaussianDiffusion: ModelMeanType.EPSILON: noise, }[self.model_mean_type] assert model_output.shape == target.shape == x_start.shape - terms["mse"] = mean_flat((target - model_output) ** 2, mask=mask) + if weights is None: + terms["mse"] = mean_flat((target - model_output) ** 2, mask=mask) + else: + weight = _extract_into_tensor(weights, t, target.shape) + terms["mse"] = mean_flat(weight * (target - model_output) ** 2, mask=mask) if "vb" in terms: terms["loss"] = terms["mse"] + terms["vb"] else: diff --git a/opensora/schedulers/iddpm/speed.py b/opensora/schedulers/iddpm/speed.py new file mode 100644 index 0000000..3231d02 --- /dev/null +++ b/opensora/schedulers/iddpm/speed.py @@ -0,0 +1,75 @@ +import numpy as np +import torch +import torch.nn.functional as F + +from opensora.registry import SCHEDULERS + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps + + +@SCHEDULERS.register_module("iddpm-speed") +class SpeeDiffusion(SpacedDiffusion): + def __init__( + self, + num_sampling_steps=None, + timestep_respacing=None, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000, + cfg_scale=4.0, + ): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if num_sampling_steps is not None: + assert timestep_respacing is None + timestep_respacing = str(num_sampling_steps) + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + super().__init__( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), + model_var_type=( + (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + ) + + self.cfg_scale = cfg_scale + + grad = np.gradient(self.sqrt_one_minus_alphas_cumprod) + self.meaningful_steps = np.argmax(grad < 5e-5) + 1 + + # p2 weighting from: Perception Prioritized Training of Diffusion Models + self.p2_gamma = 1 + self.p2_k = 1 + self.snr = 1.0 / (1 - self.alphas_cumprod) - 1 + sqrt_one_minus_alphas_bar = torch.from_numpy(self.sqrt_one_minus_alphas_cumprod) + p = torch.tanh(1e6 * (torch.gradient(sqrt_one_minus_alphas_bar)[0] - 1e-4)) + 1.5 + self.p = F.normalize(p, p=1, dim=0) + self.weights = 1 / (self.p2_k + self.snr) ** self.p2_gamma + + def t_sample(self, n, device): + t = torch.multinomial(self.p, n // 2 + 1, replacement=True).to(device) + dual_t = torch.where(t < self.meaningful_steps, self.meaningful_steps - t, t - self.meaningful_steps) + t = torch.cat([t, dual_t], dim=0)[:n] + return t + + def training_losses(self, model, x, t, *args, **kwargs): # pylint: disable=signature-differs + t = self.t_sample(x.shape[0], x.device) + return super().training_losses(model, x, t, weights=self.weights, *args, **kwargs) + + def sample(self, *args, **kwargs): + raise NotImplementedError("SpeeDiffusion is only for training") diff --git a/opensora/utils/train_utils.py b/opensora/utils/train_utils.py index c9b54d9..1bd4ea4 100644 --- a/opensora/utils/train_utils.py +++ b/opensora/utils/train_utils.py @@ -1,3 +1,4 @@ +import math import random from collections import OrderedDict @@ -33,13 +34,14 @@ def update_ema( class MaskGenerator: - def __init__(self, mask_ratios): + def __init__(self, mask_ratios, condition_frames=4): self.mask_name = ["mask_no", "mask_random", "mask_head", "mask_tail", "mask_head_tail"] assert len(mask_ratios) == len(self.mask_name) - assert sum(mask_ratios) == 1.0 + assert math.isclose(sum(mask_ratios), 1.0, abs_tol=1e-6) self.mask_prob = mask_ratios print(self.mask_prob) self.mask_acc_prob = [sum(self.mask_prob[: i + 1]) for i in range(len(self.mask_prob))] + self.condition_frames = condition_frames def get_mask(self, x): mask_type = random.random() @@ -50,18 +52,18 @@ class MaskGenerator: mask = torch.ones(x.shape[2], dtype=torch.bool, device=x.device) if mask_name == "mask_random": - random_size = random.randint(1, 4) + random_size = random.randint(1, self.condition_frames) random_pos = random.randint(0, x.shape[2] - random_size) mask[random_pos : random_pos + random_size] = 0 return mask elif mask_name == "mask_head": - random_size = random.randint(1, 4) + random_size = random.randint(1, self.condition_frames) mask[:random_size] = 0 elif mask_name == "mask_tail": - random_size = random.randint(1, 4) + random_size = random.randint(1, self.condition_frames) mask[-random_size:] = 0 elif mask_name == "mask_head_tail": - random_size = random.randint(1, 4) + random_size = random.randint(1, self.condition_frames) mask[:random_size] = 0 mask[-random_size:] = 0