hotfix: dit inference

This commit is contained in:
Zangwei Zheng 2024-03-22 13:07:21 +08:00
parent 966a0d0b97
commit 41f0ce3209

View file

@ -95,6 +95,7 @@ class DiT(nn.Module):
dtype=torch.float32,
enable_flashattn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
):
super().__init__()
self.learn_sigma = learn_sigma
@ -118,6 +119,7 @@ class DiT(nn.Module):
self.no_temporal_pos_emb = no_temporal_pos_emb
self.mlp_ratio = mlp_ratio
self.depth = depth
assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in DiT"
self.register_buffer("pos_embed_spatial", self.get_spatial_pos_embed())
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
@ -188,6 +190,8 @@ class DiT(nn.Module):
"""
# origin inputs should be float32, cast to specified dtype
x = x.to(self.dtype)
if self.use_text_encoder:
y = y.to(self.dtype)
# embedding
x = self.x_embedder(x) # (B, N, D)