diff --git a/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py b/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py index b356c46..e3aa372 100644 --- a/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py +++ b/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py @@ -84,9 +84,9 @@ magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 epochs = 200 log_every = 1 -ckpt_every = 50 +ckpt_every = 5 load = None -batch_size = 32 +batch_size = 4 lr = 1e-4 grad_clip = 1.0 diff --git a/scripts/train-vae-v2.py b/scripts/train-vae-v2.py index e9c8ac8..74b6a77 100644 --- a/scripts/train-vae-v2.py +++ b/scripts/train-vae-v2.py @@ -53,19 +53,26 @@ def main(): # 1. args & cfg # ====================================================== cfg = parse_configs(training=True) - exp_name, exp_dir = create_experiment_workspace(cfg) - save_training_config(cfg._cfg_dict, exp_dir) # ====================================================== # 2. runtime variables & colossalai launch # ====================================================== assert torch.cuda.is_available(), "Training currently requires at least one GPU." - assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}" - + # 2.1. colossalai init distributed training colossalai.launch_from_torch({}) coordinator = DistCoordinator() + + if coordinator.is_master(): # only create directory for master + exp_name, exp_dir = create_experiment_workspace(cfg) + print("master creating experiment dir:", exp_dir, exp_name) + save_training_config(cfg._cfg_dict, exp_dir) + print("process going into barrier A") + dist.barrier() + print("process left barrier A") + device = get_current_device() + assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}" dtype = to_torch_dtype(cfg.dtype) # 2.2. init logger, tensorboard & wandb @@ -221,9 +228,17 @@ def main(): booster.load_lr_scheduler(lr_scheduler, os.path.join(cfg.load, "lr_scheduler")) if disc_lr_scheduler is not None: booster.load_lr_scheduler(disc_lr_scheduler, os.path.join(cfg.load, "disc_lr_scheduler")) - + lecam_path = os.path.join(cfg.load, "lecam_state.json") + if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path): + lecam_state = load_json(lecam_path) + lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"] + lecam_ema = LeCamEMA(decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device) + else: + lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device) running_states = load_json(os.path.join(cfg.load, "running_states.json")) + print('going to barrier B') dist.barrier() + print("left barrier B") start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"] logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}") logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch") @@ -267,7 +282,7 @@ def main(): # lecam_ema_real = torch.tensor(0.0) # lecam_ema_fake = torch.tensor(0.0) - lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device) + for epoch in range(start_epoch, cfg.epochs): dataloader.sampler.set_epoch(epoch) @@ -453,32 +468,47 @@ def main(): # Save checkpoint if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0: - save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") - os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) - booster.save_model(vae, os.path.join(save_dir, "model"), shard=True) - booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True) - booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) - booster.save_optimizer(disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096) - - if lr_scheduler is not None: - booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) - if disc_lr_scheduler is not None: - booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler")) - - running_states = { - "epoch": epoch, - "step": step+1, - "global_step": global_step+1, - "sample_start_index": (step+1) * cfg.batch_size, - } if coordinator.is_master(): - save_json(running_states, os.path.join(save_dir, "running_states.json")) - dist.barrier() - logger.info( - f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" - ) + save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") + os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model + booster.save_model(vae, os.path.join(save_dir, "model"), shard=True) + print("model saved") + booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True) + print("discriminator saved") + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) + print("optimizer saved") + booster.save_optimizer(disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096) + print("disc opt saved") + + if lr_scheduler is not None: + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + print("lr scheduler saved") + if disc_lr_scheduler is not None: + booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler")) + print("disc scheduler saved") + + lecam_ema_real, lecam_ema_fake = lecam_ema.get() + lecam_state = { + "lecam_ema_real": lecam_ema_real.item(), + "lecam_ema_fake": lecam_ema_fake.item(), + } + save_json(lecam_state, os.path.join(save_dir, "lecam_states.json")) + print("lecam state saved") + running_states = { + "epoch": epoch, + "step": step+1, + "global_step": global_step+1, + "sample_start_index": (step+1) * cfg.batch_size, + } + save_json(running_states, os.path.join(save_dir, "running_states.json")) + logger.info( + f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" + ) + # use barrier to ask non-master processes to wait, lift barrier when master finish saving and reaches here + print("process going into barrier C") + dist.barrier() + print("process left barrier C") - # p.step() # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))