mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-02 19:57:45 +02:00
inference v2 working
This commit is contained in:
parent
179bb7b125
commit
95517d7fb5
|
|
@ -70,7 +70,7 @@ def main():
|
|||
dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
num_workers=cfg.num_workers,
|
||||
shuffle=True,
|
||||
shuffle=False,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
|
|
@ -187,7 +187,7 @@ def main():
|
|||
adversarial_loss = adversarial_loss_fn(
|
||||
fake_logits,
|
||||
nll_loss,
|
||||
vae.module.get_last_layer(),
|
||||
vae.get_last_layer(),
|
||||
cfg.discriminator_start+1, # Hack to use discriminator
|
||||
is_training = vae.training,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue