mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-18 16:49:41 +02:00
hotfix: dit inference
This commit is contained in:
parent
966a0d0b97
commit
41f0ce3209
|
|
@ -95,6 +95,7 @@ class DiT(nn.Module):
|
|||
dtype=torch.float32,
|
||||
enable_flashattn=False,
|
||||
enable_layernorm_kernel=False,
|
||||
enable_sequence_parallelism=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.learn_sigma = learn_sigma
|
||||
|
|
@ -118,6 +119,7 @@ class DiT(nn.Module):
|
|||
self.no_temporal_pos_emb = no_temporal_pos_emb
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.depth = depth
|
||||
assert enable_sequence_parallelism is False, "Sequence parallelism is not supported in DiT"
|
||||
|
||||
self.register_buffer("pos_embed_spatial", self.get_spatial_pos_embed())
|
||||
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
|
||||
|
|
@ -188,6 +190,8 @@ class DiT(nn.Module):
|
|||
"""
|
||||
# origin inputs should be float32, cast to specified dtype
|
||||
x = x.to(self.dtype)
|
||||
if self.use_text_encoder:
|
||||
y = y.to(self.dtype)
|
||||
|
||||
# embedding
|
||||
x = self.x_embedder(x) # (B, N, D)
|
||||
|
|
|
|||
Loading…
Reference in a new issue