fix qknorm bug

This commit is contained in:
zhengzangw 2024-04-26 03:59:50 +00:00
parent 335e8a3eed
commit 7fcf0c6bfd
2 changed files with 19 additions and 5 deletions

View file

@ -141,6 +141,7 @@ class Attention(nn.Module):
norm_layer: nn.Module = LlamaRMSNorm,
enable_flashattn: bool = False,
rope=None,
qk_norm_legacy: bool = False,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
@ -153,6 +154,7 @@ class Attention(nn.Module):
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.qk_norm_legacy = qk_norm_legacy
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
@ -171,11 +173,17 @@ class Attention(nn.Module):
qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
# WARNING: this may be a bug
if self.rope:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
q, k = self.q_norm(q), self.k_norm(k)
if self.qk_norm_legacy:
# WARNING: this may be a bug
if self.rope:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
q, k = self.q_norm(q), self.k_norm(k)
else:
q, k = self.q_norm(q), self.k_norm(k)
if self.rope:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
if enable_flashattn:
from flash_attn import flash_attn_func
@ -222,6 +230,7 @@ class SeqParallelAttention(Attention):
norm_layer: nn.Module = LlamaRMSNorm,
enable_flashattn: bool = False,
rope=None,
qk_norm_legacy: bool = False,
) -> None:
assert rope is None, "Rope is not supported in SeqParallelAttention"
super().__init__(

View file

@ -42,6 +42,7 @@ class STDiT2Block(nn.Module):
enable_sequence_parallelism=False,
rope=None,
qk_norm=False,
qk_norm_legacy=False,
):
super().__init__()
self.hidden_size = hidden_size
@ -64,6 +65,7 @@ class STDiT2Block(nn.Module):
qkv_bias=True,
enable_flashattn=enable_flashattn,
qk_norm=qk_norm,
qk_norm_legacy=qk_norm_legacy,
)
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
@ -86,6 +88,7 @@ class STDiT2Block(nn.Module):
enable_flashattn=self.enable_flashattn,
rope=rope,
qk_norm=qk_norm,
qk_norm_legacy=qk_norm_legacy,
)
self.scale_shift_table_temporal = nn.Parameter(torch.randn(3, hidden_size) / hidden_size**0.5) # new
@ -195,6 +198,7 @@ class STDiT2(nn.Module):
dtype=torch.float32,
freeze=None,
qk_norm=False,
qk_norm_legacy=False,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
@ -244,6 +248,7 @@ class STDiT2(nn.Module):
enable_sequence_parallelism=enable_sequence_parallelism,
rope=self.rope.rotate_queries_or_keys,
qk_norm=qk_norm,
qk_norm_legacy=qk_norm_legacy,
)
for i in range(self.depth)
]