From d00f6810bb660e7479e5bfcf4aef943ed3eb22a5 Mon Sep 17 00:00:00 2001 From: ChangeFWorld Date: Wed, 10 Apr 2024 10:07:21 +0800 Subject: [PATCH 1/7] typo fix --- opensora/schedulers/dpms/__init__.py | 2 +- opensora/schedulers/rf/__init__.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 opensora/schedulers/rf/__init__.py diff --git a/opensora/schedulers/dpms/__init__.py b/opensora/schedulers/dpms/__init__.py index ba2f648..ed74427 100644 --- a/opensora/schedulers/dpms/__init__.py +++ b/opensora/schedulers/dpms/__init__.py @@ -8,7 +8,7 @@ from .dpm_solver import DPMS @SCHEDULERS.register_module("dpm-solver") -class DMP_SOLVER: +class DPM_SOLVER: def __init__(self, num_sampling_steps=None, cfg_scale=4.0): self.num_sampling_steps = num_sampling_steps self.cfg_scale = cfg_scale diff --git a/opensora/schedulers/rf/__init__.py b/opensora/schedulers/rf/__init__.py new file mode 100644 index 0000000..e69de29 From 9a14b9e23cacbfe51fb659a89cd9573a26bbb197 Mon Sep 17 00:00:00 2001 From: ChangeFWorld Date: Wed, 10 Apr 2024 18:27:43 +0800 Subject: [PATCH 2/7] rflow scheduler unfinished(40%) --- opensora/schedulers/rf/__init__.py | 10 +++++ opensora/schedulers/rf/rectified_flow.py | 49 ++++++++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 opensora/schedulers/rf/rectified_flow.py diff --git a/opensora/schedulers/rf/__init__.py b/opensora/schedulers/rf/__init__.py index e69de29..e8bdc45 100644 --- a/opensora/schedulers/rf/__init__.py +++ b/opensora/schedulers/rf/__init__.py @@ -0,0 +1,10 @@ +# should have property num_timesteps, +# method sample() training_losses() +import torch +from .rectified_flow import RFlowScheduler +from functools import partial + +from opensora.registry import SCHEDULERS + +# @SCHEDULERS.register_module("rflow") +# class RFLOW: \ No newline at end of file diff --git a/opensora/schedulers/rf/rectified_flow.py b/opensora/schedulers/rf/rectified_flow.py new file mode 100644 index 0000000..80732e5 --- /dev/null +++ b/opensora/schedulers/rf/rectified_flow.py @@ -0,0 +1,49 @@ +import torch +import numpy as np + +# some code are inspired by https://github.com/magic-research/piecewise-rectified-flow/blob/main/scripts/train_perflow.py + + +class RFlowScheduler: + def __init__( + self, + num_timesteps = 1000, + ): + self.num_timesteps = num_timesteps + + + def training_losses(self, model, x_start, t, model_kwargs=None, noise = None, mask = None, weights = None): + ''' + Compute training losses for a single timestep. + Arguments format copied from opensora/schedulers/iddpm/gaussian_diffusion.py/training_losses + Note: t is int tensor and should be rescaled from [0, num_timesteps-1] to [1,0] + ''' + assert mask is None, "mask not support for rectified flow yet" + assert weights is None, "weights not support for rectified flow yet" + + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = torch.randn_like(x_start) + + x_t = self.add_noise(x_start, noise, t) + + terms = {} + velocity_pred = model(x_t, t, **model_kwargs) + loss = (velocity_pred - (x_start - noise)).pow(2).mean() + terms['loss'] = loss + + return terms + + def add_noise(self, x0, x1, t): + ''' + x0: sample of dataset + x1: sample of gaussian distribution + ''' + # rescale t from [0,num_timesteps] to [0,1] + t = t / self.num_timesteps + return t * x1 + (1 - t) * x0 + + def step(): + pass + \ No newline at end of file From 6fa92af3552253b58006839351a433e775ac0937 Mon Sep 17 00:00:00 2001 From: ChangeFWorld Date: Thu, 11 Apr 2024 11:31:21 +0800 Subject: [PATCH 3/7] rflow coarse finish, no check yet --- opensora/schedulers/rf/__init__.py | 55 +++++++++++++++++++++++- opensora/schedulers/rf/rectified_flow.py | 43 +++++++++++++----- 2 files changed, 86 insertions(+), 12 deletions(-) diff --git a/opensora/schedulers/rf/__init__.py b/opensora/schedulers/rf/__init__.py index e8bdc45..7520602 100644 --- a/opensora/schedulers/rf/__init__.py +++ b/opensora/schedulers/rf/__init__.py @@ -6,5 +6,56 @@ from functools import partial from opensora.registry import SCHEDULERS -# @SCHEDULERS.register_module("rflow") -# class RFLOW: \ No newline at end of file +@SCHEDULERS.register_module("rflow") +class RFLOW: + def __init__(self, num_sampling_steps = 10, num_timesteps = 1000, cfg_scale = 4.0): + self.num_sampling_steps = num_sampling_steps + self.num_timesteps = num_timesteps + self.cfg_scale = cfg_scale + + self.scheduler = RFlowScheduler(num_timesteps = num_timesteps, num_sampling_steps = num_sampling_steps) + + def sample( + self, + model, + text_encoder, + z, + prompts, + device, + additional_args=None, + mask=None, + guidance_scale = None, + # progress = True, + ): + assert mask is None, "mask is not supported in rectified flow yet" + # if no specific guidance scale is provided, use the default scale when initializing the scheduler + if guidance_scale is None: + guidance_scale = self.cfg_scale + + n = len(prompts) + z = torch.cat([z, z], 0) + model_args = text_encoder.encode(prompts) + y_null = text_encoder.null(n) + model_args["y"] = torch.cat([model_args["y"], y_null], 0) + if additional_args is not None: + model_args.update(additional_args) + + timesteps = [(1. - i/self.num_sampling_steps) * 1000. for i in range(self.num_sampling_steps)] + + # convert float timesteps to most close int timesteps + timesteps = [int(round(t)) for t in timesteps] + + for i, t in enumerate(timesteps): + pred = model(z, torch.tensor(t * z.shape[0], device = device), **model_args) + pred_cond, pred_uncond = pred.chunk(2, dim = 1) + v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + + dt = (timesteps[i] - timesteps[i+1])/self.num_timesteps if i < len(timesteps) - 1 else 1/self.num_timesteps + z = z + v_pred * dt + + return z + + + + + \ No newline at end of file diff --git a/opensora/schedulers/rf/rectified_flow.py b/opensora/schedulers/rf/rectified_flow.py index 80732e5..03aef76 100644 --- a/opensora/schedulers/rf/rectified_flow.py +++ b/opensora/schedulers/rf/rectified_flow.py @@ -1,15 +1,19 @@ import torch import numpy as np +from typing import Union # some code are inspired by https://github.com/magic-research/piecewise-rectified-flow/blob/main/scripts/train_perflow.py +# and https://github.com/magic-research/piecewise-rectified-flow/blob/main/src/scheduler_perflow.py class RFlowScheduler: def __init__( self, num_timesteps = 1000, + num_sampling_steps = 10, ): self.num_timesteps = num_timesteps + self.num_sampling_steps = num_sampling_steps def training_losses(self, model, x_start, t, model_kwargs=None, noise = None, mask = None, weights = None): @@ -35,15 +39,34 @@ class RFlowScheduler: return terms - def add_noise(self, x0, x1, t): - ''' - x0: sample of dataset - x1: sample of gaussian distribution - ''' - # rescale t from [0,num_timesteps] to [0,1] - t = t / self.num_timesteps - return t * x1 + (1 - t) * x0 + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + """ + compatible with diffusers add_noise() + """ + timepoints = timesteps.float() / self.num_timesteps # [0, 999/1000] + timepoints = 1 - timepoints # [1,1/1000] + + return timepoints * original_samples + (1 - timepoints) * noise - def step(): - pass + # def step( + # self, + # model_output: torch.FloatTensor, + # timestep: Union[int, torch.IntTensor], + # sample: torch.FloatTensor, + # ) -> torch.FloatTensor: + # ''' + # take an Euler step sampling + # ''' + + # dt = 1 / self.num_sampling_steps + + # return sample + dt * model_output + + \ No newline at end of file From 2d1590b5821f55b568fc52f6f9c5d530bed1824e Mon Sep 17 00:00:00 2001 From: tianyi Date: Tue, 16 Apr 2024 16:07:30 +0800 Subject: [PATCH 4/7] fix bugs, runable --- .gitignore | 1 + opensora/schedulers/__init__.py | 1 + opensora/schedulers/rf/__init__.py | 3 +++ opensora/schedulers/rf/rectified_flow.py | 9 ++++++++- 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 55daee5..8797ae7 100644 --- a/.gitignore +++ b/.gitignore @@ -177,3 +177,4 @@ pretrained_models/ # Secret files hostfile gradio_cached_examples/ +wandb/ diff --git a/opensora/schedulers/__init__.py b/opensora/schedulers/__init__.py index 97ea76f..c7733a4 100644 --- a/opensora/schedulers/__init__.py +++ b/opensora/schedulers/__init__.py @@ -1,2 +1,3 @@ from .dpms import DPMS from .iddpm import IDDPM +from .rf import RFLOW diff --git a/opensora/schedulers/rf/__init__.py b/opensora/schedulers/rf/__init__.py index 7520602..dd1009e 100644 --- a/opensora/schedulers/rf/__init__.py +++ b/opensora/schedulers/rf/__init__.py @@ -54,6 +54,9 @@ class RFLOW: z = z + v_pred * dt return z + + def training_losses(self, model, x_start, t, model_kwargs=None, noise = None, mask = None, weights = None): + return self.scheduler.training_losses(model, x_start, t, model_kwargs, noise, mask, weights) diff --git a/opensora/schedulers/rf/rectified_flow.py b/opensora/schedulers/rf/rectified_flow.py index 03aef76..9f85c10 100644 --- a/opensora/schedulers/rf/rectified_flow.py +++ b/opensora/schedulers/rf/rectified_flow.py @@ -29,11 +29,13 @@ class RFlowScheduler: model_kwargs = {} if noise is None: noise = torch.randn_like(x_start) + assert noise.shape == x_start.shape x_t = self.add_noise(x_start, noise, t) terms = {} - velocity_pred = model(x_t, t, **model_kwargs) + model_output = model(x_t, t, **model_kwargs) + velocity_pred = model_output.chunk(2, dim = 1)[0] loss = (velocity_pred - (x_start - noise)).pow(2).mean() terms['loss'] = loss @@ -52,6 +54,11 @@ class RFlowScheduler: timepoints = timesteps.float() / self.num_timesteps # [0, 999/1000] timepoints = 1 - timepoints # [1,1/1000] + # timepoint (bsz) noise: (bsz, 4, frame, w ,h) + # expand timepoint to noise shape + timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) + timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4]) + return timepoints * original_samples + (1 - timepoints) * noise # def step( From 0045d8b7b0385133516c0e3b3e078035ec62a4db Mon Sep 17 00:00:00 2001 From: tianyi Date: Fri, 19 Apr 2024 11:18:29 +0800 Subject: [PATCH 5/7] fix --- configs/opensora-v1-1/train/image_rflow.py | 88 +++++++++++++++++++ .../opensora/inference/16x512x512-rflow.py | 35 ++++++++ .../opensora/train/16x256x256-spee-rflow.py | 64 ++++++++++++++ configs/pixart/inference/1x512x512-rflow.py | 39 ++++++++ configs/pixart/train/1x512x512-rflow.py | 55 ++++++++++++ opensora/schedulers/rf/__init__.py | 7 +- 6 files changed, 285 insertions(+), 3 deletions(-) create mode 100644 configs/opensora-v1-1/train/image_rflow.py create mode 100644 configs/opensora/inference/16x512x512-rflow.py create mode 100644 configs/opensora/train/16x256x256-spee-rflow.py create mode 100644 configs/pixart/inference/1x512x512-rflow.py create mode 100644 configs/pixart/train/1x512x512-rflow.py diff --git a/configs/opensora-v1-1/train/image_rflow.py b/configs/opensora-v1-1/train/image_rflow.py new file mode 100644 index 0000000..66c7b56 --- /dev/null +++ b/configs/opensora-v1-1/train/image_rflow.py @@ -0,0 +1,88 @@ +# Define dataset +# dataset = dict( +# type="VariableVideoTextDataset", +# data_path=None, +# num_frames=None, +# frame_interval=3, +# image_size=(None, None), +# transform_name="resize_crop", +# ) +dataset = dict( + type="VideoTextDataset", + data_path=None, + num_frames=1, + frame_interval=1, + image_size=(256, 256), + transform_name="center", +) +bucket_config = { # 6s/it + "256": {1: (1.0, 256)}, + "512": {1: (1.0, 80)}, + "480p": {1: (1.0, 52)}, + "1024": {1: (1.0, 20)}, + "1080p": {1: (1.0, 8)}, +} + +# Define acceleration +num_workers = 16 +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +# model = dict( +# type="DiT-XL/2", +# from_pretrained="/home/zhaowangbo/wangbo/PixArt-alpha/pretrained_models/PixArt-XL-2-512x512.pth", +# # input_sq_size=512, # pretrained model is trained on 512x512 +# enable_flashattn=True, +# enable_layernorm_kernel=True, +# ) +model = dict( + type="PixArt-XL/2", + space_scale=1.0, + time_scale=1.0, + no_temporal_pos_emb=True, + from_pretrained="PixArt-XL-2-512x512.pth", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +# model = dict( +# type="DiT-XL/2", +# # space_scale=1.0, +# # time_scale=1.0, +# no_temporal_pos_emb=True, +# # from_pretrained="PixArt-XL-2-512x512.pth", +# from_pretrained="/home/zhaowangbo/wangbo/PixArt-alpha/pretrained_models/PixArt-XL-2-512x512.pth", +# enable_flashattn=True, +# enable_layernorm_kernel=True, +# ) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=4, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=200, + shardformer=True, +) +scheduler = dict( + type="rflow", + # timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 10 +log_every = 10 +ckpt_every = 500 +load = None + +batch_size = 100 # only for logging +lr = 2e-5 +grad_clip = 1.0 diff --git a/configs/opensora/inference/16x512x512-rflow.py b/configs/opensora/inference/16x512x512-rflow.py new file mode 100644 index 0000000..35804d5 --- /dev/null +++ b/configs/opensora/inference/16x512x512-rflow.py @@ -0,0 +1,35 @@ +num_frames = 16 +fps = 24 // 3 +image_size = (512, 512) + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=1.0, + time_scale=1.0, + enable_flashattn=True, + enable_layernorm_kernel=True, + from_pretrained="PRETRAINED_MODEL", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=2, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, +) +scheduler = dict( + type="rflow", + num_sampling_steps=10, + cfg_scale=7.0, +) +dtype = "bf16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/t2v_samples.txt" +save_dir = "./outputs/samples/" diff --git a/configs/opensora/train/16x256x256-spee-rflow.py b/configs/opensora/train/16x256x256-spee-rflow.py new file mode 100644 index 0000000..cbb929c --- /dev/null +++ b/configs/opensora/train/16x256x256-spee-rflow.py @@ -0,0 +1,64 @@ +# Define dataset +dataset = dict( + type="VideoTextDataset", + data_path=None, + num_frames=16, + frame_interval=3, + image_size=(256, 256), +) + +# Define acceleration +num_workers = 4 +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=0.5, + time_scale=1.0, + # from_pretrained="PixArt-XL-2-512x512.pth", + # from_pretrained = "/home/zhaowangbo/wangbo/PixArt-alpha/pretrained_models/OpenSora-v1-HQ-16x512x512.pth", + # from_pretrained = "OpenSora-v1-HQ-16x512x512.pth", + from_pretrained = "PRETRAINED_MODEL", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +# mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07] +# mask_ratios = { +# "mask_no": 0.9, +# "mask_random": 0.06, +# "mask_head": 0.01, +# "mask_tail": 0.01, +# "mask_head_tail": 0.02, +# } +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="rflow", + # timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = True + +epochs = 1 +log_every = 10 +ckpt_every = 1000 +load = None + +batch_size = 16 +lr = 2e-5 +grad_clip = 1.0 diff --git a/configs/pixart/inference/1x512x512-rflow.py b/configs/pixart/inference/1x512x512-rflow.py new file mode 100644 index 0000000..7bce7e2 --- /dev/null +++ b/configs/pixart/inference/1x512x512-rflow.py @@ -0,0 +1,39 @@ +num_frames = 1 +fps = 1 +image_size = (512, 512) + +# Define model +model = dict( + type="PixArt-XL/2", + space_scale=1.0, + time_scale=1.0, + no_temporal_pos_emb=True, + from_pretrained="PRETRAINED_MODEL", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, +) +scheduler = dict( + type="rflow", + num_sampling_steps=20, + cfg_scale=7.0, +) +dtype = "bf16" + +# prompt_path = "./assets/texts/t2i_samples.txt" +prompt = [ + "Pirate ship trapped in a cosmic maelstrom nebula.", + "A small cactus with a happy face in the Sahara desert.", + "A small cactus with a sad face in the Sahara desert.", +] + +# Others +batch_size = 2 +seed = 42 +save_dir = "./outputs/samples2/" diff --git a/configs/pixart/train/1x512x512-rflow.py b/configs/pixart/train/1x512x512-rflow.py new file mode 100644 index 0000000..8925cdb --- /dev/null +++ b/configs/pixart/train/1x512x512-rflow.py @@ -0,0 +1,55 @@ +# Define dataset +dataset = dict( + type="VideoTextDataset", + data_path=None, + num_frames=1, + frame_interval=3, + image_size=(512, 512), +) + +# Define acceleration +num_workers = 4 +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="PixArt-XL/2", + space_scale=1.0, + time_scale=1.0, + no_temporal_pos_emb=True, + # from_pretrained="PixArt-XL-2-512x512.pth", + from_pretrained = "PRETRAINED_MODEL", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="rflow", + # timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = True + +epochs = 2 +log_every = 10 +ckpt_every = 1000 +load = None + +batch_size = 64 +lr = 2e-5 +grad_clip = 1.0 diff --git a/opensora/schedulers/rf/__init__.py b/opensora/schedulers/rf/__init__.py index dd1009e..aa0e086 100644 --- a/opensora/schedulers/rf/__init__.py +++ b/opensora/schedulers/rf/__init__.py @@ -33,7 +33,6 @@ class RFLOW: guidance_scale = self.cfg_scale n = len(prompts) - z = torch.cat([z, z], 0) model_args = text_encoder.encode(prompts) y_null = text_encoder.null(n) model_args["y"] = torch.cat([model_args["y"], y_null], 0) @@ -46,8 +45,10 @@ class RFLOW: timesteps = [int(round(t)) for t in timesteps] for i, t in enumerate(timesteps): - pred = model(z, torch.tensor(t * z.shape[0], device = device), **model_args) - pred_cond, pred_uncond = pred.chunk(2, dim = 1) + z_in = torch.cat([z, z], 0) + print(z_in.shape, torch.tensor([t]* z_in.shape[0], device = device).shape) + pred = model(z_in, torch.tensor([t]* z_in.shape[0], device = device), **model_args).chunk(2, dim = 1)[0] + pred_cond, pred_uncond = pred.chunk(2, dim = 0) v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) dt = (timesteps[i] - timesteps[i+1])/self.num_timesteps if i < len(timesteps) - 1 else 1/self.num_timesteps From fb4aede778a7d272a3d996c88cf43c11d62fd884 Mon Sep 17 00:00:00 2001 From: Tianyi Date: Fri, 3 May 2024 05:48:14 +0000 Subject: [PATCH 6/7] mask and weight support for rf training --- opensora/schedulers/rf/__init__.py | 3 +-- opensora/schedulers/rf/rectified_flow.py | 28 +++++++++++------------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/opensora/schedulers/rf/__init__.py b/opensora/schedulers/rf/__init__.py index aa0e086..b6239de 100644 --- a/opensora/schedulers/rf/__init__.py +++ b/opensora/schedulers/rf/__init__.py @@ -27,7 +27,7 @@ class RFLOW: guidance_scale = None, # progress = True, ): - assert mask is None, "mask is not supported in rectified flow yet" + assert mask is None, "mask is not supported in rectified flow inference yet" # if no specific guidance scale is provided, use the default scale when initializing the scheduler if guidance_scale is None: guidance_scale = self.cfg_scale @@ -46,7 +46,6 @@ class RFLOW: for i, t in enumerate(timesteps): z_in = torch.cat([z, z], 0) - print(z_in.shape, torch.tensor([t]* z_in.shape[0], device = device).shape) pred = model(z_in, torch.tensor([t]* z_in.shape[0], device = device), **model_args).chunk(2, dim = 1)[0] pred_cond, pred_uncond = pred.chunk(2, dim = 0) v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) diff --git a/opensora/schedulers/rf/rectified_flow.py b/opensora/schedulers/rf/rectified_flow.py index 9f85c10..c4ad216 100644 --- a/opensora/schedulers/rf/rectified_flow.py +++ b/opensora/schedulers/rf/rectified_flow.py @@ -1,6 +1,9 @@ import torch import numpy as np from typing import Union +from einops import rearrange +from typing import List +from ..iddpm.gaussian_diffusion import _extract_into_tensor, mean_flat # some code are inspired by https://github.com/magic-research/piecewise-rectified-flow/blob/main/scripts/train_perflow.py # and https://github.com/magic-research/piecewise-rectified-flow/blob/main/src/scheduler_perflow.py @@ -22,7 +25,6 @@ class RFlowScheduler: Arguments format copied from opensora/schedulers/iddpm/gaussian_diffusion.py/training_losses Note: t is int tensor and should be rescaled from [0, num_timesteps-1] to [1,0] ''' - assert mask is None, "mask not support for rectified flow yet" assert weights is None, "weights not support for rectified flow yet" if model_kwargs is None: @@ -32,11 +34,20 @@ class RFlowScheduler: assert noise.shape == x_start.shape x_t = self.add_noise(x_start, noise, t) + if mask is not None: + t0 = torch.zeros_like(t) + x_t0 = self.add_noise(x_start, noise, t0) + x_t = torch.where(mask[:, None, :, None, None], x_t, x_t0) + terms = {} model_output = model(x_t, t, **model_kwargs) velocity_pred = model_output.chunk(2, dim = 1)[0] - loss = (velocity_pred - (x_start - noise)).pow(2).mean() + if weights is None: + loss = mean_flat((velocity_pred - (x_start - noise)).pow(2), mask = mask) + else: + weight = _extract_into_tensor(weights, t, x_start.shape) + loss = mean_flat(weight * (velocity_pred - (x_start - noise)).pow(2), mask = mask) terms['loss'] = loss return terms @@ -61,19 +72,6 @@ class RFlowScheduler: return timepoints * original_samples + (1 - timepoints) * noise - # def step( - # self, - # model_output: torch.FloatTensor, - # timestep: Union[int, torch.IntTensor], - # sample: torch.FloatTensor, - # ) -> torch.FloatTensor: - # ''' - # take an Euler step sampling - # ''' - - # dt = 1 / self.num_sampling_steps - - # return sample + dt * model_output \ No newline at end of file From 659e054e3ae58cfc896c9b948606bfe0c274ed18 Mon Sep 17 00:00:00 2001 From: Tianyi Date: Fri, 3 May 2024 05:54:03 +0000 Subject: [PATCH 7/7] assert remove --- opensora/schedulers/rf/rectified_flow.py | 1 - 1 file changed, 1 deletion(-) diff --git a/opensora/schedulers/rf/rectified_flow.py b/opensora/schedulers/rf/rectified_flow.py index c4ad216..05357ac 100644 --- a/opensora/schedulers/rf/rectified_flow.py +++ b/opensora/schedulers/rf/rectified_flow.py @@ -25,7 +25,6 @@ class RFlowScheduler: Arguments format copied from opensora/schedulers/iddpm/gaussian_diffusion.py/training_losses Note: t is int tensor and should be rescaled from [0, num_timesteps-1] to [1,0] ''' - assert weights is None, "weights not support for rectified flow yet" if model_kwargs is None: model_kwargs = {}