mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
rflow coarse finish, no check yet
This commit is contained in:
parent
9a14b9e23c
commit
6fa92af355
|
|
@ -6,5 +6,56 @@ from functools import partial
|
|||
|
||||
from opensora.registry import SCHEDULERS
|
||||
|
||||
# @SCHEDULERS.register_module("rflow")
|
||||
# class RFLOW:
|
||||
@SCHEDULERS.register_module("rflow")
|
||||
class RFLOW:
|
||||
def __init__(self, num_sampling_steps = 10, num_timesteps = 1000, cfg_scale = 4.0):
|
||||
self.num_sampling_steps = num_sampling_steps
|
||||
self.num_timesteps = num_timesteps
|
||||
self.cfg_scale = cfg_scale
|
||||
|
||||
self.scheduler = RFlowScheduler(num_timesteps = num_timesteps, num_sampling_steps = num_sampling_steps)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
model,
|
||||
text_encoder,
|
||||
z,
|
||||
prompts,
|
||||
device,
|
||||
additional_args=None,
|
||||
mask=None,
|
||||
guidance_scale = None,
|
||||
# progress = True,
|
||||
):
|
||||
assert mask is None, "mask is not supported in rectified flow yet"
|
||||
# if no specific guidance scale is provided, use the default scale when initializing the scheduler
|
||||
if guidance_scale is None:
|
||||
guidance_scale = self.cfg_scale
|
||||
|
||||
n = len(prompts)
|
||||
z = torch.cat([z, z], 0)
|
||||
model_args = text_encoder.encode(prompts)
|
||||
y_null = text_encoder.null(n)
|
||||
model_args["y"] = torch.cat([model_args["y"], y_null], 0)
|
||||
if additional_args is not None:
|
||||
model_args.update(additional_args)
|
||||
|
||||
timesteps = [(1. - i/self.num_sampling_steps) * 1000. for i in range(self.num_sampling_steps)]
|
||||
|
||||
# convert float timesteps to most close int timesteps
|
||||
timesteps = [int(round(t)) for t in timesteps]
|
||||
|
||||
for i, t in enumerate(timesteps):
|
||||
pred = model(z, torch.tensor(t * z.shape[0], device = device), **model_args)
|
||||
pred_cond, pred_uncond = pred.chunk(2, dim = 1)
|
||||
v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
|
||||
|
||||
dt = (timesteps[i] - timesteps[i+1])/self.num_timesteps if i < len(timesteps) - 1 else 1/self.num_timesteps
|
||||
z = z + v_pred * dt
|
||||
|
||||
return z
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,15 +1,19 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
|
||||
# some code are inspired by https://github.com/magic-research/piecewise-rectified-flow/blob/main/scripts/train_perflow.py
|
||||
# and https://github.com/magic-research/piecewise-rectified-flow/blob/main/src/scheduler_perflow.py
|
||||
|
||||
|
||||
class RFlowScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
num_timesteps = 1000,
|
||||
num_sampling_steps = 10,
|
||||
):
|
||||
self.num_timesteps = num_timesteps
|
||||
self.num_sampling_steps = num_sampling_steps
|
||||
|
||||
|
||||
def training_losses(self, model, x_start, t, model_kwargs=None, noise = None, mask = None, weights = None):
|
||||
|
|
@ -35,15 +39,34 @@ class RFlowScheduler:
|
|||
|
||||
return terms
|
||||
|
||||
def add_noise(self, x0, x1, t):
|
||||
'''
|
||||
x0: sample of dataset
|
||||
x1: sample of gaussian distribution
|
||||
'''
|
||||
# rescale t from [0,num_timesteps] to [0,1]
|
||||
t = t / self.num_timesteps
|
||||
return t * x1 + (1 - t) * x0
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
compatible with diffusers add_noise()
|
||||
"""
|
||||
timepoints = timesteps.float() / self.num_timesteps # [0, 999/1000]
|
||||
timepoints = 1 - timepoints # [1,1/1000]
|
||||
|
||||
return timepoints * original_samples + (1 - timepoints) * noise
|
||||
|
||||
def step():
|
||||
pass
|
||||
# def step(
|
||||
# self,
|
||||
# model_output: torch.FloatTensor,
|
||||
# timestep: Union[int, torch.IntTensor],
|
||||
# sample: torch.FloatTensor,
|
||||
# ) -> torch.FloatTensor:
|
||||
# '''
|
||||
# take an Euler step sampling
|
||||
# '''
|
||||
|
||||
# dt = 1 / self.num_sampling_steps
|
||||
|
||||
# return sample + dt * model_output
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in a new issue