diff --git a/opensora/models/dit/dit.py b/opensora/models/dit/dit.py index a23dd7b..f264f8e 100644 --- a/opensora/models/dit/dit.py +++ b/opensora/models/dit/dit.py @@ -95,6 +95,7 @@ class DiT(nn.Module): dtype=torch.float32, enable_flashattn=False, enable_layernorm_kernel=False, + enable_sequence_parallelism=False, ): super().__init__() self.learn_sigma = learn_sigma @@ -118,6 +119,7 @@ class DiT(nn.Module): self.no_temporal_pos_emb = no_temporal_pos_emb self.mlp_ratio = mlp_ratio self.depth = depth + assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in DiT" self.register_buffer("pos_embed_spatial", self.get_spatial_pos_embed()) self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed()) @@ -188,6 +190,8 @@ class DiT(nn.Module): """ # origin inputs should be float32, cast to specified dtype x = x.to(self.dtype) + if self.use_text_encoder: + y = y.to(self.dtype) # embedding x = self.x_embedder(x) # (B, N, D)