This commit is contained in:
tianyi 2024-04-19 11:18:29 +08:00
parent 0d98c1d02d
commit 0045d8b7b0
6 changed files with 285 additions and 3 deletions

View file

@ -0,0 +1,88 @@
# Define dataset
# dataset = dict(
# type="VariableVideoTextDataset",
# data_path=None,
# num_frames=None,
# frame_interval=3,
# image_size=(None, None),
# transform_name="resize_crop",
# )
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=1,
frame_interval=1,
image_size=(256, 256),
transform_name="center",
)
bucket_config = { # 6s/it
"256": {1: (1.0, 256)},
"512": {1: (1.0, 80)},
"480p": {1: (1.0, 52)},
"1024": {1: (1.0, 20)},
"1080p": {1: (1.0, 8)},
}
# Define acceleration
num_workers = 16
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
# model = dict(
# type="DiT-XL/2",
# from_pretrained="/home/zhaowangbo/wangbo/PixArt-alpha/pretrained_models/PixArt-XL-2-512x512.pth",
# # input_sq_size=512, # pretrained model is trained on 512x512
# enable_flashattn=True,
# enable_layernorm_kernel=True,
# )
model = dict(
type="PixArt-XL/2",
space_scale=1.0,
time_scale=1.0,
no_temporal_pos_emb=True,
from_pretrained="PixArt-XL-2-512x512.pth",
enable_flashattn=True,
enable_layernorm_kernel=True,
)
# model = dict(
# type="DiT-XL/2",
# # space_scale=1.0,
# # time_scale=1.0,
# no_temporal_pos_emb=True,
# # from_pretrained="PixArt-XL-2-512x512.pth",
# from_pretrained="/home/zhaowangbo/wangbo/PixArt-alpha/pretrained_models/PixArt-XL-2-512x512.pth",
# enable_flashattn=True,
# enable_layernorm_kernel=True,
# )
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
micro_batch_size=4,
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=200,
shardformer=True,
)
scheduler = dict(
type="rflow",
# timestep_respacing="",
)
# Others
seed = 42
outputs = "outputs"
wandb = False
epochs = 10
log_every = 10
ckpt_every = 500
load = None
batch_size = 100 # only for logging
lr = 2e-5
grad_clip = 1.0

View file

@ -0,0 +1,35 @@
num_frames = 16
fps = 24 // 3
image_size = (512, 512)
# Define model
model = dict(
type="STDiT-XL/2",
space_scale=1.0,
time_scale=1.0,
enable_flashattn=True,
enable_layernorm_kernel=True,
from_pretrained="PRETRAINED_MODEL",
)
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
micro_batch_size=2,
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=120,
)
scheduler = dict(
type="rflow",
num_sampling_steps=10,
cfg_scale=7.0,
)
dtype = "bf16"
# Others
batch_size = 2
seed = 42
prompt_path = "./assets/texts/t2v_samples.txt"
save_dir = "./outputs/samples/"

View file

@ -0,0 +1,64 @@
# Define dataset
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=16,
frame_interval=3,
image_size=(256, 256),
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
model = dict(
type="STDiT-XL/2",
space_scale=0.5,
time_scale=1.0,
# from_pretrained="PixArt-XL-2-512x512.pth",
# from_pretrained = "/home/zhaowangbo/wangbo/PixArt-alpha/pretrained_models/OpenSora-v1-HQ-16x512x512.pth",
# from_pretrained = "OpenSora-v1-HQ-16x512x512.pth",
from_pretrained = "PRETRAINED_MODEL",
enable_flashattn=True,
enable_layernorm_kernel=True,
)
# mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07]
# mask_ratios = {
# "mask_no": 0.9,
# "mask_random": 0.06,
# "mask_head": 0.01,
# "mask_tail": 0.01,
# "mask_head_tail": 0.02,
# }
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=120,
shardformer=True,
)
scheduler = dict(
type="rflow",
# timestep_respacing="",
)
# Others
seed = 42
outputs = "outputs"
wandb = True
epochs = 1
log_every = 10
ckpt_every = 1000
load = None
batch_size = 16
lr = 2e-5
grad_clip = 1.0

View file

@ -0,0 +1,39 @@
num_frames = 1
fps = 1
image_size = (512, 512)
# Define model
model = dict(
type="PixArt-XL/2",
space_scale=1.0,
time_scale=1.0,
no_temporal_pos_emb=True,
from_pretrained="PRETRAINED_MODEL",
)
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=120,
)
scheduler = dict(
type="rflow",
num_sampling_steps=20,
cfg_scale=7.0,
)
dtype = "bf16"
# prompt_path = "./assets/texts/t2i_samples.txt"
prompt = [
"Pirate ship trapped in a cosmic maelstrom nebula.",
"A small cactus with a happy face in the Sahara desert.",
"A small cactus with a sad face in the Sahara desert.",
]
# Others
batch_size = 2
seed = 42
save_dir = "./outputs/samples2/"

View file

@ -0,0 +1,55 @@
# Define dataset
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=1,
frame_interval=3,
image_size=(512, 512),
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
model = dict(
type="PixArt-XL/2",
space_scale=1.0,
time_scale=1.0,
no_temporal_pos_emb=True,
# from_pretrained="PixArt-XL-2-512x512.pth",
from_pretrained = "PRETRAINED_MODEL",
enable_flashattn=True,
enable_layernorm_kernel=True,
)
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=120,
shardformer=True,
)
scheduler = dict(
type="rflow",
# timestep_respacing="",
)
# Others
seed = 42
outputs = "outputs"
wandb = True
epochs = 2
log_every = 10
ckpt_every = 1000
load = None
batch_size = 64
lr = 2e-5
grad_clip = 1.0

View file

@ -33,7 +33,6 @@ class RFLOW:
guidance_scale = self.cfg_scale
n = len(prompts)
z = torch.cat([z, z], 0)
model_args = text_encoder.encode(prompts)
y_null = text_encoder.null(n)
model_args["y"] = torch.cat([model_args["y"], y_null], 0)
@ -46,8 +45,10 @@ class RFLOW:
timesteps = [int(round(t)) for t in timesteps]
for i, t in enumerate(timesteps):
pred = model(z, torch.tensor(t * z.shape[0], device = device), **model_args)
pred_cond, pred_uncond = pred.chunk(2, dim = 1)
z_in = torch.cat([z, z], 0)
print(z_in.shape, torch.tensor([t]* z_in.shape[0], device = device).shape)
pred = model(z_in, torch.tensor([t]* z_in.shape[0], device = device), **model_args).chunk(2, dim = 1)[0]
pred_cond, pred_uncond = pred.chunk(2, dim = 0)
v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
dt = (timesteps[i] - timesteps[i+1])/self.num_timesteps if i < len(timesteps) - 1 else 1/self.num_timesteps