add speediffusion

This commit is contained in:
Zangwei Zheng 2024-03-25 22:20:10 +08:00
parent b1496b3e17
commit 0dcf1ff888
5 changed files with 144 additions and 8 deletions

View file

@ -0,0 +1,54 @@
num_frames = 16
frame_interval = 3
image_size = (256, 256)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
# Define acceleration
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
model = dict(
type="STDiT-XL/2",
space_scale=0.5,
time_scale=1.0,
from_pretrained="PixArt-XL-2-512x512.pth",
enable_flashattn=True,
enable_layernorm_kernel=True,
)
mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07]
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=120,
shardformer=True,
)
scheduler = dict(
type="iddpm-speed",
timestep_respacing="",
)
# Others
seed = 42
outputs = "outputs"
wandb = False
epochs = 1000
log_every = 10
ckpt_every = 1000
load = None
batch_size = 8
lr = 2e-5
grad_clip = 1.0

View file

@ -6,6 +6,7 @@ from opensora.registry import SCHEDULERS
from . import gaussian_diffusion as gd
from .respace import SpacedDiffusion, space_timesteps
from .speed import SpeeDiffusion
@SCHEDULERS.register_module("iddpm")

View file

@ -697,7 +697,7 @@ class GaussianDiffusion:
output = th.where((t == 0), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"]}
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, mask=None):
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, mask=None, weights=None):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
@ -763,7 +763,11 @@ class GaussianDiffusion:
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
assert model_output.shape == target.shape == x_start.shape
terms["mse"] = mean_flat((target - model_output) ** 2, mask=mask)
if weights is None:
terms["mse"] = mean_flat((target - model_output) ** 2, mask=mask)
else:
weight = _extract_into_tensor(weights, t, target.shape)
terms["mse"] = mean_flat(weight * (target - model_output) ** 2, mask=mask)
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else:

View file

@ -0,0 +1,75 @@
import numpy as np
import torch
import torch.nn.functional as F
from opensora.registry import SCHEDULERS
from . import gaussian_diffusion as gd
from .respace import SpacedDiffusion, space_timesteps
@SCHEDULERS.register_module("iddpm-speed")
class SpeeDiffusion(SpacedDiffusion):
def __init__(
self,
num_sampling_steps=None,
timestep_respacing=None,
noise_schedule="linear",
use_kl=False,
sigma_small=False,
predict_xstart=False,
learn_sigma=True,
rescale_learned_sigmas=False,
diffusion_steps=1000,
cfg_scale=4.0,
):
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if num_sampling_steps is not None:
assert timestep_respacing is None
timestep_respacing = str(num_sampling_steps)
if timestep_respacing is None or timestep_respacing == "":
timestep_respacing = [diffusion_steps]
super().__init__(
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
betas=betas,
model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X),
model_var_type=(
(gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type,
)
self.cfg_scale = cfg_scale
grad = np.gradient(self.sqrt_one_minus_alphas_cumprod)
self.meaningful_steps = np.argmax(grad < 5e-5) + 1
# p2 weighting from: Perception Prioritized Training of Diffusion Models
self.p2_gamma = 1
self.p2_k = 1
self.snr = 1.0 / (1 - self.alphas_cumprod) - 1
sqrt_one_minus_alphas_bar = torch.from_numpy(self.sqrt_one_minus_alphas_cumprod)
p = torch.tanh(1e6 * (torch.gradient(sqrt_one_minus_alphas_bar)[0] - 1e-4)) + 1.5
self.p = F.normalize(p, p=1, dim=0)
self.weights = 1 / (self.p2_k + self.snr) ** self.p2_gamma
def t_sample(self, n, device):
t = torch.multinomial(self.p, n // 2 + 1, replacement=True).to(device)
dual_t = torch.where(t < self.meaningful_steps, self.meaningful_steps - t, t - self.meaningful_steps)
t = torch.cat([t, dual_t], dim=0)[:n]
return t
def training_losses(self, model, x, t, *args, **kwargs): # pylint: disable=signature-differs
t = self.t_sample(x.shape[0], x.device)
return super().training_losses(model, x, t, weights=self.weights, *args, **kwargs)
def sample(self, *args, **kwargs):
raise NotImplementedError("SpeeDiffusion is only for training")

View file

@ -1,3 +1,4 @@
import math
import random
from collections import OrderedDict
@ -33,13 +34,14 @@ def update_ema(
class MaskGenerator:
def __init__(self, mask_ratios):
def __init__(self, mask_ratios, condition_frames=4):
self.mask_name = ["mask_no", "mask_random", "mask_head", "mask_tail", "mask_head_tail"]
assert len(mask_ratios) == len(self.mask_name)
assert sum(mask_ratios) == 1.0
assert math.isclose(sum(mask_ratios), 1.0, abs_tol=1e-6)
self.mask_prob = mask_ratios
print(self.mask_prob)
self.mask_acc_prob = [sum(self.mask_prob[: i + 1]) for i in range(len(self.mask_prob))]
self.condition_frames = condition_frames
def get_mask(self, x):
mask_type = random.random()
@ -50,18 +52,18 @@ class MaskGenerator:
mask = torch.ones(x.shape[2], dtype=torch.bool, device=x.device)
if mask_name == "mask_random":
random_size = random.randint(1, 4)
random_size = random.randint(1, self.condition_frames)
random_pos = random.randint(0, x.shape[2] - random_size)
mask[random_pos : random_pos + random_size] = 0
return mask
elif mask_name == "mask_head":
random_size = random.randint(1, 4)
random_size = random.randint(1, self.condition_frames)
mask[:random_size] = 0
elif mask_name == "mask_tail":
random_size = random.randint(1, 4)
random_size = random.randint(1, self.condition_frames)
mask[-random_size:] = 0
elif mask_name == "mask_head_tail":
random_size = random.randint(1, 4)
random_size = random.randint(1, self.condition_frames)
mask[:random_size] = 0
mask[-random_size:] = 0