mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-20 17:35:58 +02:00
fix qknorm bug
This commit is contained in:
parent
335e8a3eed
commit
7fcf0c6bfd
|
|
@ -141,6 +141,7 @@ class Attention(nn.Module):
|
|||
norm_layer: nn.Module = LlamaRMSNorm,
|
||||
enable_flashattn: bool = False,
|
||||
rope=None,
|
||||
qk_norm_legacy: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
|
|
@ -153,6 +154,7 @@ class Attention(nn.Module):
|
|||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.qk_norm_legacy = qk_norm_legacy
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
|
@ -171,11 +173,17 @@ class Attention(nn.Module):
|
|||
|
||||
qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
# WARNING: this may be a bug
|
||||
if self.rope:
|
||||
q = self.rotary_emb(q)
|
||||
k = self.rotary_emb(k)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
if self.qk_norm_legacy:
|
||||
# WARNING: this may be a bug
|
||||
if self.rope:
|
||||
q = self.rotary_emb(q)
|
||||
k = self.rotary_emb(k)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
else:
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
if self.rope:
|
||||
q = self.rotary_emb(q)
|
||||
k = self.rotary_emb(k)
|
||||
|
||||
if enable_flashattn:
|
||||
from flash_attn import flash_attn_func
|
||||
|
|
@ -222,6 +230,7 @@ class SeqParallelAttention(Attention):
|
|||
norm_layer: nn.Module = LlamaRMSNorm,
|
||||
enable_flashattn: bool = False,
|
||||
rope=None,
|
||||
qk_norm_legacy: bool = False,
|
||||
) -> None:
|
||||
assert rope is None, "Rope is not supported in SeqParallelAttention"
|
||||
super().__init__(
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ class STDiT2Block(nn.Module):
|
|||
enable_sequence_parallelism=False,
|
||||
rope=None,
|
||||
qk_norm=False,
|
||||
qk_norm_legacy=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
|
@ -64,6 +65,7 @@ class STDiT2Block(nn.Module):
|
|||
qkv_bias=True,
|
||||
enable_flashattn=enable_flashattn,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_legacy=qk_norm_legacy,
|
||||
)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
|
||||
|
||||
|
|
@ -86,6 +88,7 @@ class STDiT2Block(nn.Module):
|
|||
enable_flashattn=self.enable_flashattn,
|
||||
rope=rope,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_legacy=qk_norm_legacy,
|
||||
)
|
||||
self.scale_shift_table_temporal = nn.Parameter(torch.randn(3, hidden_size) / hidden_size**0.5) # new
|
||||
|
||||
|
|
@ -195,6 +198,7 @@ class STDiT2(nn.Module):
|
|||
dtype=torch.float32,
|
||||
freeze=None,
|
||||
qk_norm=False,
|
||||
qk_norm_legacy=False,
|
||||
enable_flashattn=False,
|
||||
enable_layernorm_kernel=False,
|
||||
enable_sequence_parallelism=False,
|
||||
|
|
@ -244,6 +248,7 @@ class STDiT2(nn.Module):
|
|||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
rope=self.rope.rotate_queries_or_keys,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_legacy=qk_norm_legacy,
|
||||
)
|
||||
for i in range(self.depth)
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue