diff --git a/scripts/inference-debug.py b/scripts/inference-debug.py index b75c2a5..179cd89 100644 --- a/scripts/inference-debug.py +++ b/scripts/inference-debug.py @@ -121,7 +121,7 @@ def main(): # latent_size = vae.get_latent_size(input_size) # 3.2. move to device & eval - vae = vae.to(device, dtype) + vae = vae.to(device, dtype).eval() # # 4.5. setup optimizer # optimizer = HybridAdam(