diff --git a/scripts/inference-vae-v2.py b/scripts/inference-vae-v2.py index 47e4159..f44945a 100644 --- a/scripts/inference-vae-v2.py +++ b/scripts/inference-vae-v2.py @@ -70,7 +70,7 @@ def main(): dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, - shuffle=True, + shuffle=False, drop_last=True, pin_memory=True, process_group=get_data_parallel_group(), @@ -187,7 +187,7 @@ def main(): adversarial_loss = adversarial_loss_fn( fake_logits, nll_loss, - vae.module.get_last_layer(), + vae.get_last_layer(), cfg.discriminator_start+1, # Hack to use discriminator is_training = vae.training, )