mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 12:49:38 +02:00
Merge remote-tracking branch 'upstream/main' into hotfix/fix-sp
This commit is contained in:
commit
6bb2c599b6
|
|
@ -499,7 +499,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
|
||||||
|
|
||||||
# shape:
|
# shape:
|
||||||
# q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM]
|
# q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM]
|
||||||
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim)
|
||||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||||
kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down")
|
kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down")
|
||||||
k, v = kv.unbind(2)
|
k, v = kv.unbind(2)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue