Merge remote-tracking branch 'upstream/main' into hotfix/fix-sp

This commit is contained in:
FrankLeeeee 2024-06-24 09:08:21 +00:00
commit 6bb2c599b6

View file

@ -499,7 +499,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
# shape:
# q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM]
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down")
k, v = kv.unbind(2)