mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 12:49:38 +02:00
fix SeqParallelMultiHeadCrossAttention for consistent results in distributed mode (#510)
This commit is contained in:
parent
ea44eb6b9e
commit
00fef1d1af
|
|
@ -499,7 +499,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
|
||||||
|
|
||||||
# shape:
|
# shape:
|
||||||
# q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM]
|
# q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM]
|
||||||
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim)
|
||||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||||
kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down")
|
kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down")
|
||||||
k, v = kv.unbind(2)
|
k, v = kv.unbind(2)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue