[feat] new training scheme for mask

This commit is contained in:
Zangwei Zheng 2024-03-24 16:07:18 +08:00
parent 7d478f5094
commit 26c583118b
4 changed files with 128 additions and 24 deletions

View file

@ -0,0 +1,54 @@
num_frames = 16
frame_interval = 3
image_size = (256, 256)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
# Define acceleration
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
model = dict(
type="STDiT-XL/2",
space_scale=0.5,
time_scale=1.0,
from_pretrained="PixArt-XL-2-512x512.pth",
enable_flashattn=True,
enable_layernorm_kernel=True,
)
mask_ratios = [0.7, 0.15, 0.05, 0.05, 0.05]
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=120,
shardformer=True,
)
scheduler = dict(
type="iddpm",
timestep_respacing="",
)
# Others
seed = 42
outputs = "outputs"
wandb = False
epochs = 1000
log_every = 10
ckpt_every = 1000
load = None
batch_size = 8
lr = 2e-5
grad_clip = 1.0

View file

@ -19,7 +19,6 @@ model = dict(
type="STDiT-XL/2",
space_scale=1.0,
time_scale=1.0,
use_x_mask=True,
from_pretrained=None,
enable_flashattn=True,
enable_layernorm_kernel=True,
@ -48,7 +47,7 @@ wandb = False
epochs = 1000
log_every = 10
ckpt_every = 500
ckpt_every = 1000
load = None
batch_size = 8

View file

@ -328,7 +328,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
# apply all to all to gather back attention heads and scatter sequence
x = x.view(B, -1, self.num_heads // sp_size, self.head_dim)
x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2)
@ -363,16 +363,28 @@ class T2IFinalLayer(nn.Module):
The final layer of PixArt.
"""
def __init__(self, hidden_size, num_patch, out_channels):
def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
self.out_channels = out_channels
self.d_t = d_t
self.d_s = d_s
def forward(self, x, t):
def forward(self, x, t, x_mask=None, t0=None):
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
if x_mask is not None:
shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1)
x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero)
# t_mask_select
x = rearrange(x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s)
x_zero = rearrange(x_zero, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s)
x = torch.where(x_mask[:, :, None, None], x, x_zero)
x = rearrange(x, "B T S C -> B (T S) C", T=self.d_t, S=self.d_s)
x = self.linear(x)
return x

View file

@ -3,9 +3,8 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from einops import rearrange
from timm.models.layers import DropPath
from timm.models.layers import DropPath, trunc_normal_
from timm.models.vision_transformer import Mlp
from timm.models.layers import trunc_normal_
from opensora.acceleration.checkpoint import auto_grad_checkpoint
from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward
@ -86,19 +85,43 @@ class STDiTBlock(nn.Module):
enable_flashattn=self.enable_flashattn,
)
def forward(self, x, y, t, mask=None, tpe=None):
def t_mask_select(self, x, masked_x, x_mask):
# x: [B, T, S, C]
# mased_x: [B, T, S, C]
# x_mask: [B, T]
x = rearrange(x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s)
masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s)
x = torch.where(x_mask[:, :, None, None], x, masked_x)
x = rearrange(x, "B T S C -> B (T S) C")
return x
def forward(self, x, y, t, mask=None, tpe=None, x_mask=None, t0=None):
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + t.reshape(B, 6, -1)
).chunk(6, dim=1)
x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
if x_mask is not None:
shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = (
self.scale_shift_table[None] + t0.reshape(B, 6, -1)
).chunk(6, dim=1)
x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero, scale_msa_zero)
x_m = self.t_mask_select(x_m, x_m_zero, x_mask)
# spatial branch
x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s)
x_s = self.attn(x_s)
x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=self.d_t, S=self.d_s)
x = x + self.drop_path(gate_msa * x_s)
if x_mask is not None:
x_s_zero = gate_msa_zero * x_s
x_s = gate_msa * x_s
x_s = self.t_mask_select(x_s, x_s_zero, x_mask)
else:
x_s = gate_msa * x_s
x = x + self.drop_path(x_s)
# temporal branch
x_t = rearrange(x, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s)
@ -112,7 +135,20 @@ class STDiTBlock(nn.Module):
x = x + self.cross_attn(x, y, mask)
# mlp
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)
if x_mask is not None:
x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero, scale_mlp_zero)
x_m = self.t_mask_select(x_m, x_m_zero, x_mask)
x_mlp = self.mlp(x_m)
if x_mask is not None:
x_mlp_zero = gate_mlp_zero * x_mlp
x_mlp = gate_mlp * x_mlp
x_mlp = self.t_mask_select(x_mlp, x_mlp_zero, x_mask)
else:
x_mlp = gate_mlp * x_mlp
x = x + self.drop_path(x_mlp)
return x
@ -138,7 +174,6 @@ class STDiT(nn.Module):
space_scale=1.0,
time_scale=1.0,
freeze=None,
use_x_mask=False,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
@ -195,12 +230,13 @@ class STDiT(nn.Module):
for i in range(self.depth)
]
)
self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)
self.use_x_mask = use_x_mask
if self.use_x_mask:
self.x_mask = nn.Parameter(torch.zeros(self.hidden_size))
trunc_normal_(self.x_mask, std=0.02)
self.final_layer = T2IFinalLayer(
hidden_size,
np.prod(self.patch_size),
self.out_channels,
d_t=self.num_temporal,
d_s=self.num_spatial,
)
# init model
self.initialize_weights()
@ -240,10 +276,6 @@ class STDiT(nn.Module):
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
if x_mask is not None:
assert self.use_x_mask
# x: [B, T, S, C], mask: [B, T], self.x_mask: [C]
x = x + ~x_mask[:, :, None, None] * self.x_mask
x = rearrange(x, "B T S C -> B (T S) C")
# shard over the sequence dim if sp is enabled
@ -251,7 +283,14 @@ class STDiT(nn.Module):
x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
t0 = self.t_block(t) # [B, C]
t_mlp = self.t_block(t) # [B, C]
if x_mask is not None:
t0_timestep = torch.zeros_like(timestep)
t0 = self.t_embedder(t0_timestep, dtype=x.dtype)
t0_mlp = self.t_block(t0)
else:
t0 = None
t0_mlp = None
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
@ -275,14 +314,14 @@ class STDiT(nn.Module):
tpe = self.pos_embed_temporal
else:
tpe = None
x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
x = auto_grad_checkpoint(block, x, y, t_mlp, y_lens, tpe, x_mask, t0_mlp)
if self.enable_sequence_parallelism:
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
# x.shape: [B, N, C]
# final process
x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out]
x = self.final_layer(x, t, x_mask, t0) # [B, N, C=T_p * H_p * W_p * C_out]
x = self.unpatchify(x) # [B, C_out, T, H, W]
# cast to float32 for better accuracy