diff --git a/opensora/models/vae/vae_3d.py b/opensora/models/vae/vae_3d.py index 822b038..ed88b82 100644 --- a/opensora/models/vae/vae_3d.py +++ b/opensora/models/vae/vae_3d.py @@ -130,7 +130,7 @@ class Encoder(nn.Module): self.block_args = dict( # norm_fn=self.norm_fn, conv_fn=self.conv_fn, - dtype=self.dtype, + dtype=dtype, activation_fn=self.activation_fn, use_conv_shortcut=False, num_groups=self.num_groups,