inference v2 working

This commit is contained in:
Shen-Chenhui 2024-04-16 18:30:31 +08:00
parent 179bb7b125
commit 95517d7fb5

View file

@ -70,7 +70,7 @@ def main():
dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
shuffle=True,
shuffle=False,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
@ -187,7 +187,7 @@ def main():
adversarial_loss = adversarial_loss_fn(
fake_logits,
nll_loss,
vae.module.get_last_layer(),
vae.get_last_layer(),
cfg.discriminator_start+1, # Hack to use discriminator
is_training = vae.training,
)