mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
[feat] new training scheme for mask
This commit is contained in:
parent
7d478f5094
commit
26c583118b
54
configs/opensora/train/16x256x256-mask.py
Normal file
54
configs/opensora/train/16x256x256-mask.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
num_frames = 16
|
||||
frame_interval = 3
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
num_workers = 4
|
||||
|
||||
# Define acceleration
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="STDiT-XL/2",
|
||||
space_scale=0.5,
|
||||
time_scale=1.0,
|
||||
from_pretrained="PixArt-XL-2-512x512.pth",
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
mask_ratios = [0.7, 0.15, 0.05, 0.05, 0.05]
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="DeepFloyd/t5-v1_1-xxl",
|
||||
model_max_length=120,
|
||||
shardformer=True,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="iddpm",
|
||||
timestep_respacing="",
|
||||
)
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
|
||||
epochs = 1000
|
||||
log_every = 10
|
||||
ckpt_every = 1000
|
||||
load = None
|
||||
|
||||
batch_size = 8
|
||||
lr = 2e-5
|
||||
grad_clip = 1.0
|
||||
|
|
@ -19,7 +19,6 @@ model = dict(
|
|||
type="STDiT-XL/2",
|
||||
space_scale=1.0,
|
||||
time_scale=1.0,
|
||||
use_x_mask=True,
|
||||
from_pretrained=None,
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
|
|
@ -48,7 +47,7 @@ wandb = False
|
|||
|
||||
epochs = 1000
|
||||
log_every = 10
|
||||
ckpt_every = 500
|
||||
ckpt_every = 1000
|
||||
load = None
|
||||
|
||||
batch_size = 8
|
||||
|
|
|
|||
|
|
@ -328,7 +328,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
|
|||
if mask is not None:
|
||||
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
|
||||
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
|
||||
|
||||
|
||||
# apply all to all to gather back attention heads and scatter sequence
|
||||
x = x.view(B, -1, self.num_heads // sp_size, self.head_dim)
|
||||
x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2)
|
||||
|
|
@ -363,16 +363,28 @@ class T2IFinalLayer(nn.Module):
|
|||
The final layer of PixArt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, num_patch, out_channels):
|
||||
def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
|
||||
self.out_channels = out_channels
|
||||
self.d_t = d_t
|
||||
self.d_s = d_s
|
||||
|
||||
def forward(self, x, t):
|
||||
def forward(self, x, t, x_mask=None, t0=None):
|
||||
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
|
||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||
if x_mask is not None:
|
||||
shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1)
|
||||
x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero)
|
||||
|
||||
# t_mask_select
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s)
|
||||
x_zero = rearrange(x_zero, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s)
|
||||
x = torch.where(x_mask[:, :, None, None], x, x_zero)
|
||||
x = rearrange(x, "B T S C -> B (T S) C", T=self.d_t, S=self.d_s)
|
||||
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from timm.models.layers import DropPath
|
||||
from timm.models.layers import DropPath, trunc_normal_
|
||||
from timm.models.vision_transformer import Mlp
|
||||
from timm.models.layers import trunc_normal_
|
||||
|
||||
from opensora.acceleration.checkpoint import auto_grad_checkpoint
|
||||
from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward
|
||||
|
|
@ -86,19 +85,43 @@ class STDiTBlock(nn.Module):
|
|||
enable_flashattn=self.enable_flashattn,
|
||||
)
|
||||
|
||||
def forward(self, x, y, t, mask=None, tpe=None):
|
||||
def t_mask_select(self, x, masked_x, x_mask):
|
||||
# x: [B, T, S, C]
|
||||
# mased_x: [B, T, S, C]
|
||||
# x_mask: [B, T]
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s)
|
||||
masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s)
|
||||
x = torch.where(x_mask[:, :, None, None], x, masked_x)
|
||||
x = rearrange(x, "B T S C -> B (T S) C")
|
||||
return x
|
||||
|
||||
def forward(self, x, y, t, mask=None, tpe=None, x_mask=None, t0=None):
|
||||
B, N, C = x.shape
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + t.reshape(B, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
|
||||
if x_mask is not None:
|
||||
shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = (
|
||||
self.scale_shift_table[None] + t0.reshape(B, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero, scale_msa_zero)
|
||||
x_m = self.t_mask_select(x_m, x_m_zero, x_mask)
|
||||
|
||||
# spatial branch
|
||||
x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s)
|
||||
x_s = self.attn(x_s)
|
||||
x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=self.d_t, S=self.d_s)
|
||||
x = x + self.drop_path(gate_msa * x_s)
|
||||
|
||||
if x_mask is not None:
|
||||
x_s_zero = gate_msa_zero * x_s
|
||||
x_s = gate_msa * x_s
|
||||
x_s = self.t_mask_select(x_s, x_s_zero, x_mask)
|
||||
else:
|
||||
x_s = gate_msa * x_s
|
||||
|
||||
x = x + self.drop_path(x_s)
|
||||
|
||||
# temporal branch
|
||||
x_t = rearrange(x, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s)
|
||||
|
|
@ -112,7 +135,20 @@ class STDiTBlock(nn.Module):
|
|||
x = x + self.cross_attn(x, y, mask)
|
||||
|
||||
# mlp
|
||||
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||
x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
if x_mask is not None:
|
||||
x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero, scale_mlp_zero)
|
||||
x_m = self.t_mask_select(x_m, x_m_zero, x_mask)
|
||||
|
||||
x_mlp = self.mlp(x_m)
|
||||
if x_mask is not None:
|
||||
x_mlp_zero = gate_mlp_zero * x_mlp
|
||||
x_mlp = gate_mlp * x_mlp
|
||||
x_mlp = self.t_mask_select(x_mlp, x_mlp_zero, x_mask)
|
||||
else:
|
||||
x_mlp = gate_mlp * x_mlp
|
||||
|
||||
x = x + self.drop_path(x_mlp)
|
||||
|
||||
return x
|
||||
|
||||
|
|
@ -138,7 +174,6 @@ class STDiT(nn.Module):
|
|||
space_scale=1.0,
|
||||
time_scale=1.0,
|
||||
freeze=None,
|
||||
use_x_mask=False,
|
||||
enable_flashattn=False,
|
||||
enable_layernorm_kernel=False,
|
||||
enable_sequence_parallelism=False,
|
||||
|
|
@ -195,12 +230,13 @@ class STDiT(nn.Module):
|
|||
for i in range(self.depth)
|
||||
]
|
||||
)
|
||||
self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)
|
||||
|
||||
self.use_x_mask = use_x_mask
|
||||
if self.use_x_mask:
|
||||
self.x_mask = nn.Parameter(torch.zeros(self.hidden_size))
|
||||
trunc_normal_(self.x_mask, std=0.02)
|
||||
self.final_layer = T2IFinalLayer(
|
||||
hidden_size,
|
||||
np.prod(self.patch_size),
|
||||
self.out_channels,
|
||||
d_t=self.num_temporal,
|
||||
d_s=self.num_spatial,
|
||||
)
|
||||
|
||||
# init model
|
||||
self.initialize_weights()
|
||||
|
|
@ -240,10 +276,6 @@ class STDiT(nn.Module):
|
|||
x = self.x_embedder(x) # [B, N, C]
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
|
||||
x = x + self.pos_embed
|
||||
if x_mask is not None:
|
||||
assert self.use_x_mask
|
||||
# x: [B, T, S, C], mask: [B, T], self.x_mask: [C]
|
||||
x = x + ~x_mask[:, :, None, None] * self.x_mask
|
||||
x = rearrange(x, "B T S C -> B (T S) C")
|
||||
|
||||
# shard over the sequence dim if sp is enabled
|
||||
|
|
@ -251,7 +283,14 @@ class STDiT(nn.Module):
|
|||
x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
|
||||
|
||||
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
|
||||
t0 = self.t_block(t) # [B, C]
|
||||
t_mlp = self.t_block(t) # [B, C]
|
||||
if x_mask is not None:
|
||||
t0_timestep = torch.zeros_like(timestep)
|
||||
t0 = self.t_embedder(t0_timestep, dtype=x.dtype)
|
||||
t0_mlp = self.t_block(t0)
|
||||
else:
|
||||
t0 = None
|
||||
t0_mlp = None
|
||||
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
|
||||
|
||||
if mask is not None:
|
||||
|
|
@ -275,14 +314,14 @@ class STDiT(nn.Module):
|
|||
tpe = self.pos_embed_temporal
|
||||
else:
|
||||
tpe = None
|
||||
x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
|
||||
x = auto_grad_checkpoint(block, x, y, t_mlp, y_lens, tpe, x_mask, t0_mlp)
|
||||
|
||||
if self.enable_sequence_parallelism:
|
||||
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
|
||||
# x.shape: [B, N, C]
|
||||
|
||||
# final process
|
||||
x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out]
|
||||
x = self.final_layer(x, t, x_mask, t0) # [B, N, C=T_p * H_p * W_p * C_out]
|
||||
x = self.unpatchify(x) # [B, C_out, T, H, W]
|
||||
|
||||
# cast to float32 for better accuracy
|
||||
|
|
|
|||
Loading…
Reference in a new issue