diff --git a/opensora/utils/ckpt_utils.py b/opensora/utils/ckpt_utils.py index d6206ae..b788072 100644 --- a/opensora/utils/ckpt_utils.py +++ b/opensora/utils/ckpt_utils.py @@ -249,7 +249,7 @@ def load( load_dir: str, sampler=None, ) -> Tuple[int, int, int]: - booster.load_model(model, os.path.join(load_dir, "model")) + booster.load_model(model, os.path.join(load_dir, model_name)) # ema is not boosted, so we don't use booster.load_model ema.load_state_dict( torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")), diff --git a/scripts/train.py b/scripts/train.py index f4091b1..b139822 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -98,7 +98,7 @@ def main(): ) if cfg.dataset.type == DEFAULT_DATASET_NAME: dataloader = prepare_dataloader(**dataloader_args) - total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size + total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1) logger.info("Total batch size: %s", total_batch_size) else: dataloader = prepare_variable_dataloader( diff --git a/scripts/train_vae.py b/scripts/train_vae.py index 5398e04..add7f10 100644 --- a/scripts/train_vae.py +++ b/scripts/train_vae.py @@ -1,12 +1,11 @@ import os import random from datetime import timedelta -from pprint import pprint +from pprint import pformat import torch import torch.distributed as dist from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device, set_seed @@ -15,7 +14,7 @@ from tqdm import tqdm import wandb from opensora.acceleration.checkpoint import set_grad_checkpoint -from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group +from opensora.acceleration.parallel_states import get_data_parallel_group from opensora.datasets import prepare_dataloader from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELoss from opensora.registry import DATASETS, MODELS, build_module @@ -27,68 +26,64 @@ from opensora.utils.config_utils import ( save_training_config, ) from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, to_torch_dtype +from opensora.utils.train_utils import create_colossalai_plugin + +DEFAULT_DATASET_NAME = "VideoTextDataset" def main(): # ====================================================== - # 1. args & cfg + # 1. configs & runtime variables & colossalai launch # ====================================================== cfg = parse_configs(training=True) - # ====================================================== - # 2. runtime variables & colossalai launch - # ====================================================== assert torch.cuda.is_available(), "Training currently requires at least one GPU." - assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}" + cfg_dtype = cfg.get("dtype", "bf16") + assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}" + dtype = to_torch_dtype(cfg.get("dtype", "bf16")) - # 2.1. colossalai init distributed training - # we set a very large timeout to avoid some processes exit early + # == colossalai init distributed training == + # NOTE: A very large timeout is set to avoid some processes exit early dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) set_seed(1024) coordinator = DistCoordinator() device = get_current_device() - dtype = to_torch_dtype(cfg.dtype) - # 2.2. init exp_dir, logger, tensorboard & wandb + # == init exp_dir == exp_name, exp_dir = define_experiment_workspace(cfg) coordinator.block_all() if coordinator.is_master(): os.makedirs(exp_dir, exist_ok=True) - save_training_config(cfg._cfg_dict, exp_dir) + save_training_config(cfg.to_dict(), exp_dir) coordinator.block_all() - if not coordinator.is_master(): - logger = create_logger(None) - else: - print("Training configuration:") - pprint(cfg._cfg_dict) + # == init logger, tensorboard & wandb == + if coordinator.is_master(): logger = create_logger(exp_dir) logger.info(f"Experiment directory created at {exp_dir}") - - writer = create_tensorboard_writer(exp_dir) - if cfg.wandb: - wandb.init(project="minisora", name=exp_name, config=cfg._cfg_dict) - - # 2.3. initialize ColossalAI booster - if cfg.plugin == "zero2": - plugin = LowLevelZeroPlugin( - stage=2, - precision=cfg.dtype, - initial_scale=2**16, - max_norm=cfg.grad_clip, - ) - set_data_parallel_group(dist.group.WORLD) + logger.info("Training configuration:\n %s", pformat(cfg.to_dict())) + tb_writer = create_tensorboard_writer(exp_dir) + if cfg.get("wandb", False): + wandb.init(project="minisora", name=exp_name, config=cfg.to_dict()) else: - raise ValueError(f"Unknown plugin {cfg.plugin}") + logger = create_logger(None) + + # == init ColossalAI booster == + plugin = create_colossalai_plugin( + plugin=cfg.get("plugin", "zero2"), + dtype=cfg_dtype, + grad_clip=cfg.get("grad_clip", 0), + sp_size=cfg.get("sp_size", 1), + ) booster = Booster(plugin=plugin) # ====================================================== - # 3. build dataset and dataloader + # 2. build dataset and dataloader # ====================================================== - assert cfg.dataset.type == "VideoTextDataset", "Only support VideoTextDataset for now" dataset = build_module(cfg.dataset, DATASETS) - logger.info(f"Dataset contains {len(dataset)} samples.") + logger.info("Dataset contains %s samples.", len(dataset)) + # == build dataloader == dataloader_args = dict( dataset=dataset, batch_size=cfg.batch_size, @@ -99,76 +94,93 @@ def main(): pin_memory=True, process_group=get_data_parallel_group(), ) - # TODO: use plugin's prepare dataloader - dataloader = prepare_dataloader(**dataloader_args) - total_batch_size = cfg.batch_size * dist.get_world_size() - logger.info(f"Total batch size: {total_batch_size}") + if cfg.dataset.type == DEFAULT_DATASET_NAME: + dataloader = prepare_dataloader(**dataloader_args) + total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1) + logger.info("Total batch size: %s", total_batch_size) + else: + dataloader = prepare_variable_dataloader( + bucket_config=cfg.bucket_config, + num_bucket_build_workers=cfg.num_bucket_build_workers, + **dataloader_args, + ) # ====================================================== - # 4. build model + # 3. build model # ====================================================== - # 4.1. build model - model = build_module(cfg.model, MODELS) - model.to(device, dtype) + # == build vae model == + model = build_module(cfg.model, MODELS).to(device, dtype) + model.train() model_numel, model_numel_trainable = get_model_numel(model) logger.info( - f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}" + "Trainable model params: %s, Total model params: %s", + format_numel_str(model_numel_trainable), + format_numel_str(model_numel), ) + # == build discriminator model == + if cfg.get("discriminator", False) != False: + discriminator = build_module(cfg.discriminator, MODELS).to(device, dtype) + discriminator.train() + discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator) + logger.info( + "Trainable model params: %s, Total model params: %s", + format_numel_str(discriminator_numel_trainable), + format_numel_str(discriminator_numel), + ) - # 4.4 loss functions + # == setup loss functions == vae_loss_fn = VAELoss( logvar_init=cfg.get("logvar_init", 0.0), - perceptual_loss_weight=cfg.perceptual_loss_weight, - kl_loss_weight=cfg.kl_loss_weight, + perceptual_loss_weight=cfg.get("perceptual_loss_weight", 0.1), + kl_loss_weight=cfg.get("kl_loss_weight", 1e-6), device=device, dtype=dtype, ) if cfg.get("discriminator", False) != False: adversarial_loss_fn = AdversarialLoss( - discriminator_factor=cfg.discriminator_factor, - discriminator_start=cfg.discriminator_start, - generator_factor=cfg.generator_factor, - generator_loss_type=cfg.generator_loss_type, + discriminator_factor=cfg.get("discriminator_factor", 1), + discriminator_start=cfg.get("discriminator_start", -1), + generator_factor=cfg.get("generator_factor", 0.5), + generator_loss_type=cfg.get("generator_loss_type", "hinge"), ) disc_loss_fn = DiscriminatorLoss( - discriminator_factor=cfg.discriminator_factor, - discriminator_start=cfg.discriminator_start, - discriminator_loss_type=cfg.discriminator_loss_type, - lecam_loss_weight=cfg.lecam_loss_weight, - gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight, + discriminator_factor=cfg.get("discriminator_factor", 1), + discriminator_start=cfg.get("discriminator_start", -1), + discriminator_loss_type=cfg.get("discriminator_loss_type", "hinge"), + lecam_loss_weight=cfg.get("lecam_loss_weight", None), + gradient_penalty_loss_weight=cfg.get("gradient_penalty_loss_weight", None), ) - # 4.5. setup optimizer - # vae optimizer + # == setup vae optimizer == optimizer = HybridAdam( - filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True + filter(lambda p: p.requires_grad, model.parameters()), + adamw_mode=True, + lr=cfg.get("lr", 1e-5), + weight_decay=cfg.get("weight_decay", 0), ) lr_scheduler = None - # 4.6. prepare for training - if cfg.grad_checkpoint: - set_grad_checkpoint(model) - model.train() - - # 4.7 add discriminator if specified in config + # == setup discriminator optimizer == if cfg.get("discriminator", False) != False: - discriminator = build_module(cfg.discriminator, MODELS) - discriminator.to(device, dtype) - discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator) - logger.info( - f"Trainable model params: {format_numel_str(discriminator_numel_trainable)}, Total model params: {format_numel_str(discriminator_numel)}" - ) disc_optimizer = HybridAdam( - filter(lambda p: p.requires_grad, discriminator.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True + filter(lambda p: p.requires_grad, discriminator.parameters()), + adamw_mode=True, + lr=cfg.get("lr", 1e-5), + weight_decay=cfg.get("weight_decay", 0), ) disc_lr_scheduler = None - discriminator.train() + + # == additional preparation == + if cfg.get("grad_checkpoint", False): + set_grad_checkpoint(model) # ======================================================= - # 5. boost model for distributed training with colossalai + # 4. distributed training preparation with colossalai # ======================================================= + # == boosting == + # NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32 torch.set_default_dtype(dtype) model, optimizer, _, dataloader, lr_scheduler = booster.boost( model=model, @@ -182,49 +194,53 @@ def main(): optimizer=disc_optimizer, lr_scheduler=disc_lr_scheduler, ) - torch.set_default_dtype(torch.float) - num_steps_per_epoch = len(dataloader) - logger.info("Boost model for distributed training") - num_steps_per_epoch = len(dataloader) + logger.info("Boosting model for distributed training") + if cfg.dataset.type == DEFAULT_DATASET_NAME: + num_steps_per_epoch = len(dataloader) + else: + num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size() + None if cfg.get("start_from_scratch ", False) else dataloader.batch_sampler - # ======================================================= - # 6. training loop - # ======================================================= - start_epoch = start_step = log_step = sampler_start_idx = 0 - acc_step = 0 + cfg_epochs = cfg.get("epochs", 1000) + logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch) + + # == global variables == + start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0 running_loss = 0.0 running_disc_loss = 0.0 - # 6.1. resume training - if cfg.load is not None: + + # == resume == + if cfg.get("load", None) is not None: logger.info("Loading checkpoint") booster.load_model(model, os.path.join(cfg.load, "model")) booster.load_optimizer(optimizer, os.path.join(cfg.load, "optimizer")) - - running_states = load_json(os.path.join(cfg.load, "running_states.json")) - if cfg.get("discriminator", False) != False and os.path.exists(os.path.join(cfg.load, "discriminator")): booster.load_model(discriminator, os.path.join(cfg.load, "discriminator")) booster.load_optimizer(disc_optimizer, os.path.join(cfg.load, "disc_optimizer")) - - dist.barrier() + running_states = load_json(os.path.join(cfg.load, "running_states.json")) start_epoch, start_step, sampler_start_idx = ( running_states["epoch"], running_states["step"], running_states["sample_start_index"], ) - logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}") + logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step) + dist.barrier() - logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch") + if cfg.dataset.type == DEFAULT_DATASET_NAME: + dataloader.sampler.set_start_index(sampler_start_idx) - dataloader.sampler.set_start_index(sampler_start_idx) - - # 6.3. training loop - for epoch in range(start_epoch, cfg.epochs): - dataloader.sampler.set_epoch(epoch) + # ======================================================= + # 5. training loop + # ======================================================= + for epoch in range(start_epoch, cfg_epochs): + # == set dataloader to new epoch == + if cfg.dataset.type == DEFAULT_DATASET_NAME: + dataloader.sampler.set_epoch(epoch) dataloader_iter = iter(dataloader) - logger.info(f"Beginning epoch {epoch}...") + logger.info("Beginning epoch %s...", epoch) + # == training loop in an epoch == with tqdm( enumerate(dataloader_iter, start=start_step), desc=f"Epoch {epoch}", @@ -239,35 +255,36 @@ def main(): length = random.randint(1, x.size(2)) x = x[:, :, :length, :, :] - # ===== VAE ===== + # == vae encoding === x_rec, x_z_rec, z, posterior, x_z = model(x) - # ====== Generator Loss ====== + # == loss initialization == vae_loss = torch.tensor(0.0, device=device, dtype=dtype) disc_loss = torch.tensor(0.0, device=device, dtype=dtype) - log_dict = {} - # real image reconstruction loss + # == real image reconstruction loss computation == nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior) log_dict["kl_loss"] = weighted_kl_loss.item() log_dict["nll_loss"] = weighted_nll_loss.item() if cfg.get("use_real_rec_loss", False): vae_loss += weighted_nll_loss + weighted_kl_loss + # == temporal vae reconstruction loss computation == _, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior, no_perceptual=True) log_dict["z_nll_loss"] = weighted_z_nll_loss.item() - # z reconstruction loss + + # == use z reconstruction loss == if cfg.get("use_z_rec_loss", False): vae_loss += weighted_z_nll_loss - # only for image + # == use image loss only == if cfg.get("use_image_identity_loss", False) and x.size(2) == 1: _, image_identity_loss, _ = vae_loss_fn(x_z, z, posterior, no_perceptual=True) vae_loss += image_identity_loss log_dict["image_identity_loss"] = image_identity_loss.item() - # Adversarial Generator Loss + # == generator adversarial loss == if cfg.get("discriminator", False) != False: recon_video = rearrange(x_rec, "b c t h w -> (b t) c h w").contiguous() global_step = epoch * num_steps_per_epoch + step @@ -281,12 +298,12 @@ def main(): ) vae_loss += adversarial_loss - # Backward & update + # == generator backward & update == optimizer.zero_grad() booster.backward(loss=vae_loss, optimizer=optimizer) optimizer.step() - # Adversarial Discriminator loss + # == discriminator adversarial loss == if cfg.get("discriminator", False) != False: real_video = rearrange(x, "b c t h w -> (b t) c h w").contiguous() fake_video = rearrange(x_rec, "b c t h w -> (b t) c h w").contiguous() @@ -298,31 +315,30 @@ def main(): global_step, ) disc_loss = weighted_d_adversarial_loss - # Backward & update + + # == discriminator backward & update == disc_optimizer.zero_grad() booster.backward(loss=disc_loss, optimizer=disc_optimizer) disc_optimizer.step() all_reduce_mean(disc_loss) running_disc_loss += disc_loss.item() - # Log loss values: + # == update log info == all_reduce_mean(vae_loss) running_loss += vae_loss.item() global_step = epoch * num_steps_per_epoch + step log_step += 1 acc_step += 1 - # Log to tensorboard + # == logging == if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0: avg_loss = running_loss / log_step avg_disc_loss = running_disc_loss / log_step pbar.set_postfix( {"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step} ) - running_loss = 0 - log_step = 0 - running_disc_loss = 0 - writer.add_scalar("loss", vae_loss.item(), global_step) + # tensorboard + tb_writer.add_scalar("loss", vae_loss.item(), global_step) if cfg.wandb: wandb.log( { @@ -335,8 +351,11 @@ def main(): }, step=global_step, ) + running_loss = 0 + running_disc_loss = 0 + log_step = 0 - # Save checkpoint + # == checkpoint saving == if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0: save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model @@ -356,15 +375,25 @@ def main(): "global_step": global_step + 1, "sample_start_index": (step + 1) * cfg.batch_size, } + if coordinator.is_master(): save_json(running_states, os.path.join(save_dir, "running_states.json")) dist.barrier() + logger.info( - f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" + "Saved checkpoint at epoch %s step %s global_step %s to %s", + epoch, + step + 1, + global_step + 1, + exp_dir, ) # the continue epochs are not resumed, so we need to reset the sampler start index and start step - dataloader.sampler.set_start_index(0) + if cfg.dataset.type == DEFAULT_DATASET_NAME: + dataloader.sampler.set_start_index(0) + else: + dataloader.batch_sampler.set_epoch(epoch + 1) + logger.info("Epoch done, recomputing batch sampler") start_step = 0