fix SeqParallelMultiHeadCrossAttention for consistent results in distributed mode (#510)

This commit is contained in:
Jiacheng Yang 2024-06-24 05:07:49 -04:00 committed by GitHub
parent ea44eb6b9e
commit 00fef1d1af

View file

@ -499,7 +499,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
# shape: # shape:
# q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM] # q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM]
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down") kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down")
k, v = kv.unbind(2) k, v = kv.unbind(2)