rflow coarse finish, no check yet

This commit is contained in:
ChangeFWorld 2024-04-11 11:31:21 +08:00
parent 9a14b9e23c
commit 6fa92af355
2 changed files with 86 additions and 12 deletions

View file

@ -6,5 +6,56 @@ from functools import partial
from opensora.registry import SCHEDULERS
# @SCHEDULERS.register_module("rflow")
# class RFLOW:
@SCHEDULERS.register_module("rflow")
class RFLOW:
def __init__(self, num_sampling_steps = 10, num_timesteps = 1000, cfg_scale = 4.0):
self.num_sampling_steps = num_sampling_steps
self.num_timesteps = num_timesteps
self.cfg_scale = cfg_scale
self.scheduler = RFlowScheduler(num_timesteps = num_timesteps, num_sampling_steps = num_sampling_steps)
def sample(
self,
model,
text_encoder,
z,
prompts,
device,
additional_args=None,
mask=None,
guidance_scale = None,
# progress = True,
):
assert mask is None, "mask is not supported in rectified flow yet"
# if no specific guidance scale is provided, use the default scale when initializing the scheduler
if guidance_scale is None:
guidance_scale = self.cfg_scale
n = len(prompts)
z = torch.cat([z, z], 0)
model_args = text_encoder.encode(prompts)
y_null = text_encoder.null(n)
model_args["y"] = torch.cat([model_args["y"], y_null], 0)
if additional_args is not None:
model_args.update(additional_args)
timesteps = [(1. - i/self.num_sampling_steps) * 1000. for i in range(self.num_sampling_steps)]
# convert float timesteps to most close int timesteps
timesteps = [int(round(t)) for t in timesteps]
for i, t in enumerate(timesteps):
pred = model(z, torch.tensor(t * z.shape[0], device = device), **model_args)
pred_cond, pred_uncond = pred.chunk(2, dim = 1)
v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
dt = (timesteps[i] - timesteps[i+1])/self.num_timesteps if i < len(timesteps) - 1 else 1/self.num_timesteps
z = z + v_pred * dt
return z

View file

@ -1,15 +1,19 @@
import torch
import numpy as np
from typing import Union
# some code are inspired by https://github.com/magic-research/piecewise-rectified-flow/blob/main/scripts/train_perflow.py
# and https://github.com/magic-research/piecewise-rectified-flow/blob/main/src/scheduler_perflow.py
class RFlowScheduler:
def __init__(
self,
num_timesteps = 1000,
num_sampling_steps = 10,
):
self.num_timesteps = num_timesteps
self.num_sampling_steps = num_sampling_steps
def training_losses(self, model, x_start, t, model_kwargs=None, noise = None, mask = None, weights = None):
@ -35,15 +39,34 @@ class RFlowScheduler:
return terms
def add_noise(self, x0, x1, t):
'''
x0: sample of dataset
x1: sample of gaussian distribution
'''
# rescale t from [0,num_timesteps] to [0,1]
t = t / self.num_timesteps
return t * x1 + (1 - t) * x0
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
"""
compatible with diffusers add_noise()
"""
timepoints = timesteps.float() / self.num_timesteps # [0, 999/1000]
timepoints = 1 - timepoints # [1,1/1000]
return timepoints * original_samples + (1 - timepoints) * noise
def step():
pass
# def step(
# self,
# model_output: torch.FloatTensor,
# timestep: Union[int, torch.IntTensor],
# sample: torch.FloatTensor,
# ) -> torch.FloatTensor:
# '''
# take an Euler step sampling
# '''
# dt = 1 / self.num_sampling_steps
# return sample + dt * model_output