diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index 5fdd923..9742211 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -890,7 +890,7 @@ class VEALoss(nn.Module): # KL Loss self.kl_loss_weight = kl_loss_weight # Perceptual Loss - self.perceptual_loss_fn = LPIPS().eval() + self.perceptual_loss_fn = LPIPS().eval().to(device, dtype) self.perceptual_loss_weight = perceptual_loss_weight self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)