import torch from torch.distributions import LogisticNormal from ..iddpm.gaussian_diffusion import _extract_into_tensor, mean_flat # some code are inspired by https://github.com/magic-research/piecewise-rectified-flow/blob/main/scripts/train_perflow.py # and https://github.com/magic-research/piecewise-rectified-flow/blob/main/src/scheduler_perflow.py def timestep_transform( t, model_kwargs, base_resolution=512 * 512, base_num_frames=1, scale=1.0, num_timesteps=1, ): # Force fp16 input to fp32 to avoid nan output for key in ["height", "width", "num_frames"]: if model_kwargs[key].dtype == torch.float16: model_kwargs[key] = model_kwargs[key].float() t = t / num_timesteps resolution = model_kwargs["height"] * model_kwargs["width"] ratio_space = (resolution / base_resolution).sqrt() # NOTE: currently, we do not take fps into account # NOTE: temporal_reduction is hardcoded, this should be equal to the temporal reduction factor of the vae if model_kwargs["num_frames"][0] == 1: num_frames = torch.ones_like(model_kwargs["num_frames"]) else: num_frames = model_kwargs["num_frames"] // 17 * 5 ratio_time = (num_frames / base_num_frames).sqrt() ratio = ratio_space * ratio_time * scale new_t = ratio * t / (1 + (ratio - 1) * t) new_t = new_t * num_timesteps return new_t class RFlowScheduler: def __init__( self, num_timesteps=1000, num_sampling_steps=10, use_discrete_timesteps=False, sample_method="uniform", loc=0.0, scale=1.0, use_timestep_transform=False, transform_scale=1.0, ): self.num_timesteps = num_timesteps self.num_sampling_steps = num_sampling_steps self.use_discrete_timesteps = use_discrete_timesteps # sample method assert sample_method in ["uniform", "logit-normal"] assert ( sample_method == "uniform" or not use_discrete_timesteps ), "Only uniform sampling is supported for discrete timesteps" self.sample_method = sample_method if sample_method == "logit-normal": self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale])) self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device) # timestep transform self.use_timestep_transform = use_timestep_transform self.transform_scale = transform_scale def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None): """ Compute training losses for a single timestep. Arguments format copied from opensora/schedulers/iddpm/gaussian_diffusion.py/training_losses Note: t is int tensor and should be rescaled from [0, num_timesteps-1] to [1,0] """ if t is None: if self.use_discrete_timesteps: t = torch.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device) elif self.sample_method == "uniform": t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_timesteps elif self.sample_method == "logit-normal": t = self.sample_t(x_start) * self.num_timesteps if self.use_timestep_transform: t = timestep_transform(t, model_kwargs, scale=self.transform_scale, num_timesteps=self.num_timesteps) if model_kwargs is None: model_kwargs = {} if noise is None: noise = torch.randn_like(x_start) assert noise.shape == x_start.shape x_t = self.add_noise(x_start, noise, t) if mask is not None: t0 = torch.zeros_like(t) x_t0 = self.add_noise(x_start, noise, t0) x_t = torch.where(mask[:, None, :, None, None], x_t, x_t0) terms = {} model_output = model(x_t, t, **model_kwargs) velocity_pred = model_output.chunk(2, dim=1)[0] if weights is None: loss = mean_flat((velocity_pred - (x_start - noise)).pow(2), mask=mask) else: weight = _extract_into_tensor(weights, t, x_start.shape) loss = mean_flat(weight * (velocity_pred - (x_start - noise)).pow(2), mask=mask) terms["loss"] = loss return terms def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: """ compatible with diffusers add_noise() """ timepoints = timesteps.float() / self.num_timesteps timepoints = 1 - timepoints # [1,1/1000] # timepoint (bsz) noise: (bsz, 4, frame, w ,h) # expand timepoint to noise shape timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4]) return timepoints * original_samples + (1 - timepoints) * noise