From 00fef1d1af0b431ffd4dadea684a2d59d5d880f2 Mon Sep 17 00:00:00 2001 From: Jiacheng Yang Date: Mon, 24 Jun 2024 05:07:49 -0400 Subject: [PATCH] fix SeqParallelMultiHeadCrossAttention for consistent results in distributed mode (#510) --- opensora/models/layers/blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opensora/models/layers/blocks.py b/opensora/models/layers/blocks.py index 8bc7e72..5e2c13d 100644 --- a/opensora/models/layers/blocks.py +++ b/opensora/models/layers/blocks.py @@ -499,7 +499,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention): # shape: # q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM] - q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim) kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down") k, v = kv.unbind(2)