diff --git a/opensora/schedulers/rf/__init__.py b/opensora/schedulers/rf/__init__.py index aa0e086..b6239de 100644 --- a/opensora/schedulers/rf/__init__.py +++ b/opensora/schedulers/rf/__init__.py @@ -27,7 +27,7 @@ class RFLOW: guidance_scale = None, # progress = True, ): - assert mask is None, "mask is not supported in rectified flow yet" + assert mask is None, "mask is not supported in rectified flow inference yet" # if no specific guidance scale is provided, use the default scale when initializing the scheduler if guidance_scale is None: guidance_scale = self.cfg_scale @@ -46,7 +46,6 @@ class RFLOW: for i, t in enumerate(timesteps): z_in = torch.cat([z, z], 0) - print(z_in.shape, torch.tensor([t]* z_in.shape[0], device = device).shape) pred = model(z_in, torch.tensor([t]* z_in.shape[0], device = device), **model_args).chunk(2, dim = 1)[0] pred_cond, pred_uncond = pred.chunk(2, dim = 0) v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) diff --git a/opensora/schedulers/rf/rectified_flow.py b/opensora/schedulers/rf/rectified_flow.py index 9f85c10..c4ad216 100644 --- a/opensora/schedulers/rf/rectified_flow.py +++ b/opensora/schedulers/rf/rectified_flow.py @@ -1,6 +1,9 @@ import torch import numpy as np from typing import Union +from einops import rearrange +from typing import List +from ..iddpm.gaussian_diffusion import _extract_into_tensor, mean_flat # some code are inspired by https://github.com/magic-research/piecewise-rectified-flow/blob/main/scripts/train_perflow.py # and https://github.com/magic-research/piecewise-rectified-flow/blob/main/src/scheduler_perflow.py @@ -22,7 +25,6 @@ class RFlowScheduler: Arguments format copied from opensora/schedulers/iddpm/gaussian_diffusion.py/training_losses Note: t is int tensor and should be rescaled from [0, num_timesteps-1] to [1,0] ''' - assert mask is None, "mask not support for rectified flow yet" assert weights is None, "weights not support for rectified flow yet" if model_kwargs is None: @@ -32,11 +34,20 @@ class RFlowScheduler: assert noise.shape == x_start.shape x_t = self.add_noise(x_start, noise, t) + if mask is not None: + t0 = torch.zeros_like(t) + x_t0 = self.add_noise(x_start, noise, t0) + x_t = torch.where(mask[:, None, :, None, None], x_t, x_t0) + terms = {} model_output = model(x_t, t, **model_kwargs) velocity_pred = model_output.chunk(2, dim = 1)[0] - loss = (velocity_pred - (x_start - noise)).pow(2).mean() + if weights is None: + loss = mean_flat((velocity_pred - (x_start - noise)).pow(2), mask = mask) + else: + weight = _extract_into_tensor(weights, t, x_start.shape) + loss = mean_flat(weight * (velocity_pred - (x_start - noise)).pow(2), mask = mask) terms['loss'] = loss return terms @@ -61,19 +72,6 @@ class RFlowScheduler: return timepoints * original_samples + (1 - timepoints) * noise - # def step( - # self, - # model_output: torch.FloatTensor, - # timestep: Union[int, torch.IntTensor], - # sample: torch.FloatTensor, - # ) -> torch.FloatTensor: - # ''' - # take an Euler step sampling - # ''' - - # dt = 1 / self.num_sampling_steps - - # return sample + dt * model_output \ No newline at end of file