mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
lecam support
This commit is contained in:
parent
576b44d98e
commit
c20955e5b3
|
|
@ -156,6 +156,9 @@ def main():
|
|||
f"Trainable discriminator params: {format_numel_str(discriminator_numel_trainable)}, Total model params: {format_numel_str(discriminator_numel)}"
|
||||
)
|
||||
|
||||
# LeCam Initialization
|
||||
lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device)
|
||||
|
||||
# 4.3. move to device
|
||||
if cfg.get("use_pipeline") == True:
|
||||
vae_2d.to(device, dtype).eval() # eval mode, not training!
|
||||
|
|
@ -221,7 +224,12 @@ def main():
|
|||
booster.load_lr_scheduler(lr_scheduler, os.path.join(cfg.load, "lr_scheduler"))
|
||||
if disc_lr_scheduler is not None:
|
||||
booster.load_lr_scheduler(disc_lr_scheduler, os.path.join(cfg.load, "disc_lr_scheduler"))
|
||||
|
||||
# LeCam EMA for discriminator
|
||||
lecam_path = os.path.join(cfg.load, "lecam_state.json")
|
||||
if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path):
|
||||
lecam_state = load_json(lecam_path)
|
||||
lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"]
|
||||
lecam_ema = LeCamEMA(decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device)
|
||||
running_states = load_json(os.path.join(cfg.load, "running_states.json"))
|
||||
dist.barrier()
|
||||
start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"]
|
||||
|
|
@ -267,7 +275,6 @@ def main():
|
|||
|
||||
# lecam_ema_real = torch.tensor(0.0)
|
||||
# lecam_ema_fake = torch.tensor(0.0)
|
||||
lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device)
|
||||
|
||||
for epoch in range(start_epoch, cfg.epochs):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
|
|
@ -471,8 +478,14 @@ def main():
|
|||
"global_step": global_step+1,
|
||||
"sample_start_index": (step+1) * cfg.batch_size,
|
||||
}
|
||||
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
|
||||
lecam_state = {
|
||||
"lecam_ema_real": lecam_ema_real.item(),
|
||||
"lecam_ema_fake": lecam_ema_fake.item(),
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
save_json(lecam_state, os.path.join(save_dir, "lecam_states.json"))
|
||||
dist.barrier()
|
||||
logger.info(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
|
||||
|
|
|
|||
Loading…
Reference in a new issue