mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-14 18:25:35 +02:00
issue opt saving
This commit is contained in:
parent
576b44d98e
commit
bf1999f9b1
|
|
@ -84,9 +84,9 @@ magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
|
|||
|
||||
epochs = 200
|
||||
log_every = 1
|
||||
ckpt_every = 50
|
||||
ckpt_every = 5
|
||||
load = None
|
||||
|
||||
batch_size = 32
|
||||
batch_size = 4
|
||||
lr = 1e-4
|
||||
grad_clip = 1.0
|
||||
|
|
|
|||
|
|
@ -53,19 +53,26 @@ def main():
|
|||
# 1. args & cfg
|
||||
# ======================================================
|
||||
cfg = parse_configs(training=True)
|
||||
exp_name, exp_dir = create_experiment_workspace(cfg)
|
||||
save_training_config(cfg._cfg_dict, exp_dir)
|
||||
|
||||
# ======================================================
|
||||
# 2. runtime variables & colossalai launch
|
||||
# ======================================================
|
||||
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
||||
assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}"
|
||||
|
||||
|
||||
# 2.1. colossalai init distributed training
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
if coordinator.is_master(): # only create directory for master
|
||||
exp_name, exp_dir = create_experiment_workspace(cfg)
|
||||
print("master creating experiment dir:", exp_dir, exp_name)
|
||||
save_training_config(cfg._cfg_dict, exp_dir)
|
||||
print("process going into barrier A")
|
||||
dist.barrier()
|
||||
print("process left barrier A")
|
||||
|
||||
device = get_current_device()
|
||||
assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}"
|
||||
dtype = to_torch_dtype(cfg.dtype)
|
||||
|
||||
# 2.2. init logger, tensorboard & wandb
|
||||
|
|
@ -221,9 +228,17 @@ def main():
|
|||
booster.load_lr_scheduler(lr_scheduler, os.path.join(cfg.load, "lr_scheduler"))
|
||||
if disc_lr_scheduler is not None:
|
||||
booster.load_lr_scheduler(disc_lr_scheduler, os.path.join(cfg.load, "disc_lr_scheduler"))
|
||||
|
||||
lecam_path = os.path.join(cfg.load, "lecam_state.json")
|
||||
if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path):
|
||||
lecam_state = load_json(lecam_path)
|
||||
lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"]
|
||||
lecam_ema = LeCamEMA(decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device)
|
||||
else:
|
||||
lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device)
|
||||
running_states = load_json(os.path.join(cfg.load, "running_states.json"))
|
||||
print('going to barrier B')
|
||||
dist.barrier()
|
||||
print("left barrier B")
|
||||
start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"]
|
||||
logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}")
|
||||
logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch")
|
||||
|
|
@ -267,7 +282,7 @@ def main():
|
|||
|
||||
# lecam_ema_real = torch.tensor(0.0)
|
||||
# lecam_ema_fake = torch.tensor(0.0)
|
||||
lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device)
|
||||
|
||||
|
||||
for epoch in range(start_epoch, cfg.epochs):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
|
|
@ -453,32 +468,47 @@ def main():
|
|||
|
||||
# Save checkpoint
|
||||
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
|
||||
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
||||
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
|
||||
booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
|
||||
booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
|
||||
booster.save_optimizer(disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096)
|
||||
|
||||
if lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
if disc_lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler"))
|
||||
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step+1,
|
||||
"global_step": global_step+1,
|
||||
"sample_start_index": (step+1) * cfg.batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
dist.barrier()
|
||||
logger.info(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
|
||||
)
|
||||
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
||||
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model
|
||||
booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
|
||||
print("model saved")
|
||||
booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
|
||||
print("discriminator saved")
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
|
||||
print("optimizer saved")
|
||||
booster.save_optimizer(disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096)
|
||||
print("disc opt saved")
|
||||
|
||||
if lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
print("lr scheduler saved")
|
||||
if disc_lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler"))
|
||||
print("disc scheduler saved")
|
||||
|
||||
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
|
||||
lecam_state = {
|
||||
"lecam_ema_real": lecam_ema_real.item(),
|
||||
"lecam_ema_fake": lecam_ema_fake.item(),
|
||||
}
|
||||
save_json(lecam_state, os.path.join(save_dir, "lecam_states.json"))
|
||||
print("lecam state saved")
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step+1,
|
||||
"global_step": global_step+1,
|
||||
"sample_start_index": (step+1) * cfg.batch_size,
|
||||
}
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
logger.info(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
|
||||
)
|
||||
# use barrier to ask non-master processes to wait, lift barrier when master finish saving and reaches here
|
||||
print("process going into barrier C")
|
||||
dist.barrier()
|
||||
print("process left barrier C")
|
||||
|
||||
# p.step()
|
||||
|
||||
# print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue