mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-14 18:25:35 +02:00
130 lines
5 KiB
Python
130 lines
5 KiB
Python
import torch
|
|
from torch.distributions import LogisticNormal
|
|
|
|
from ..iddpm.gaussian_diffusion import _extract_into_tensor, mean_flat
|
|
|
|
# some code are inspired by https://github.com/magic-research/piecewise-rectified-flow/blob/main/scripts/train_perflow.py
|
|
# and https://github.com/magic-research/piecewise-rectified-flow/blob/main/src/scheduler_perflow.py
|
|
|
|
|
|
def timestep_transform(
|
|
t,
|
|
model_kwargs,
|
|
base_resolution=512 * 512,
|
|
base_num_frames=1,
|
|
scale=1.0,
|
|
num_timesteps=1,
|
|
):
|
|
# Force fp16 input to fp32 to avoid nan output
|
|
for key in ["height", "width", "num_frames"]:
|
|
if model_kwargs[key].dtype == torch.float16:
|
|
model_kwargs[key] = model_kwargs[key].float()
|
|
|
|
t = t / num_timesteps
|
|
resolution = model_kwargs["height"] * model_kwargs["width"]
|
|
ratio_space = (resolution / base_resolution).sqrt()
|
|
# NOTE: currently, we do not take fps into account
|
|
# NOTE: temporal_reduction is hardcoded, this should be equal to the temporal reduction factor of the vae
|
|
if model_kwargs["num_frames"][0] == 1:
|
|
num_frames = torch.ones_like(model_kwargs["num_frames"])
|
|
else:
|
|
num_frames = model_kwargs["num_frames"] // 17 * 5
|
|
ratio_time = (num_frames / base_num_frames).sqrt()
|
|
|
|
ratio = ratio_space * ratio_time * scale
|
|
new_t = ratio * t / (1 + (ratio - 1) * t)
|
|
|
|
new_t = new_t * num_timesteps
|
|
return new_t
|
|
|
|
|
|
class RFlowScheduler:
|
|
def __init__(
|
|
self,
|
|
num_timesteps=1000,
|
|
num_sampling_steps=10,
|
|
use_discrete_timesteps=False,
|
|
sample_method="uniform",
|
|
loc=0.0,
|
|
scale=1.0,
|
|
use_timestep_transform=False,
|
|
transform_scale=1.0,
|
|
):
|
|
self.num_timesteps = num_timesteps
|
|
self.num_sampling_steps = num_sampling_steps
|
|
self.use_discrete_timesteps = use_discrete_timesteps
|
|
|
|
# sample method
|
|
assert sample_method in ["uniform", "logit-normal"]
|
|
assert (
|
|
sample_method == "uniform" or not use_discrete_timesteps
|
|
), "Only uniform sampling is supported for discrete timesteps"
|
|
self.sample_method = sample_method
|
|
if sample_method == "logit-normal":
|
|
self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale]))
|
|
self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device)
|
|
|
|
# timestep transform
|
|
self.use_timestep_transform = use_timestep_transform
|
|
self.transform_scale = transform_scale
|
|
|
|
def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None):
|
|
"""
|
|
Compute training losses for a single timestep.
|
|
Arguments format copied from opensora/schedulers/iddpm/gaussian_diffusion.py/training_losses
|
|
Note: t is int tensor and should be rescaled from [0, num_timesteps-1] to [1,0]
|
|
"""
|
|
if t is None:
|
|
if self.use_discrete_timesteps:
|
|
t = torch.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device)
|
|
elif self.sample_method == "uniform":
|
|
t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_timesteps
|
|
elif self.sample_method == "logit-normal":
|
|
t = self.sample_t(x_start) * self.num_timesteps
|
|
|
|
if self.use_timestep_transform:
|
|
t = timestep_transform(t, model_kwargs, scale=self.transform_scale, num_timesteps=self.num_timesteps)
|
|
|
|
if model_kwargs is None:
|
|
model_kwargs = {}
|
|
if noise is None:
|
|
noise = torch.randn_like(x_start)
|
|
assert noise.shape == x_start.shape
|
|
|
|
x_t = self.add_noise(x_start, noise, t)
|
|
if mask is not None:
|
|
t0 = torch.zeros_like(t)
|
|
x_t0 = self.add_noise(x_start, noise, t0)
|
|
x_t = torch.where(mask[:, None, :, None, None], x_t, x_t0)
|
|
|
|
terms = {}
|
|
model_output = model(x_t, t, **model_kwargs)
|
|
velocity_pred = model_output.chunk(2, dim=1)[0]
|
|
if weights is None:
|
|
loss = mean_flat((velocity_pred - (x_start - noise)).pow(2), mask=mask)
|
|
else:
|
|
weight = _extract_into_tensor(weights, t, x_start.shape)
|
|
loss = mean_flat(weight * (velocity_pred - (x_start - noise)).pow(2), mask=mask)
|
|
terms["loss"] = loss
|
|
|
|
return terms
|
|
|
|
def add_noise(
|
|
self,
|
|
original_samples: torch.FloatTensor,
|
|
noise: torch.FloatTensor,
|
|
timesteps: torch.IntTensor,
|
|
) -> torch.FloatTensor:
|
|
"""
|
|
compatible with diffusers add_noise()
|
|
"""
|
|
timepoints = timesteps.float() / self.num_timesteps
|
|
timepoints = 1 - timepoints # [1,1/1000]
|
|
|
|
# timepoint (bsz) noise: (bsz, 4, frame, w ,h)
|
|
# expand timepoint to noise shape
|
|
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
|
|
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4])
|
|
|
|
return timepoints * original_samples + (1 - timepoints) * noise
|