diff --git a/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py b/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py index 8405bfa..86f8711 100644 --- a/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py +++ b/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py @@ -84,9 +84,11 @@ magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 epochs = 200 log_every = 1 -ckpt_every = 1 # 50 + +ckpt_every = 50 load = None -batch_size = 4 # 32 +batch_size = 32 + lr = 1e-4 grad_clip = 1.0 diff --git a/scripts/train-vae-v2.py b/scripts/train-vae-v2.py index b6cdd25..4967935 100644 --- a/scripts/train-vae-v2.py +++ b/scripts/train-vae-v2.py @@ -16,6 +16,7 @@ from tqdm import tqdm import os from einops import rearrange import numpy as np +from glob import glob from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import ( @@ -53,19 +54,32 @@ def main(): # 1. args & cfg # ====================================================== cfg = parse_configs(training=True) - exp_name, exp_dir = create_experiment_workspace(cfg) - save_training_config(cfg._cfg_dict, exp_dir) # ====================================================== # 2. runtime variables & colossalai launch # ====================================================== assert torch.cuda.is_available(), "Training currently requires at least one GPU." - assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}" - + # 2.1. colossalai init distributed training colossalai.launch_from_torch({}) coordinator = DistCoordinator() + + exp_dir = None + if coordinator.is_master(): # only create directory for master + exp_name, exp_dir = create_experiment_workspace(cfg) + save_training_config(cfg._cfg_dict, exp_dir) + dist.barrier() + + # get exp dir for non-master process + if exp_dir is None: + experiment_index = len(glob(f"{cfg.outputs}/*"))-1 + model_name = cfg.model["type"].replace("/", "-") + exp_name = f"{experiment_index:03d}-F{cfg.num_frames}S{cfg.frame_interval}-{model_name}" + exp_dir = f"{cfg.outputs}/{exp_name}" + assert os.path.exists(exp_dir) + device = get_current_device() + assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}" dtype = to_torch_dtype(cfg.dtype) # 2.2. init logger, tensorboard & wandb @@ -224,14 +238,14 @@ def main(): booster.load_lr_scheduler(lr_scheduler, os.path.join(cfg.load, "lr_scheduler")) if disc_lr_scheduler is not None: booster.load_lr_scheduler(disc_lr_scheduler, os.path.join(cfg.load, "disc_lr_scheduler")) + # LeCam EMA for discriminator lecam_path = os.path.join(cfg.load, "lecam_states.json") if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path): lecam_state = load_json(lecam_path) lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"] lecam_ema = LeCamEMA(decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device) - else: - print(f"lecan not loaded, path: {lecam_path}, lecame loss weight {cfg.lecam_loss_weight}") + running_states = load_json(os.path.join(cfg.load, "running_states.json")) dist.barrier() start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"] @@ -274,9 +288,6 @@ def main(): disc_time_padding = 0 video_contains_first_frame = cfg.video_contains_first_frame - - # lecam_ema_real = torch.tensor(0.0) - # lecam_ema_fake = torch.tensor(0.0) for epoch in range(start_epoch, cfg.epochs): dataloader.sampler.set_epoch(epoch) @@ -406,9 +417,6 @@ def main(): ) disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss if cfg.lecam_loss_weight is not None: - # SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc. - # lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(real_logits.clone().detach()) - # lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(fake_logits.clone().detach()) ema_real = torch.mean(real_logits.clone().detach()).to(device, dtype) ema_fake = torch.mean(fake_logits.clone().detach()).to(device, dtype) all_reduce_mean(ema_real) @@ -463,7 +471,7 @@ def main(): # Save checkpoint if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0: save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") - os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) + os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model booster.save_model(vae, os.path.join(save_dir, "model"), shard=True) booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True) booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) @@ -491,12 +499,11 @@ def main(): if cfg.lecam_loss_weight is not None: save_json(lecam_state, os.path.join(save_dir, "lecam_states.json")) dist.barrier() + logger.info( f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" ) - # p.step() - # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))