mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
mask and weight support for rf training
This commit is contained in:
parent
07cbe11cb6
commit
fb4aede778
|
|
@ -27,7 +27,7 @@ class RFLOW:
|
|||
guidance_scale = None,
|
||||
# progress = True,
|
||||
):
|
||||
assert mask is None, "mask is not supported in rectified flow yet"
|
||||
assert mask is None, "mask is not supported in rectified flow inference yet"
|
||||
# if no specific guidance scale is provided, use the default scale when initializing the scheduler
|
||||
if guidance_scale is None:
|
||||
guidance_scale = self.cfg_scale
|
||||
|
|
@ -46,7 +46,6 @@ class RFLOW:
|
|||
|
||||
for i, t in enumerate(timesteps):
|
||||
z_in = torch.cat([z, z], 0)
|
||||
print(z_in.shape, torch.tensor([t]* z_in.shape[0], device = device).shape)
|
||||
pred = model(z_in, torch.tensor([t]* z_in.shape[0], device = device), **model_args).chunk(2, dim = 1)[0]
|
||||
pred_cond, pred_uncond = pred.chunk(2, dim = 0)
|
||||
v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
from einops import rearrange
|
||||
from typing import List
|
||||
from ..iddpm.gaussian_diffusion import _extract_into_tensor, mean_flat
|
||||
|
||||
# some code are inspired by https://github.com/magic-research/piecewise-rectified-flow/blob/main/scripts/train_perflow.py
|
||||
# and https://github.com/magic-research/piecewise-rectified-flow/blob/main/src/scheduler_perflow.py
|
||||
|
|
@ -22,7 +25,6 @@ class RFlowScheduler:
|
|||
Arguments format copied from opensora/schedulers/iddpm/gaussian_diffusion.py/training_losses
|
||||
Note: t is int tensor and should be rescaled from [0, num_timesteps-1] to [1,0]
|
||||
'''
|
||||
assert mask is None, "mask not support for rectified flow yet"
|
||||
assert weights is None, "weights not support for rectified flow yet"
|
||||
|
||||
if model_kwargs is None:
|
||||
|
|
@ -32,11 +34,20 @@ class RFlowScheduler:
|
|||
assert noise.shape == x_start.shape
|
||||
|
||||
x_t = self.add_noise(x_start, noise, t)
|
||||
if mask is not None:
|
||||
t0 = torch.zeros_like(t)
|
||||
x_t0 = self.add_noise(x_start, noise, t0)
|
||||
x_t = torch.where(mask[:, None, :, None, None], x_t, x_t0)
|
||||
|
||||
|
||||
terms = {}
|
||||
model_output = model(x_t, t, **model_kwargs)
|
||||
velocity_pred = model_output.chunk(2, dim = 1)[0]
|
||||
loss = (velocity_pred - (x_start - noise)).pow(2).mean()
|
||||
if weights is None:
|
||||
loss = mean_flat((velocity_pred - (x_start - noise)).pow(2), mask = mask)
|
||||
else:
|
||||
weight = _extract_into_tensor(weights, t, x_start.shape)
|
||||
loss = mean_flat(weight * (velocity_pred - (x_start - noise)).pow(2), mask = mask)
|
||||
terms['loss'] = loss
|
||||
|
||||
return terms
|
||||
|
|
@ -61,19 +72,6 @@ class RFlowScheduler:
|
|||
|
||||
return timepoints * original_samples + (1 - timepoints) * noise
|
||||
|
||||
# def step(
|
||||
# self,
|
||||
# model_output: torch.FloatTensor,
|
||||
# timestep: Union[int, torch.IntTensor],
|
||||
# sample: torch.FloatTensor,
|
||||
# ) -> torch.FloatTensor:
|
||||
# '''
|
||||
# take an Euler step sampling
|
||||
# '''
|
||||
|
||||
# dt = 1 / self.num_sampling_steps
|
||||
|
||||
# return sample + dt * model_output
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in a new issue