diff --git a/configs/opensora-v1-1/train/image_rflow.py b/configs/opensora-v1-1/train/image_rflow.py new file mode 100644 index 0000000..66c7b56 --- /dev/null +++ b/configs/opensora-v1-1/train/image_rflow.py @@ -0,0 +1,88 @@ +# Define dataset +# dataset = dict( +# type="VariableVideoTextDataset", +# data_path=None, +# num_frames=None, +# frame_interval=3, +# image_size=(None, None), +# transform_name="resize_crop", +# ) +dataset = dict( + type="VideoTextDataset", + data_path=None, + num_frames=1, + frame_interval=1, + image_size=(256, 256), + transform_name="center", +) +bucket_config = { # 6s/it + "256": {1: (1.0, 256)}, + "512": {1: (1.0, 80)}, + "480p": {1: (1.0, 52)}, + "1024": {1: (1.0, 20)}, + "1080p": {1: (1.0, 8)}, +} + +# Define acceleration +num_workers = 16 +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +# model = dict( +# type="DiT-XL/2", +# from_pretrained="/home/zhaowangbo/wangbo/PixArt-alpha/pretrained_models/PixArt-XL-2-512x512.pth", +# # input_sq_size=512, # pretrained model is trained on 512x512 +# enable_flashattn=True, +# enable_layernorm_kernel=True, +# ) +model = dict( + type="PixArt-XL/2", + space_scale=1.0, + time_scale=1.0, + no_temporal_pos_emb=True, + from_pretrained="PixArt-XL-2-512x512.pth", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +# model = dict( +# type="DiT-XL/2", +# # space_scale=1.0, +# # time_scale=1.0, +# no_temporal_pos_emb=True, +# # from_pretrained="PixArt-XL-2-512x512.pth", +# from_pretrained="/home/zhaowangbo/wangbo/PixArt-alpha/pretrained_models/PixArt-XL-2-512x512.pth", +# enable_flashattn=True, +# enable_layernorm_kernel=True, +# ) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=4, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=200, + shardformer=True, +) +scheduler = dict( + type="rflow", + # timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 10 +log_every = 10 +ckpt_every = 500 +load = None + +batch_size = 100 # only for logging +lr = 2e-5 +grad_clip = 1.0 diff --git a/configs/opensora/inference/16x512x512-rflow.py b/configs/opensora/inference/16x512x512-rflow.py new file mode 100644 index 0000000..35804d5 --- /dev/null +++ b/configs/opensora/inference/16x512x512-rflow.py @@ -0,0 +1,35 @@ +num_frames = 16 +fps = 24 // 3 +image_size = (512, 512) + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=1.0, + time_scale=1.0, + enable_flashattn=True, + enable_layernorm_kernel=True, + from_pretrained="PRETRAINED_MODEL", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=2, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, +) +scheduler = dict( + type="rflow", + num_sampling_steps=10, + cfg_scale=7.0, +) +dtype = "bf16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/t2v_samples.txt" +save_dir = "./outputs/samples/" diff --git a/configs/opensora/train/16x256x256-spee-rflow.py b/configs/opensora/train/16x256x256-spee-rflow.py new file mode 100644 index 0000000..cbb929c --- /dev/null +++ b/configs/opensora/train/16x256x256-spee-rflow.py @@ -0,0 +1,64 @@ +# Define dataset +dataset = dict( + type="VideoTextDataset", + data_path=None, + num_frames=16, + frame_interval=3, + image_size=(256, 256), +) + +# Define acceleration +num_workers = 4 +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=0.5, + time_scale=1.0, + # from_pretrained="PixArt-XL-2-512x512.pth", + # from_pretrained = "/home/zhaowangbo/wangbo/PixArt-alpha/pretrained_models/OpenSora-v1-HQ-16x512x512.pth", + # from_pretrained = "OpenSora-v1-HQ-16x512x512.pth", + from_pretrained = "PRETRAINED_MODEL", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +# mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07] +# mask_ratios = { +# "mask_no": 0.9, +# "mask_random": 0.06, +# "mask_head": 0.01, +# "mask_tail": 0.01, +# "mask_head_tail": 0.02, +# } +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="rflow", + # timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = True + +epochs = 1 +log_every = 10 +ckpt_every = 1000 +load = None + +batch_size = 16 +lr = 2e-5 +grad_clip = 1.0 diff --git a/configs/pixart/inference/1x512x512-rflow.py b/configs/pixart/inference/1x512x512-rflow.py new file mode 100644 index 0000000..7bce7e2 --- /dev/null +++ b/configs/pixart/inference/1x512x512-rflow.py @@ -0,0 +1,39 @@ +num_frames = 1 +fps = 1 +image_size = (512, 512) + +# Define model +model = dict( + type="PixArt-XL/2", + space_scale=1.0, + time_scale=1.0, + no_temporal_pos_emb=True, + from_pretrained="PRETRAINED_MODEL", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, +) +scheduler = dict( + type="rflow", + num_sampling_steps=20, + cfg_scale=7.0, +) +dtype = "bf16" + +# prompt_path = "./assets/texts/t2i_samples.txt" +prompt = [ + "Pirate ship trapped in a cosmic maelstrom nebula.", + "A small cactus with a happy face in the Sahara desert.", + "A small cactus with a sad face in the Sahara desert.", +] + +# Others +batch_size = 2 +seed = 42 +save_dir = "./outputs/samples2/" diff --git a/configs/pixart/train/1x512x512-rflow.py b/configs/pixart/train/1x512x512-rflow.py new file mode 100644 index 0000000..8925cdb --- /dev/null +++ b/configs/pixart/train/1x512x512-rflow.py @@ -0,0 +1,55 @@ +# Define dataset +dataset = dict( + type="VideoTextDataset", + data_path=None, + num_frames=1, + frame_interval=3, + image_size=(512, 512), +) + +# Define acceleration +num_workers = 4 +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="PixArt-XL/2", + space_scale=1.0, + time_scale=1.0, + no_temporal_pos_emb=True, + # from_pretrained="PixArt-XL-2-512x512.pth", + from_pretrained = "PRETRAINED_MODEL", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="rflow", + # timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = True + +epochs = 2 +log_every = 10 +ckpt_every = 1000 +load = None + +batch_size = 64 +lr = 2e-5 +grad_clip = 1.0 diff --git a/opensora/schedulers/__init__.py b/opensora/schedulers/__init__.py index 97ea76f..c7733a4 100644 --- a/opensora/schedulers/__init__.py +++ b/opensora/schedulers/__init__.py @@ -1,2 +1,3 @@ from .dpms import DPMS from .iddpm import IDDPM +from .rf import RFLOW diff --git a/opensora/schedulers/dpms/__init__.py b/opensora/schedulers/dpms/__init__.py index ba2f648..ed74427 100644 --- a/opensora/schedulers/dpms/__init__.py +++ b/opensora/schedulers/dpms/__init__.py @@ -8,7 +8,7 @@ from .dpm_solver import DPMS @SCHEDULERS.register_module("dpm-solver") -class DMP_SOLVER: +class DPM_SOLVER: def __init__(self, num_sampling_steps=None, cfg_scale=4.0): self.num_sampling_steps = num_sampling_steps self.cfg_scale = cfg_scale diff --git a/opensora/schedulers/rf/__init__.py b/opensora/schedulers/rf/__init__.py new file mode 100644 index 0000000..b6239de --- /dev/null +++ b/opensora/schedulers/rf/__init__.py @@ -0,0 +1,64 @@ +# should have property num_timesteps, +# method sample() training_losses() +import torch +from .rectified_flow import RFlowScheduler +from functools import partial + +from opensora.registry import SCHEDULERS + +@SCHEDULERS.register_module("rflow") +class RFLOW: + def __init__(self, num_sampling_steps = 10, num_timesteps = 1000, cfg_scale = 4.0): + self.num_sampling_steps = num_sampling_steps + self.num_timesteps = num_timesteps + self.cfg_scale = cfg_scale + + self.scheduler = RFlowScheduler(num_timesteps = num_timesteps, num_sampling_steps = num_sampling_steps) + + def sample( + self, + model, + text_encoder, + z, + prompts, + device, + additional_args=None, + mask=None, + guidance_scale = None, + # progress = True, + ): + assert mask is None, "mask is not supported in rectified flow inference yet" + # if no specific guidance scale is provided, use the default scale when initializing the scheduler + if guidance_scale is None: + guidance_scale = self.cfg_scale + + n = len(prompts) + model_args = text_encoder.encode(prompts) + y_null = text_encoder.null(n) + model_args["y"] = torch.cat([model_args["y"], y_null], 0) + if additional_args is not None: + model_args.update(additional_args) + + timesteps = [(1. - i/self.num_sampling_steps) * 1000. for i in range(self.num_sampling_steps)] + + # convert float timesteps to most close int timesteps + timesteps = [int(round(t)) for t in timesteps] + + for i, t in enumerate(timesteps): + z_in = torch.cat([z, z], 0) + pred = model(z_in, torch.tensor([t]* z_in.shape[0], device = device), **model_args).chunk(2, dim = 1)[0] + pred_cond, pred_uncond = pred.chunk(2, dim = 0) + v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + + dt = (timesteps[i] - timesteps[i+1])/self.num_timesteps if i < len(timesteps) - 1 else 1/self.num_timesteps + z = z + v_pred * dt + + return z + + def training_losses(self, model, x_start, t, model_kwargs=None, noise = None, mask = None, weights = None): + return self.scheduler.training_losses(model, x_start, t, model_kwargs, noise, mask, weights) + + + + + \ No newline at end of file diff --git a/opensora/schedulers/rf/rectified_flow.py b/opensora/schedulers/rf/rectified_flow.py new file mode 100644 index 0000000..05357ac --- /dev/null +++ b/opensora/schedulers/rf/rectified_flow.py @@ -0,0 +1,76 @@ +import torch +import numpy as np +from typing import Union +from einops import rearrange +from typing import List +from ..iddpm.gaussian_diffusion import _extract_into_tensor, mean_flat + +# some code are inspired by https://github.com/magic-research/piecewise-rectified-flow/blob/main/scripts/train_perflow.py +# and https://github.com/magic-research/piecewise-rectified-flow/blob/main/src/scheduler_perflow.py + + +class RFlowScheduler: + def __init__( + self, + num_timesteps = 1000, + num_sampling_steps = 10, + ): + self.num_timesteps = num_timesteps + self.num_sampling_steps = num_sampling_steps + + + def training_losses(self, model, x_start, t, model_kwargs=None, noise = None, mask = None, weights = None): + ''' + Compute training losses for a single timestep. + Arguments format copied from opensora/schedulers/iddpm/gaussian_diffusion.py/training_losses + Note: t is int tensor and should be rescaled from [0, num_timesteps-1] to [1,0] + ''' + + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = torch.randn_like(x_start) + assert noise.shape == x_start.shape + + x_t = self.add_noise(x_start, noise, t) + if mask is not None: + t0 = torch.zeros_like(t) + x_t0 = self.add_noise(x_start, noise, t0) + x_t = torch.where(mask[:, None, :, None, None], x_t, x_t0) + + + terms = {} + model_output = model(x_t, t, **model_kwargs) + velocity_pred = model_output.chunk(2, dim = 1)[0] + if weights is None: + loss = mean_flat((velocity_pred - (x_start - noise)).pow(2), mask = mask) + else: + weight = _extract_into_tensor(weights, t, x_start.shape) + loss = mean_flat(weight * (velocity_pred - (x_start - noise)).pow(2), mask = mask) + terms['loss'] = loss + + return terms + + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + """ + compatible with diffusers add_noise() + """ + timepoints = timesteps.float() / self.num_timesteps # [0, 999/1000] + timepoints = 1 - timepoints # [1,1/1000] + + # timepoint (bsz) noise: (bsz, 4, frame, w ,h) + # expand timepoint to noise shape + timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) + timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4]) + + return timepoints * original_samples + (1 - timepoints) * noise + + + + \ No newline at end of file