diff --git a/opensora/models/vae/vae_3d.py b/opensora/models/vae/vae_3d.py index e7f81c9..44ec554 100644 --- a/opensora/models/vae/vae_3d.py +++ b/opensora/models/vae/vae_3d.py @@ -222,6 +222,7 @@ class Decoder(nn.Module): self.channel_multipliers = channel_multipliers self.temporal_downsample = temporal_downsample self.num_groups = num_groups + self.dtype = dtype if isinstance(self.temporal_downsample, int): self.temporal_downsample = _get_selected_flags(