diff --git a/opensora/models/layers/blocks.py b/opensora/models/layers/blocks.py index 55d874a..b104439 100644 --- a/opensora/models/layers/blocks.py +++ b/opensora/models/layers/blocks.py @@ -141,6 +141,7 @@ class Attention(nn.Module): norm_layer: nn.Module = LlamaRMSNorm, enable_flashattn: bool = False, rope=None, + qk_norm_legacy: bool = False, ) -> None: super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" @@ -153,6 +154,7 @@ class Attention(nn.Module): self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.qk_norm_legacy = qk_norm_legacy self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) @@ -171,11 +173,17 @@ class Attention(nn.Module): qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) - # WARNING: this may be a bug - if self.rope: - q = self.rotary_emb(q) - k = self.rotary_emb(k) - q, k = self.q_norm(q), self.k_norm(k) + if self.qk_norm_legacy: + # WARNING: this may be a bug + if self.rope: + q = self.rotary_emb(q) + k = self.rotary_emb(k) + q, k = self.q_norm(q), self.k_norm(k) + else: + q, k = self.q_norm(q), self.k_norm(k) + if self.rope: + q = self.rotary_emb(q) + k = self.rotary_emb(k) if enable_flashattn: from flash_attn import flash_attn_func @@ -222,6 +230,7 @@ class SeqParallelAttention(Attention): norm_layer: nn.Module = LlamaRMSNorm, enable_flashattn: bool = False, rope=None, + qk_norm_legacy: bool = False, ) -> None: assert rope is None, "Rope is not supported in SeqParallelAttention" super().__init__( diff --git a/opensora/models/stdit/stdit2.py b/opensora/models/stdit/stdit2.py index 73fe276..7eafec9 100644 --- a/opensora/models/stdit/stdit2.py +++ b/opensora/models/stdit/stdit2.py @@ -42,6 +42,7 @@ class STDiT2Block(nn.Module): enable_sequence_parallelism=False, rope=None, qk_norm=False, + qk_norm_legacy=False, ): super().__init__() self.hidden_size = hidden_size @@ -64,6 +65,7 @@ class STDiT2Block(nn.Module): qkv_bias=True, enable_flashattn=enable_flashattn, qk_norm=qk_norm, + qk_norm_legacy=qk_norm_legacy, ) self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) @@ -86,6 +88,7 @@ class STDiT2Block(nn.Module): enable_flashattn=self.enable_flashattn, rope=rope, qk_norm=qk_norm, + qk_norm_legacy=qk_norm_legacy, ) self.scale_shift_table_temporal = nn.Parameter(torch.randn(3, hidden_size) / hidden_size**0.5) # new @@ -195,6 +198,7 @@ class STDiT2(nn.Module): dtype=torch.float32, freeze=None, qk_norm=False, + qk_norm_legacy=False, enable_flashattn=False, enable_layernorm_kernel=False, enable_sequence_parallelism=False, @@ -244,6 +248,7 @@ class STDiT2(nn.Module): enable_sequence_parallelism=enable_sequence_parallelism, rope=self.rope.rotate_queries_or_keys, qk_norm=qk_norm, + qk_norm_legacy=qk_norm_legacy, ) for i in range(self.depth) ]