From 26c583118bf92a101b91dfb5fcc5276fbe5729e8 Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Sun, 24 Mar 2024 16:07:18 +0800 Subject: [PATCH] [feat] new training scheme for mask --- configs/opensora/train/16x256x256-mask.py | 54 ++++++++++++++++ configs/opensora/train/16x512x512-mask.py | 3 +- opensora/models/layers/blocks.py | 18 +++++- opensora/models/stdit/stdit.py | 77 +++++++++++++++++------ 4 files changed, 128 insertions(+), 24 deletions(-) create mode 100644 configs/opensora/train/16x256x256-mask.py diff --git a/configs/opensora/train/16x256x256-mask.py b/configs/opensora/train/16x256x256-mask.py new file mode 100644 index 0000000..65425d7 --- /dev/null +++ b/configs/opensora/train/16x256x256-mask.py @@ -0,0 +1,54 @@ +num_frames = 16 +frame_interval = 3 +image_size = (256, 256) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=0.5, + time_scale=1.0, + from_pretrained="PixArt-XL-2-512x512.pth", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +mask_ratios = [0.7, 0.15, 0.05, 0.05, 0.05] +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="iddpm", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 1000 +load = None + +batch_size = 8 +lr = 2e-5 +grad_clip = 1.0 diff --git a/configs/opensora/train/16x512x512-mask.py b/configs/opensora/train/16x512x512-mask.py index 6bc9bdf..c10655f 100644 --- a/configs/opensora/train/16x512x512-mask.py +++ b/configs/opensora/train/16x512x512-mask.py @@ -19,7 +19,6 @@ model = dict( type="STDiT-XL/2", space_scale=1.0, time_scale=1.0, - use_x_mask=True, from_pretrained=None, enable_flashattn=True, enable_layernorm_kernel=True, @@ -48,7 +47,7 @@ wandb = False epochs = 1000 log_every = 10 -ckpt_every = 500 +ckpt_every = 1000 load = None batch_size = 8 diff --git a/opensora/models/layers/blocks.py b/opensora/models/layers/blocks.py index 0f8bd59..3db11de 100644 --- a/opensora/models/layers/blocks.py +++ b/opensora/models/layers/blocks.py @@ -328,7 +328,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention): if mask is not None: attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) - + # apply all to all to gather back attention heads and scatter sequence x = x.view(B, -1, self.num_heads // sp_size, self.head_dim) x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2) @@ -363,16 +363,28 @@ class T2IFinalLayer(nn.Module): The final layer of PixArt. """ - def __init__(self, hidden_size, num_patch, out_channels): + def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True) self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5) self.out_channels = out_channels + self.d_t = d_t + self.d_s = d_s - def forward(self, x, t): + def forward(self, x, t, x_mask=None, t0=None): shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) x = t2i_modulate(self.norm_final(x), shift, scale) + if x_mask is not None: + shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1) + x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero) + + # t_mask_select + x = rearrange(x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s) + x_zero = rearrange(x_zero, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s) + x = torch.where(x_mask[:, :, None, None], x, x_zero) + x = rearrange(x, "B T S C -> B (T S) C", T=self.d_t, S=self.d_s) + x = self.linear(x) return x diff --git a/opensora/models/stdit/stdit.py b/opensora/models/stdit/stdit.py index 1485c9e..22a459b 100644 --- a/opensora/models/stdit/stdit.py +++ b/opensora/models/stdit/stdit.py @@ -3,9 +3,8 @@ import torch import torch.distributed as dist import torch.nn as nn from einops import rearrange -from timm.models.layers import DropPath +from timm.models.layers import DropPath, trunc_normal_ from timm.models.vision_transformer import Mlp -from timm.models.layers import trunc_normal_ from opensora.acceleration.checkpoint import auto_grad_checkpoint from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward @@ -86,19 +85,43 @@ class STDiTBlock(nn.Module): enable_flashattn=self.enable_flashattn, ) - def forward(self, x, y, t, mask=None, tpe=None): + def t_mask_select(self, x, masked_x, x_mask): + # x: [B, T, S, C] + # mased_x: [B, T, S, C] + # x_mask: [B, T] + x = rearrange(x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s) + masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s) + x = torch.where(x_mask[:, :, None, None], x, masked_x) + x = rearrange(x, "B T S C -> B (T S) C") + return x + + def forward(self, x, y, t, mask=None, tpe=None, x_mask=None, t0=None): B, N, C = x.shape shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.scale_shift_table[None] + t.reshape(B, 6, -1) ).chunk(6, dim=1) x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa) + if x_mask is not None: + shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = ( + self.scale_shift_table[None] + t0.reshape(B, 6, -1) + ).chunk(6, dim=1) + x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero, scale_msa_zero) + x_m = self.t_mask_select(x_m, x_m_zero, x_mask) # spatial branch x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s) x_s = self.attn(x_s) x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=self.d_t, S=self.d_s) - x = x + self.drop_path(gate_msa * x_s) + + if x_mask is not None: + x_s_zero = gate_msa_zero * x_s + x_s = gate_msa * x_s + x_s = self.t_mask_select(x_s, x_s_zero, x_mask) + else: + x_s = gate_msa * x_s + + x = x + self.drop_path(x_s) # temporal branch x_t = rearrange(x, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s) @@ -112,7 +135,20 @@ class STDiTBlock(nn.Module): x = x + self.cross_attn(x, y, mask) # mlp - x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp) + if x_mask is not None: + x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero, scale_mlp_zero) + x_m = self.t_mask_select(x_m, x_m_zero, x_mask) + + x_mlp = self.mlp(x_m) + if x_mask is not None: + x_mlp_zero = gate_mlp_zero * x_mlp + x_mlp = gate_mlp * x_mlp + x_mlp = self.t_mask_select(x_mlp, x_mlp_zero, x_mask) + else: + x_mlp = gate_mlp * x_mlp + + x = x + self.drop_path(x_mlp) return x @@ -138,7 +174,6 @@ class STDiT(nn.Module): space_scale=1.0, time_scale=1.0, freeze=None, - use_x_mask=False, enable_flashattn=False, enable_layernorm_kernel=False, enable_sequence_parallelism=False, @@ -195,12 +230,13 @@ class STDiT(nn.Module): for i in range(self.depth) ] ) - self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels) - - self.use_x_mask = use_x_mask - if self.use_x_mask: - self.x_mask = nn.Parameter(torch.zeros(self.hidden_size)) - trunc_normal_(self.x_mask, std=0.02) + self.final_layer = T2IFinalLayer( + hidden_size, + np.prod(self.patch_size), + self.out_channels, + d_t=self.num_temporal, + d_s=self.num_spatial, + ) # init model self.initialize_weights() @@ -240,10 +276,6 @@ class STDiT(nn.Module): x = self.x_embedder(x) # [B, N, C] x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial) x = x + self.pos_embed - if x_mask is not None: - assert self.use_x_mask - # x: [B, T, S, C], mask: [B, T], self.x_mask: [C] - x = x + ~x_mask[:, :, None, None] * self.x_mask x = rearrange(x, "B T S C -> B (T S) C") # shard over the sequence dim if sp is enabled @@ -251,7 +283,14 @@ class STDiT(nn.Module): x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down") t = self.t_embedder(timestep, dtype=x.dtype) # [B, C] - t0 = self.t_block(t) # [B, C] + t_mlp = self.t_block(t) # [B, C] + if x_mask is not None: + t0_timestep = torch.zeros_like(timestep) + t0 = self.t_embedder(t0_timestep, dtype=x.dtype) + t0_mlp = self.t_block(t0) + else: + t0 = None + t0_mlp = None y = self.y_embedder(y, self.training) # [B, 1, N_token, C] if mask is not None: @@ -275,14 +314,14 @@ class STDiT(nn.Module): tpe = self.pos_embed_temporal else: tpe = None - x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe) + x = auto_grad_checkpoint(block, x, y, t_mlp, y_lens, tpe, x_mask, t0_mlp) if self.enable_sequence_parallelism: x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up") # x.shape: [B, N, C] # final process - x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out] + x = self.final_layer(x, t, x_mask, t0) # [B, N, C=T_p * H_p * W_p * C_out] x = self.unpatchify(x) # [B, C_out, T, H, W] # cast to float32 for better accuracy