diff --git a/configs/opensora-v1-1/inference/16x256x256.py b/configs/opensora-v1-1/inference/16x256x256.py new file mode 100644 index 0000000..b68602f --- /dev/null +++ b/configs/opensora-v1-1/inference/16x256x256.py @@ -0,0 +1,39 @@ +num_frames = 16 +fps = 24 // 3 +image_size = (256, 256) + +# Define model +model = dict( + type="STDiT2-XL/2", + space_scale=0.5, + time_scale=1.0, + from_pretrained="PixArt-XL-2-1024-MS.pth", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=4, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, +) +scheduler = dict( + type="iddpm", + num_sampling_steps=100, + cfg_scale=7.0, + cfg_channel=3, # or None +) +dtype = "fp16" + +# Condition +prompt_path = "./assets/texts/t2v_samples.txt" +prompt = None # prompt has higher priority than prompt_path + +# Others +batch_size = 1 +seed = 42 +save_dir = "./outputs/samples/" diff --git a/configs/opensora-v1-1/train/16x256x256.py b/configs/opensora-v1-1/train/16x256x256.py new file mode 100644 index 0000000..7905042 --- /dev/null +++ b/configs/opensora-v1-1/train/16x256x256.py @@ -0,0 +1,54 @@ +num_frames = 16 +frame_interval = 3 +image_size = (256, 256) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="STDiT2-XL/2", + space_scale=0.5, + time_scale=1.0, + from_pretrained="PixArt-XL-2-1024-MS.pth", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=4, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="iddpm-speed", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 1000 +load = None + +batch_size = 8 +lr = 2e-5 +grad_clip = 1.0 diff --git a/opensora/datasets/video_transforms.py b/opensora/datasets/video_transforms.py index a0d1cec..8d7d095 100644 --- a/opensora/datasets/video_transforms.py +++ b/opensora/datasets/video_transforms.py @@ -18,7 +18,9 @@ import numbers import random +import numpy as np import torch +from PIL import Image def _is_tensor_video_clip(clip): diff --git a/opensora/models/layers/blocks.py b/opensora/models/layers/blocks.py index 3db11de..dc5f39c 100644 --- a/opensora/models/layers/blocks.py +++ b/opensora/models/layers/blocks.py @@ -121,6 +121,7 @@ class Attention(nn.Module): proj_drop: float = 0.0, norm_layer: nn.Module = nn.LayerNorm, enable_flashattn: bool = False, + rope=None, ) -> None: super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" @@ -137,16 +138,27 @@ class Attention(nn.Module): self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) + self.rope = False + if rope is not None: + self.rope = True + self.rotary_emb = rope + def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape qkv = self.qkv(x) qkv_shape = (B, N, 3, self.num_heads, self.head_dim) - if self.enable_flashattn: - qkv_permute_shape = (2, 0, 1, 3, 4) - else: - qkv_permute_shape = (2, 0, 3, 1, 4) - qkv = qkv.view(qkv_shape).permute(qkv_permute_shape) + + qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) + if self.rope: + q = self.rotary_emb(q) + k = self.rotary_emb(k) + if self.enable_flashattn: + # (B, #heads, N, #dim) -> (B, N, #heads, #dim) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + q, k = self.q_norm(q), self.k_norm(k) if self.enable_flashattn: from flash_attn import flash_attn_func @@ -188,7 +200,9 @@ class SeqParallelAttention(Attention): proj_drop: float = 0.0, norm_layer: nn.Module = nn.LayerNorm, enable_flashattn: bool = False, + rope=None, ) -> None: + assert rope is None, "Rope is not supported in SeqParallelAttention" super().__init__( dim=dim, num_heads=num_heads, @@ -372,19 +386,23 @@ class T2IFinalLayer(nn.Module): self.d_t = d_t self.d_s = d_s + def t_mask_select(self, x_mask, x, masked_x): + # x: [B, (T, S), C] + # mased_x: [B, (T, S), C] + # x_mask: [B, T] + x = rearrange(x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s) + masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s) + x = torch.where(x_mask[:, :, None, None], x, masked_x) + x = rearrange(x, "B T S C -> B (T S) C") + return x + def forward(self, x, t, x_mask=None, t0=None): shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) x = t2i_modulate(self.norm_final(x), shift, scale) if x_mask is not None: shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1) x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero) - - # t_mask_select - x = rearrange(x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s) - x_zero = rearrange(x_zero, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s) - x = torch.where(x_mask[:, :, None, None], x, x_zero) - x = rearrange(x, "B T S C -> B (T S) C", T=self.d_t, S=self.d_s) - + x = self.t_mask_select(x_mask, x, x_zero) x = self.linear(x) return x diff --git a/opensora/models/stdit/__init__.py b/opensora/models/stdit/__init__.py index 5ca2cc9..605159e 100644 --- a/opensora/models/stdit/__init__.py +++ b/opensora/models/stdit/__init__.py @@ -1 +1,2 @@ from .stdit import STDiT +from .stdit2 import STDiT2 diff --git a/opensora/models/stdit/stdit.py b/opensora/models/stdit/stdit.py index 70c5343..6e16058 100644 --- a/opensora/models/stdit/stdit.py +++ b/opensora/models/stdit/stdit.py @@ -3,7 +3,7 @@ import torch import torch.distributed as dist import torch.nn as nn from einops import rearrange -from timm.models.layers import DropPath, trunc_normal_ +from timm.models.layers import DropPath from timm.models.vision_transformer import Mlp from opensora.acceleration.checkpoint import auto_grad_checkpoint diff --git a/opensora/models/stdit/stdit2.py b/opensora/models/stdit/stdit2.py new file mode 100644 index 0000000..3c875ac --- /dev/null +++ b/opensora/models/stdit/stdit2.py @@ -0,0 +1,486 @@ +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from einops import rearrange +from rotary_embedding_torch import RotaryEmbedding +from timm.models.layers import DropPath +from timm.models.vision_transformer import Mlp + +from opensora.acceleration.checkpoint import auto_grad_checkpoint +from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward +from opensora.acceleration.parallel_states import get_sequence_parallel_group +from opensora.models.layers.blocks import ( + Attention, + CaptionEmbedder, + MultiHeadCrossAttention, + PatchEmbed3D, + SeqParallelAttention, + SeqParallelMultiHeadCrossAttention, + SizeEmbedder, + T2IFinalLayer, + TimestepEmbedder, + approx_gelu, + get_1d_sincos_pos_embed, + get_2d_sincos_pos_embed, + get_layernorm, + t2i_modulate, +) +from opensora.registry import MODELS +from opensora.utils.ckpt_utils import load_checkpoint + + +class STDiT2Block(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + d_s=None, + d_t=None, + mlp_ratio=4.0, + drop_path=0.0, + enable_flashattn=False, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + rope=None, + ): + super().__init__() + self.hidden_size = hidden_size + self.enable_flashattn = enable_flashattn + self._enable_sequence_parallelism = enable_sequence_parallelism + + if enable_sequence_parallelism: + self.attn_cls = SeqParallelAttention + self.mha_cls = SeqParallelMultiHeadCrossAttention + else: + self.attn_cls = Attention + self.mha_cls = MultiHeadCrossAttention + + self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.attn = self.attn_cls( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + enable_flashattn=enable_flashattn, + ) + self.cross_attn = self.mha_cls(hidden_size, num_heads) + self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.mlp = Mlp( + in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) # new + + # temporal attention + self.d_s = d_s + self.d_t = d_t + + if self._enable_sequence_parallelism: + sp_size = dist.get_world_size(get_sequence_parallel_group()) + # make sure d_t is divisible by sp_size + assert d_t % sp_size == 0 + self.d_t = d_t // sp_size + + self.attn_temp = self.attn_cls( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + enable_flashattn=self.enable_flashattn, + rope=rope, + ) + self.scale_shift_table_temporal = nn.Parameter(torch.randn(3, hidden_size) / hidden_size**0.5) # new + self.norm_temp = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) # new + + def t_mask_select(self, x_mask, x, masked_x): + # x: [B, (T, S), C] + # mased_x: [B, (T, S), C] + # x_mask: [B, T] + x = rearrange(x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s) + masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=self.d_t, S=self.d_s) + x = torch.where(x_mask[:, :, None, None], x, masked_x) + x = rearrange(x, "B T S C -> B (T S) C") + return x + + def forward(self, x, y, t, t_tmp, mask=None, x_mask=None, t0=None, t0_tmp=None): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + t.reshape(B, 6, -1) + ).chunk(6, dim=1) + shift_tmp, scale_tmp, gate_tmp = (self.scale_shift_table_temporal[None] + t_tmp.reshape(B, 3, -1)).chunk( + 3, dim=1 + ) + if x_mask is not None: + shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = ( + self.scale_shift_table[None] + t0.reshape(B, 6, -1) + ).chunk(6, dim=1) + shift_tmp_zero, scale_tmp_zero, gate_tmp_zero = ( + self.scale_shift_table_temporal[None] + t0_tmp.reshape(B, 3, -1) + ).chunk(3, dim=1) + + # modulate + x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa) + if x_mask is not None: + x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero, scale_msa_zero) + x_m = self.t_mask_select(x_mask, x_m, x_m_zero) + + # spatial branch + x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s) + x_s = self.attn(x_s) + x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=self.d_t, S=self.d_s) + if x_mask is not None: + x_s_zero = gate_msa_zero * x_s + x_s = gate_msa * x_s + x_s = self.t_mask_select(x_mask, x_s, x_s_zero) + else: + x_s = gate_msa * x_s + x = x + self.drop_path(x_s) + + # modulate + x_m = t2i_modulate(self.norm_temp(x), shift_tmp, scale_tmp) + if x_mask is not None: + x_m_zero = t2i_modulate(self.norm_temp(x), shift_tmp_zero, scale_tmp_zero) + x_m = self.t_mask_select(x_mask, x_m, x_m_zero) + + # temporal branch + x_t = rearrange(x_m, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s) + x_t = self.attn_temp(x_t) + x_t = rearrange(x_t, "(B S) T C -> B (T S) C", T=self.d_t, S=self.d_s) + if x_mask is not None: + x_t_zero = gate_tmp_zero * x_t + x_t = gate_tmp * x_t + x_t = self.t_mask_select(x_mask, x_t, x_t_zero) + else: + x_t = gate_tmp * x_t + x = x + self.drop_path(x_t) + + # cross attn + x = x + self.cross_attn(x, y, mask) + + # modulate + x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp) + if x_mask is not None: + x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero, scale_mlp_zero) + x_m = self.t_mask_select(x_mask, x_m, x_m_zero) + + # mlp + x_mlp = self.mlp(x_m) + if x_mask is not None: + x_mlp_zero = gate_mlp_zero * x_mlp + x_mlp = gate_mlp * x_mlp + x_mlp = self.t_mask_select(x_mask, x_mlp, x_mlp_zero) + else: + x_mlp = gate_mlp * x_mlp + x = x + self.drop_path(x_mlp) + + return x + + +@MODELS.register_module() +class STDiT2(nn.Module): + def __init__( + self, + input_size=(1, 32, 32), + in_channels=4, + patch_size=(1, 2, 2), + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + pred_sigma=True, + drop_path=0.0, + no_temporal_pos_emb=False, + caption_channels=4096, + model_max_length=120, + dtype=torch.float32, + space_scale=1.0, + time_scale=1.0, + freeze=None, + enable_flashattn=False, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + ): + super().__init__() + self.pred_sigma = pred_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if pred_sigma else in_channels + self.hidden_size = hidden_size + self.patch_size = patch_size + self.input_size = input_size + num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)]) + self.num_patches = num_patches + self.num_temporal = input_size[0] // patch_size[0] + self.num_spatial = num_patches // self.num_temporal + self.num_heads = num_heads + self.dtype = dtype + self.no_temporal_pos_emb = no_temporal_pos_emb + self.depth = depth + self.mlp_ratio = mlp_ratio + self.enable_flashattn = enable_flashattn + self.enable_layernorm_kernel = enable_layernorm_kernel + self.space_scale = space_scale + self.time_scale = time_scale + + self.register_buffer("pos_embed", self.get_spatial_pos_embed()) + # self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed()) + + self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + self.t_block_temp = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 3 * hidden_size, bias=True)) # new + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, + hidden_size=hidden_size, + uncond_prob=class_dropout_prob, + act_layer=approx_gelu, + token_num=model_max_length, + ) + + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] + self.rope = RotaryEmbedding(dim=self.hidden_size // self.num_heads) # new + self.blocks = nn.ModuleList( + [ + STDiT2Block( + self.hidden_size, + self.num_heads, + mlp_ratio=self.mlp_ratio, + drop_path=drop_path[i], + enable_flashattn=self.enable_flashattn, + enable_layernorm_kernel=self.enable_layernorm_kernel, + enable_sequence_parallelism=enable_sequence_parallelism, + d_t=self.num_temporal, + d_s=self.num_spatial, + rope=self.rope.rotate_queries_or_keys, + ) + for i in range(self.depth) + ] + ) + self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels) + + # multi_res + assert self.hidden_size % 3 == 0, "hidden_size must be divisible by 3" + self.csize_embedder = SizeEmbedder(self.hidden_size // 3) + self.ar_embedder = SizeEmbedder(self.hidden_size // 3) + self.fl_embedder = SizeEmbedder(self.hidden_size) # new + + # init model + self.initialize_weights() + self.initialize_temporal() + if freeze is not None: + assert freeze in ["not_temporal", "text"] + if freeze == "not_temporal": + self.freeze_not_temporal() + elif freeze == "text": + self.freeze_text() + + # sequence parallel related configs + self.enable_sequence_parallelism = enable_sequence_parallelism + if enable_sequence_parallelism: + self.sp_rank = dist.get_rank(get_sequence_parallel_group()) + else: + self.sp_rank = None + + def forward(self, x, timestep, y, mask=None, x_mask=None): + """ + Forward pass of STDiT. + Args: + x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W] + timestep (torch.Tensor): diffusion time steps; of shape [B] + y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C] + mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token] + + Returns: + x (torch.Tensor): output latent representation; of shape [B, C, T, H, W] + """ + B = x.shape[0] + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + + # TODO: hard-coded for now + hw = torch.tensor([self.input_size[1], self.input_size[2]], device=x.device, dtype=x.dtype).repeat(B, 1) + ar = torch.tensor([[self.input_size[1] / self.input_size[2]]], device=x.device, dtype=x.dtype).repeat(B, 1) + fl = torch.tensor([self.input_size[0]], device=x.device, dtype=x.dtype).repeat(B, 1) + csize = self.csize_embedder(hw, B) + ar = self.ar_embedder(ar, B) + fl = self.fl_embedder(fl, B) + data_info = torch.cat([csize, ar], dim=1) + + # embedding + x = self.x_embedder(x) # [B, N, C] + x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial) + x = x + self.pos_embed + x = rearrange(x, "B T S C -> B (T S) C") + + # shard over the sequence dim if sp is enabled + if self.enable_sequence_parallelism: + x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down") + + # prepare adaIN + t = self.t_embedder(timestep, dtype=x.dtype) # [B, C] + t_spc = t + data_info # [B, C] + t_tmp = t + fl # [B, C] + t_spc_mlp = self.t_block(t_spc) # [B, 6*C] + t_tmp_mlp = self.t_block_temp(t + fl) # [B, 3*C] + if x_mask is not None: + t0_timestep = torch.zeros_like(timestep) + t0 = self.t_embedder(t0_timestep, dtype=x.dtype) + t0_spc = t0 + data_info + t0_tmp = t0 + fl + t0_spc_mlp = self.t_block(t0_spc) + t0_tmp_mlp = self.t_block_temp(t0_tmp) + else: + t0_spc = None + t0_tmp = None + t0_spc_mlp = None + t0_tmp_mlp = None + + # prepare y + y = self.y_embedder(y, self.training) # [B, 1, N_token, C] + + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + + # blocks + for _, block in enumerate(self.blocks): + x = auto_grad_checkpoint( + block, + x, + y, + t_spc_mlp, + t_tmp_mlp, + y_lens, + x_mask, + t0_spc_mlp, + t0_tmp_mlp, + ) + + if self.enable_sequence_parallelism: + x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up") + # x.shape: [B, N, C] + + # final process + x = self.final_layer(x, t, x_mask, t0_spc) # [B, N, C=T_p * H_p * W_p * C_out] + x = self.unpatchify(x) # [B, C_out, T, H, W] + + # cast to float32 for better accuracy + x = x.to(torch.float32) + return x + + def unpatchify(self, x): + """ + Args: + x (torch.Tensor): of shape [B, N, C] + + Return: + x (torch.Tensor): of shape [B, C_out, T, H, W] + """ + + N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + T_p, H_p, W_p = self.patch_size + x = rearrange( + x, + "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)", + N_t=N_t, + N_h=N_h, + N_w=N_w, + T_p=T_p, + H_p=H_p, + W_p=W_p, + C_out=self.out_channels, + ) + return x + + def unpatchify_old(self, x): + c = self.out_channels + t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + pt, ph, pw = self.patch_size + + x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c)) + x = rearrange(x, "n t h w r p q c -> n c t r h p w q") + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + return imgs + + def get_spatial_pos_embed(self, grid_size=None): + if grid_size is None: + grid_size = self.input_size[1:] + pos_embed = get_2d_sincos_pos_embed( + self.hidden_size, + (grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]), + scale=self.space_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def get_temporal_pos_embed(self): + pos_embed = get_1d_sincos_pos_embed( + self.hidden_size, + self.input_size[0] // self.patch_size[0], + scale=self.time_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def freeze_not_temporal(self): + for n, p in self.named_parameters(): + if "attn_temp" not in n: + p.requires_grad = False + + def freeze_text(self): + for n, p in self.named_parameters(): + if "cross_attn" in n: + p.requires_grad = False + + def initialize_temporal(self): + for block in self.blocks: + nn.init.constant_(block.attn_temp.proj.weight, 0) + nn.init.constant_(block.attn_temp.proj.bias, 0) + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.t_block[1].weight, std=0.02) + nn.init.normal_(self.t_block_temp[1].weight, std=0.02) + + # Initialize caption embedding MLP: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + # Zero-out adaLN modulation layers in PixArt blocks: + for block in self.blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + +@MODELS.register_module("STDiT2-XL/2") +def STDiT2_XL_2(from_pretrained=None, **kwargs): + model = STDiT2(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model