mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
add speediffusion
This commit is contained in:
parent
b1496b3e17
commit
0dcf1ff888
54
configs/opensora/train/16x256x256-spee.py
Normal file
54
configs/opensora/train/16x256x256-spee.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
num_frames = 16
|
||||
frame_interval = 3
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
num_workers = 4
|
||||
|
||||
# Define acceleration
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="STDiT-XL/2",
|
||||
space_scale=0.5,
|
||||
time_scale=1.0,
|
||||
from_pretrained="PixArt-XL-2-512x512.pth",
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07]
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="DeepFloyd/t5-v1_1-xxl",
|
||||
model_max_length=120,
|
||||
shardformer=True,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="iddpm-speed",
|
||||
timestep_respacing="",
|
||||
)
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
|
||||
epochs = 1000
|
||||
log_every = 10
|
||||
ckpt_every = 1000
|
||||
load = None
|
||||
|
||||
batch_size = 8
|
||||
lr = 2e-5
|
||||
grad_clip = 1.0
|
||||
|
|
@ -6,6 +6,7 @@ from opensora.registry import SCHEDULERS
|
|||
|
||||
from . import gaussian_diffusion as gd
|
||||
from .respace import SpacedDiffusion, space_timesteps
|
||||
from .speed import SpeeDiffusion
|
||||
|
||||
|
||||
@SCHEDULERS.register_module("iddpm")
|
||||
|
|
|
|||
|
|
@ -697,7 +697,7 @@ class GaussianDiffusion:
|
|||
output = th.where((t == 0), decoder_nll, kl)
|
||||
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, mask=None):
|
||||
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, mask=None, weights=None):
|
||||
"""
|
||||
Compute training losses for a single timestep.
|
||||
:param model: the model to evaluate loss on.
|
||||
|
|
@ -763,7 +763,11 @@ class GaussianDiffusion:
|
|||
ModelMeanType.EPSILON: noise,
|
||||
}[self.model_mean_type]
|
||||
assert model_output.shape == target.shape == x_start.shape
|
||||
terms["mse"] = mean_flat((target - model_output) ** 2, mask=mask)
|
||||
if weights is None:
|
||||
terms["mse"] = mean_flat((target - model_output) ** 2, mask=mask)
|
||||
else:
|
||||
weight = _extract_into_tensor(weights, t, target.shape)
|
||||
terms["mse"] = mean_flat(weight * (target - model_output) ** 2, mask=mask)
|
||||
if "vb" in terms:
|
||||
terms["loss"] = terms["mse"] + terms["vb"]
|
||||
else:
|
||||
|
|
|
|||
75
opensora/schedulers/iddpm/speed.py
Normal file
75
opensora/schedulers/iddpm/speed.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from opensora.registry import SCHEDULERS
|
||||
|
||||
from . import gaussian_diffusion as gd
|
||||
from .respace import SpacedDiffusion, space_timesteps
|
||||
|
||||
|
||||
@SCHEDULERS.register_module("iddpm-speed")
|
||||
class SpeeDiffusion(SpacedDiffusion):
|
||||
def __init__(
|
||||
self,
|
||||
num_sampling_steps=None,
|
||||
timestep_respacing=None,
|
||||
noise_schedule="linear",
|
||||
use_kl=False,
|
||||
sigma_small=False,
|
||||
predict_xstart=False,
|
||||
learn_sigma=True,
|
||||
rescale_learned_sigmas=False,
|
||||
diffusion_steps=1000,
|
||||
cfg_scale=4.0,
|
||||
):
|
||||
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
|
||||
if use_kl:
|
||||
loss_type = gd.LossType.RESCALED_KL
|
||||
elif rescale_learned_sigmas:
|
||||
loss_type = gd.LossType.RESCALED_MSE
|
||||
else:
|
||||
loss_type = gd.LossType.MSE
|
||||
if num_sampling_steps is not None:
|
||||
assert timestep_respacing is None
|
||||
timestep_respacing = str(num_sampling_steps)
|
||||
if timestep_respacing is None or timestep_respacing == "":
|
||||
timestep_respacing = [diffusion_steps]
|
||||
super().__init__(
|
||||
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
|
||||
betas=betas,
|
||||
model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X),
|
||||
model_var_type=(
|
||||
(gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
|
||||
if not learn_sigma
|
||||
else gd.ModelVarType.LEARNED_RANGE
|
||||
),
|
||||
loss_type=loss_type,
|
||||
)
|
||||
|
||||
self.cfg_scale = cfg_scale
|
||||
|
||||
grad = np.gradient(self.sqrt_one_minus_alphas_cumprod)
|
||||
self.meaningful_steps = np.argmax(grad < 5e-5) + 1
|
||||
|
||||
# p2 weighting from: Perception Prioritized Training of Diffusion Models
|
||||
self.p2_gamma = 1
|
||||
self.p2_k = 1
|
||||
self.snr = 1.0 / (1 - self.alphas_cumprod) - 1
|
||||
sqrt_one_minus_alphas_bar = torch.from_numpy(self.sqrt_one_minus_alphas_cumprod)
|
||||
p = torch.tanh(1e6 * (torch.gradient(sqrt_one_minus_alphas_bar)[0] - 1e-4)) + 1.5
|
||||
self.p = F.normalize(p, p=1, dim=0)
|
||||
self.weights = 1 / (self.p2_k + self.snr) ** self.p2_gamma
|
||||
|
||||
def t_sample(self, n, device):
|
||||
t = torch.multinomial(self.p, n // 2 + 1, replacement=True).to(device)
|
||||
dual_t = torch.where(t < self.meaningful_steps, self.meaningful_steps - t, t - self.meaningful_steps)
|
||||
t = torch.cat([t, dual_t], dim=0)[:n]
|
||||
return t
|
||||
|
||||
def training_losses(self, model, x, t, *args, **kwargs): # pylint: disable=signature-differs
|
||||
t = self.t_sample(x.shape[0], x.device)
|
||||
return super().training_losses(model, x, t, weights=self.weights, *args, **kwargs)
|
||||
|
||||
def sample(self, *args, **kwargs):
|
||||
raise NotImplementedError("SpeeDiffusion is only for training")
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
import math
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
|
||||
|
|
@ -33,13 +34,14 @@ def update_ema(
|
|||
|
||||
|
||||
class MaskGenerator:
|
||||
def __init__(self, mask_ratios):
|
||||
def __init__(self, mask_ratios, condition_frames=4):
|
||||
self.mask_name = ["mask_no", "mask_random", "mask_head", "mask_tail", "mask_head_tail"]
|
||||
assert len(mask_ratios) == len(self.mask_name)
|
||||
assert sum(mask_ratios) == 1.0
|
||||
assert math.isclose(sum(mask_ratios), 1.0, abs_tol=1e-6)
|
||||
self.mask_prob = mask_ratios
|
||||
print(self.mask_prob)
|
||||
self.mask_acc_prob = [sum(self.mask_prob[: i + 1]) for i in range(len(self.mask_prob))]
|
||||
self.condition_frames = condition_frames
|
||||
|
||||
def get_mask(self, x):
|
||||
mask_type = random.random()
|
||||
|
|
@ -50,18 +52,18 @@ class MaskGenerator:
|
|||
|
||||
mask = torch.ones(x.shape[2], dtype=torch.bool, device=x.device)
|
||||
if mask_name == "mask_random":
|
||||
random_size = random.randint(1, 4)
|
||||
random_size = random.randint(1, self.condition_frames)
|
||||
random_pos = random.randint(0, x.shape[2] - random_size)
|
||||
mask[random_pos : random_pos + random_size] = 0
|
||||
return mask
|
||||
elif mask_name == "mask_head":
|
||||
random_size = random.randint(1, 4)
|
||||
random_size = random.randint(1, self.condition_frames)
|
||||
mask[:random_size] = 0
|
||||
elif mask_name == "mask_tail":
|
||||
random_size = random.randint(1, 4)
|
||||
random_size = random.randint(1, self.condition_frames)
|
||||
mask[-random_size:] = 0
|
||||
elif mask_name == "mask_head_tail":
|
||||
random_size = random.randint(1, 4)
|
||||
random_size = random.randint(1, self.condition_frames)
|
||||
mask[:random_size] = 0
|
||||
mask[-random_size:] = 0
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue