diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index 949e11d..1b0ae0b 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -276,7 +276,9 @@ class StyleGANDiscriminator(nn.Module): num_groups=32, dtype = torch.bfloat16, device="cpu", - ): + ): + super().__init__() + self.dtype = dtype self.input_size = cast_tuple(image_size, 2) self.filters = discriminator_filters