From c20955e5b3f36f27260c20156da9dfd44cdc497a Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Fri, 26 Apr 2024 10:41:40 +0800 Subject: [PATCH] lecam support --- scripts/train-vae-v2.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/scripts/train-vae-v2.py b/scripts/train-vae-v2.py index e9c8ac8..fb18ca5 100644 --- a/scripts/train-vae-v2.py +++ b/scripts/train-vae-v2.py @@ -156,6 +156,9 @@ def main(): f"Trainable discriminator params: {format_numel_str(discriminator_numel_trainable)}, Total model params: {format_numel_str(discriminator_numel)}" ) + # LeCam Initialization + lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device) + # 4.3. move to device if cfg.get("use_pipeline") == True: vae_2d.to(device, dtype).eval() # eval mode, not training! @@ -221,7 +224,12 @@ def main(): booster.load_lr_scheduler(lr_scheduler, os.path.join(cfg.load, "lr_scheduler")) if disc_lr_scheduler is not None: booster.load_lr_scheduler(disc_lr_scheduler, os.path.join(cfg.load, "disc_lr_scheduler")) - + # LeCam EMA for discriminator + lecam_path = os.path.join(cfg.load, "lecam_state.json") + if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path): + lecam_state = load_json(lecam_path) + lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"] + lecam_ema = LeCamEMA(decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device) running_states = load_json(os.path.join(cfg.load, "running_states.json")) dist.barrier() start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"] @@ -267,7 +275,6 @@ def main(): # lecam_ema_real = torch.tensor(0.0) # lecam_ema_fake = torch.tensor(0.0) - lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device) for epoch in range(start_epoch, cfg.epochs): dataloader.sampler.set_epoch(epoch) @@ -471,8 +478,14 @@ def main(): "global_step": global_step+1, "sample_start_index": (step+1) * cfg.batch_size, } + lecam_ema_real, lecam_ema_fake = lecam_ema.get() + lecam_state = { + "lecam_ema_real": lecam_ema_real.item(), + "lecam_ema_fake": lecam_ema_fake.item(), + } if coordinator.is_master(): save_json(running_states, os.path.join(save_dir, "running_states.json")) + save_json(lecam_state, os.path.join(save_dir, "lecam_states.json")) dist.barrier() logger.info( f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"