diff --git a/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py b/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py index b356c46..8405bfa 100644 --- a/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py +++ b/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py @@ -84,9 +84,9 @@ magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 epochs = 200 log_every = 1 -ckpt_every = 50 +ckpt_every = 1 # 50 load = None -batch_size = 32 +batch_size = 4 # 32 lr = 1e-4 grad_clip = 1.0 diff --git a/scripts/train-vae-v2.py b/scripts/train-vae-v2.py index fb18ca5..b6cdd25 100644 --- a/scripts/train-vae-v2.py +++ b/scripts/train-vae-v2.py @@ -225,11 +225,13 @@ def main(): if disc_lr_scheduler is not None: booster.load_lr_scheduler(disc_lr_scheduler, os.path.join(cfg.load, "disc_lr_scheduler")) # LeCam EMA for discriminator - lecam_path = os.path.join(cfg.load, "lecam_state.json") + lecam_path = os.path.join(cfg.load, "lecam_states.json") if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path): lecam_state = load_json(lecam_path) lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"] lecam_ema = LeCamEMA(decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device) + else: + print(f"lecan not loaded, path: {lecam_path}, lecame loss weight {cfg.lecam_loss_weight}") running_states = load_json(os.path.join(cfg.load, "running_states.json")) dist.barrier() start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"] @@ -403,7 +405,7 @@ def main(): real_video = real_video if cfg.gradient_penalty_loss_weight is not None else None, ) disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss - if cfg.ema_decay is not None: + if cfg.lecam_loss_weight is not None: # SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc. # lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(real_logits.clone().detach()) # lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(fake_logits.clone().detach()) @@ -478,6 +480,7 @@ def main(): "global_step": global_step+1, "sample_start_index": (step+1) * cfg.batch_size, } + lecam_ema_real, lecam_ema_fake = lecam_ema.get() lecam_state = { "lecam_ema_real": lecam_ema_real.item(), @@ -485,7 +488,8 @@ def main(): } if coordinator.is_master(): save_json(running_states, os.path.join(save_dir, "running_states.json")) - save_json(lecam_state, os.path.join(save_dir, "lecam_states.json")) + if cfg.lecam_loss_weight is not None: + save_json(lecam_state, os.path.join(save_dir, "lecam_states.json")) dist.barrier() logger.info( f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"