mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
add lecam support
This commit is contained in:
parent
c20955e5b3
commit
1d6cee302f
|
|
@ -84,9 +84,9 @@ magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
|
||||||
|
|
||||||
epochs = 200
|
epochs = 200
|
||||||
log_every = 1
|
log_every = 1
|
||||||
ckpt_every = 50
|
ckpt_every = 1 # 50
|
||||||
load = None
|
load = None
|
||||||
|
|
||||||
batch_size = 32
|
batch_size = 4 # 32
|
||||||
lr = 1e-4
|
lr = 1e-4
|
||||||
grad_clip = 1.0
|
grad_clip = 1.0
|
||||||
|
|
|
||||||
|
|
@ -225,11 +225,13 @@ def main():
|
||||||
if disc_lr_scheduler is not None:
|
if disc_lr_scheduler is not None:
|
||||||
booster.load_lr_scheduler(disc_lr_scheduler, os.path.join(cfg.load, "disc_lr_scheduler"))
|
booster.load_lr_scheduler(disc_lr_scheduler, os.path.join(cfg.load, "disc_lr_scheduler"))
|
||||||
# LeCam EMA for discriminator
|
# LeCam EMA for discriminator
|
||||||
lecam_path = os.path.join(cfg.load, "lecam_state.json")
|
lecam_path = os.path.join(cfg.load, "lecam_states.json")
|
||||||
if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path):
|
if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path):
|
||||||
lecam_state = load_json(lecam_path)
|
lecam_state = load_json(lecam_path)
|
||||||
lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"]
|
lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"]
|
||||||
lecam_ema = LeCamEMA(decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device)
|
lecam_ema = LeCamEMA(decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
print(f"lecan not loaded, path: {lecam_path}, lecame loss weight {cfg.lecam_loss_weight}")
|
||||||
running_states = load_json(os.path.join(cfg.load, "running_states.json"))
|
running_states = load_json(os.path.join(cfg.load, "running_states.json"))
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"]
|
start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"]
|
||||||
|
|
@ -403,7 +405,7 @@ def main():
|
||||||
real_video = real_video if cfg.gradient_penalty_loss_weight is not None else None,
|
real_video = real_video if cfg.gradient_penalty_loss_weight is not None else None,
|
||||||
)
|
)
|
||||||
disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
|
disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
|
||||||
if cfg.ema_decay is not None:
|
if cfg.lecam_loss_weight is not None:
|
||||||
# SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc.
|
# SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc.
|
||||||
# lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(real_logits.clone().detach())
|
# lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(real_logits.clone().detach())
|
||||||
# lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(fake_logits.clone().detach())
|
# lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(fake_logits.clone().detach())
|
||||||
|
|
@ -478,6 +480,7 @@ def main():
|
||||||
"global_step": global_step+1,
|
"global_step": global_step+1,
|
||||||
"sample_start_index": (step+1) * cfg.batch_size,
|
"sample_start_index": (step+1) * cfg.batch_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
|
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
|
||||||
lecam_state = {
|
lecam_state = {
|
||||||
"lecam_ema_real": lecam_ema_real.item(),
|
"lecam_ema_real": lecam_ema_real.item(),
|
||||||
|
|
@ -485,6 +488,7 @@ def main():
|
||||||
}
|
}
|
||||||
if coordinator.is_master():
|
if coordinator.is_master():
|
||||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||||
|
if cfg.lecam_loss_weight is not None:
|
||||||
save_json(lecam_state, os.path.join(save_dir, "lecam_states.json"))
|
save_json(lecam_state, os.path.join(save_dir, "lecam_states.json"))
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue