From 41f0ce320909742004df48d2c45c6fd83153ddbb Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Fri, 22 Mar 2024 13:07:21 +0800 Subject: [PATCH] hotfix: dit inference --- opensora/models/dit/dit.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/opensora/models/dit/dit.py b/opensora/models/dit/dit.py index a23dd7b..f264f8e 100644 --- a/opensora/models/dit/dit.py +++ b/opensora/models/dit/dit.py @@ -95,6 +95,7 @@ class DiT(nn.Module): dtype=torch.float32, enable_flashattn=False, enable_layernorm_kernel=False, + enable_sequence_parallelism=False, ): super().__init__() self.learn_sigma = learn_sigma @@ -118,6 +119,7 @@ class DiT(nn.Module): self.no_temporal_pos_emb = no_temporal_pos_emb self.mlp_ratio = mlp_ratio self.depth = depth + assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in DiT" self.register_buffer("pos_embed_spatial", self.get_spatial_pos_embed()) self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed()) @@ -188,6 +190,8 @@ class DiT(nn.Module): """ # origin inputs should be float32, cast to specified dtype x = x.to(self.dtype) + if self.use_text_encoder: + y = y.to(self.dtype) # embedding x = self.x_embedder(x) # (B, N, D)