Update pixart.py

This commit is contained in:
周钰坤 2024-04-30 13:53:17 +08:00
parent 34711b7bd2
commit bf633f3e68

View file

@ -211,7 +211,6 @@ class PixArt(nn.Module):
# embedding
x = self.x_embedder(x) # (B, N, D)
x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial)
breakpoint()
x = x + self.pos_embed
if not self.no_temporal_pos_emb:
x = rearrange(x, "b t s d -> b s t d")