diff --git a/opensora/models/layers/blocks.py b/opensora/models/layers/blocks.py index 8bc7e72..5e2c13d 100644 --- a/opensora/models/layers/blocks.py +++ b/opensora/models/layers/blocks.py @@ -499,7 +499,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention): # shape: # q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM] - q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim) kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down") k, v = kv.unbind(2)