From 6fa92af3552253b58006839351a433e775ac0937 Mon Sep 17 00:00:00 2001 From: ChangeFWorld Date: Thu, 11 Apr 2024 11:31:21 +0800 Subject: [PATCH] rflow coarse finish, no check yet --- opensora/schedulers/rf/__init__.py | 55 +++++++++++++++++++++++- opensora/schedulers/rf/rectified_flow.py | 43 +++++++++++++----- 2 files changed, 86 insertions(+), 12 deletions(-) diff --git a/opensora/schedulers/rf/__init__.py b/opensora/schedulers/rf/__init__.py index e8bdc45..7520602 100644 --- a/opensora/schedulers/rf/__init__.py +++ b/opensora/schedulers/rf/__init__.py @@ -6,5 +6,56 @@ from functools import partial from opensora.registry import SCHEDULERS -# @SCHEDULERS.register_module("rflow") -# class RFLOW: \ No newline at end of file +@SCHEDULERS.register_module("rflow") +class RFLOW: + def __init__(self, num_sampling_steps = 10, num_timesteps = 1000, cfg_scale = 4.0): + self.num_sampling_steps = num_sampling_steps + self.num_timesteps = num_timesteps + self.cfg_scale = cfg_scale + + self.scheduler = RFlowScheduler(num_timesteps = num_timesteps, num_sampling_steps = num_sampling_steps) + + def sample( + self, + model, + text_encoder, + z, + prompts, + device, + additional_args=None, + mask=None, + guidance_scale = None, + # progress = True, + ): + assert mask is None, "mask is not supported in rectified flow yet" + # if no specific guidance scale is provided, use the default scale when initializing the scheduler + if guidance_scale is None: + guidance_scale = self.cfg_scale + + n = len(prompts) + z = torch.cat([z, z], 0) + model_args = text_encoder.encode(prompts) + y_null = text_encoder.null(n) + model_args["y"] = torch.cat([model_args["y"], y_null], 0) + if additional_args is not None: + model_args.update(additional_args) + + timesteps = [(1. - i/self.num_sampling_steps) * 1000. for i in range(self.num_sampling_steps)] + + # convert float timesteps to most close int timesteps + timesteps = [int(round(t)) for t in timesteps] + + for i, t in enumerate(timesteps): + pred = model(z, torch.tensor(t * z.shape[0], device = device), **model_args) + pred_cond, pred_uncond = pred.chunk(2, dim = 1) + v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + + dt = (timesteps[i] - timesteps[i+1])/self.num_timesteps if i < len(timesteps) - 1 else 1/self.num_timesteps + z = z + v_pred * dt + + return z + + + + + \ No newline at end of file diff --git a/opensora/schedulers/rf/rectified_flow.py b/opensora/schedulers/rf/rectified_flow.py index 80732e5..03aef76 100644 --- a/opensora/schedulers/rf/rectified_flow.py +++ b/opensora/schedulers/rf/rectified_flow.py @@ -1,15 +1,19 @@ import torch import numpy as np +from typing import Union # some code are inspired by https://github.com/magic-research/piecewise-rectified-flow/blob/main/scripts/train_perflow.py +# and https://github.com/magic-research/piecewise-rectified-flow/blob/main/src/scheduler_perflow.py class RFlowScheduler: def __init__( self, num_timesteps = 1000, + num_sampling_steps = 10, ): self.num_timesteps = num_timesteps + self.num_sampling_steps = num_sampling_steps def training_losses(self, model, x_start, t, model_kwargs=None, noise = None, mask = None, weights = None): @@ -35,15 +39,34 @@ class RFlowScheduler: return terms - def add_noise(self, x0, x1, t): - ''' - x0: sample of dataset - x1: sample of gaussian distribution - ''' - # rescale t from [0,num_timesteps] to [0,1] - t = t / self.num_timesteps - return t * x1 + (1 - t) * x0 + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + """ + compatible with diffusers add_noise() + """ + timepoints = timesteps.float() / self.num_timesteps # [0, 999/1000] + timepoints = 1 - timepoints # [1,1/1000] + + return timepoints * original_samples + (1 - timepoints) * noise - def step(): - pass + # def step( + # self, + # model_output: torch.FloatTensor, + # timestep: Union[int, torch.IntTensor], + # sample: torch.FloatTensor, + # ) -> torch.FloatTensor: + # ''' + # take an Euler step sampling + # ''' + + # dt = 1 / self.num_sampling_steps + + # return sample + dt * model_output + + \ No newline at end of file