diff --git a/opensora/acceleration/checkpoint.py b/opensora/acceleration/checkpoint.py index d832a01..66ba530 100644 --- a/opensora/acceleration/checkpoint.py +++ b/opensora/acceleration/checkpoint.py @@ -18,7 +18,7 @@ def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1): def auto_grad_checkpoint(module, *args, **kwargs): if getattr(module, "grad_checkpointing", False): if not isinstance(module, Iterable): - return checkpoint(module, *args, **kwargs) + return checkpoint(module, *args, use_reentrant=False, **kwargs) gc_step = module[0].grad_checkpointing_step - return checkpoint_sequential(module, gc_step, *args, **kwargs) + return checkpoint_sequential(module, gc_step, *args, use_reentrant=False, **kwargs) return module(*args, **kwargs) diff --git a/opensora/utils/train_utils.py b/opensora/utils/train_utils.py index f961514..d59aee2 100644 --- a/opensora/utils/train_utils.py +++ b/opensora/utils/train_utils.py @@ -3,6 +3,35 @@ import random from collections import OrderedDict import torch +import torch.distributed as dist +from colossalai.booster.plugin import LowLevelZeroPlugin + +from opensora.acceleration.parallel_states import set_data_parallel_group, set_sequence_parallel_group +from opensora.acceleration.plugin import ZeroSeqParallelPlugin + + +def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size): + if plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=dtype, + initial_scale=2**16, + max_norm=grad_clip, + ) + set_data_parallel_group(dist.group.WORLD) + elif plugin == "zero2-seq": + plugin = ZeroSeqParallelPlugin( + sp_size=sp_size, + stage=2, + precision=dtype, + initial_scale=2**16, + max_norm=grad_clip, + ) + set_sequence_parallel_group(plugin.sp_group) + set_data_parallel_group(plugin.dp_group) + else: + raise ValueError(f"Unknown plugin {plugin}") + return plugin @torch.no_grad() @@ -18,7 +47,7 @@ def update_ema( for name, param in model_params.items(): if name == "pos_embed": continue - if param.requires_grad == False: + if not param.requires_grad: continue if not sharded: param_data = param.data diff --git a/scripts/train.py b/scripts/train.py index 93fc3b6..f4091b1 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,12 +1,11 @@ import os from copy import deepcopy from datetime import timedelta -from pprint import pprint +from pprint import pformat import torch import torch.distributed as dist from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device, set_seed @@ -14,12 +13,7 @@ from tqdm import tqdm import wandb from opensora.acceleration.checkpoint import set_grad_checkpoint -from opensora.acceleration.parallel_states import ( - get_data_parallel_group, - set_data_parallel_group, - set_sequence_parallel_group, -) -from opensora.acceleration.plugin import ZeroSeqParallelPlugin +from opensora.acceleration.parallel_states import get_data_parallel_group from opensora.datasets import prepare_dataloader, prepare_variable_dataloader from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save @@ -30,78 +24,68 @@ from opensora.utils.config_utils import ( save_training_config, ) from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype -from opensora.utils.train_utils import MaskGenerator, update_ema +from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema + +DEFAULT_DATASET_NAME = "VideoTextDataset" def main(): # ====================================================== - # 1. args & cfg + # 1. configs & runtime variables & colossalai launch # ====================================================== + # == parse configs == cfg = parse_configs(training=True) - # ====================================================== - # 2. runtime variables & colossalai launch - # ====================================================== assert torch.cuda.is_available(), "Training currently requires at least one GPU." - assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}" + cfg_dtype = cfg.get("dtype", "bf16") + assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}" + dtype = to_torch_dtype(cfg.get("dtype", "bf16")) - # 2.1. colossalai init distributed training - # we set a very large timeout to avoid some processes exit early + # == colossalai init distributed training == + # NOTE: A very large timeout is set to avoid some processes exit early dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) set_seed(1024) coordinator = DistCoordinator() device = get_current_device() - dtype = to_torch_dtype(cfg.dtype) - # 2.2. init exp_dir, logger, tensorboard & wandb + # == init exp_dir == exp_name, exp_dir = define_experiment_workspace(cfg) coordinator.block_all() if coordinator.is_master(): os.makedirs(exp_dir, exist_ok=True) - save_training_config(cfg._cfg_dict, exp_dir) + save_training_config(cfg.to_dict(), exp_dir) coordinator.block_all() - if not coordinator.is_master(): - logger = create_logger(None) - else: - print("Training configuration:") - pprint(cfg._cfg_dict) + # == init logger, tensorboard & wandb == + if coordinator.is_master(): logger = create_logger(exp_dir) - logger.info(f"Experiment directory created at {exp_dir}") + logger.info("Experiment directory created at %s", exp_dir) + logger.info("Training configuration:\n %s", pformat(cfg.to_dict())) - writer = create_tensorboard_writer(exp_dir) - if cfg.wandb: - wandb.init(project="minisora", name=exp_name, config=cfg._cfg_dict) - - # 2.3. initialize ColossalAI booster - if cfg.plugin == "zero2": - plugin = LowLevelZeroPlugin( - stage=2, - precision=cfg.dtype, - initial_scale=2**16, - max_norm=cfg.grad_clip, - ) - set_data_parallel_group(dist.group.WORLD) - elif cfg.plugin == "zero2-seq": - plugin = ZeroSeqParallelPlugin( - sp_size=cfg.get("sp_size", 1), - stage=2, - precision=cfg.dtype, - initial_scale=2**16, - max_norm=cfg.grad_clip, - ) - set_sequence_parallel_group(plugin.sp_group) - set_data_parallel_group(plugin.dp_group) + tb_writer = create_tensorboard_writer(exp_dir) + if cfg.get("wandb", False): + wandb.init(project="minisora", name=exp_name, config=cfg.to_dict()) else: - raise ValueError(f"Unknown plugin {cfg.plugin}") + logger = create_logger(None) + + # == init ColossalAI booster == + plugin = create_colossalai_plugin( + plugin=cfg.get("plugin", "zero2"), + dtype=cfg_dtype, + grad_clip=cfg.get("grad_clip", 0), + sp_size=cfg.get("sp_size", 1), + ) booster = Booster(plugin=plugin) # ====================================================== - # 3. build dataset and dataloader + # 2. build dataset and dataloader # ====================================================== + # == build dataset == dataset = build_module(cfg.dataset, DATASETS) - logger.info(f"Dataset contains {len(dataset)} samples.") + logger.info("Dataset contains %s samples.", len(dataset)) + + # == build dataloader == dataloader_args = dict( dataset=dataset, batch_size=cfg.batch_size, @@ -112,25 +96,26 @@ def main(): pin_memory=True, process_group=get_data_parallel_group(), ) - # TODO: use plugin's prepare dataloader - if cfg.bucket_config is None: + if cfg.dataset.type == DEFAULT_DATASET_NAME: dataloader = prepare_dataloader(**dataloader_args) + total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size + logger.info("Total batch size: %s", total_batch_size) else: dataloader = prepare_variable_dataloader( bucket_config=cfg.bucket_config, num_bucket_build_workers=cfg.num_bucket_build_workers, **dataloader_args, ) - if cfg.dataset.type == "VideoTextDataset": - total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size - logger.info(f"Total batch size: {total_batch_size}") # ====================================================== - # 4. build model + # 3. build model # ====================================================== - # 4.1. build model + # == build text-encoder and vae == text_encoder = build_module(cfg.text_encoder, MODELS, device=device) - vae = build_module(cfg.vae, MODELS) + vae = build_module(cfg.vae, MODELS).to(device, dtype) + vae.eval() + + # == build diffusion model == input_size = (dataset.num_frames, *dataset.image_size) latent_size = vae.get_latent_size(input_size) model = build_module( @@ -140,46 +125,46 @@ def main(): in_channels=vae.out_channels, caption_channels=text_encoder.output_dim, model_max_length=text_encoder.model_max_length, - ) + ).to(device, dtype) + model.train() model_numel, model_numel_trainable = get_model_numel(model) logger.info( - f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}" + "Trainable model params: %s, Total model params: %s", + format_numel_str(model_numel_trainable), + format_numel_str(model_numel), ) - # 4.2. create ema + # == build ema for diffusion model == ema = deepcopy(model).to(torch.float32).to(device) requires_grad(ema, False) ema_shape_dict = record_model_param_shape(ema) + ema.eval() + update_ema(ema, model, decay=0, sharded=False) - # 4.3. move to device - vae = vae.to(device, dtype) - model = model.to(device, dtype) - - # 4.4. build scheduler + # == build scheduler == scheduler = build_module(cfg.scheduler, SCHEDULERS) - # 4.5. setup optimizer + # == setup optimizer == optimizer = HybridAdam( filter(lambda p: p.requires_grad, model.parameters()), - lr=cfg.lr, - weight_decay=0, adamw_mode=True, + lr=cfg.get("lr", 1e-4), + weight_decay=cfg.get("weight_decay", 0), eps=cfg.get("adam_eps", 1e-8), ) lr_scheduler = None - # 4.6. prepare for training - if cfg.grad_checkpoint: + # == additional preparation == + if cfg.get("grad_checkpoint", False): set_grad_checkpoint(model) - model.train() - update_ema(ema, model, decay=0, sharded=False) - ema.eval() - if cfg.mask_ratios is not None: + if cfg.get("mask_ratios", None) is not None: mask_generator = MaskGenerator(cfg.mask_ratios) # ======================================================= - # 5. boost model for distributed training with colossalai + # 5. distributed training preparation with colossalai # ======================================================= + # == boosting == + # NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32 torch.set_default_dtype(dtype) model, optimizer, _, dataloader, lr_scheduler = booster.boost( model=model, @@ -188,20 +173,23 @@ def main(): dataloader=dataloader, ) torch.set_default_dtype(torch.float) - logger.info("Boost model for distributed training") - if cfg.dataset.type == "VariableVideoTextDataset": - num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size() - else: + logger.info("Boosting model for distributed training") + if cfg.dataset.type == DEFAULT_DATASET_NAME: num_steps_per_epoch = len(dataloader) + sampler_to_io = None + else: + num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size() + sampler_to_io = None if cfg.get("start_from_scratch ", False) else dataloader.batch_sampler - # ======================================================= - # 6. training loop - # ======================================================= + cfg_epochs = cfg.get("epochs", 1000) + logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch) + + # == global variables == start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0 running_loss = 0.0 - sampler_to_io = dataloader.batch_sampler if cfg.dataset.type == "VariableVideoTextDataset" else None - # 6.1. resume training - if cfg.load is not None: + + # == resume == + if cfg.get("load", None) is not None: logger.info("Loading checkpoint") ret = load( booster, @@ -210,24 +198,27 @@ def main(): optimizer, lr_scheduler, cfg.load, - sampler=sampler_to_io if not cfg.start_from_scratch else None, + sampler=sampler_to_io, ) - if not cfg.start_from_scratch: + if cfg.get("start_from_scratch ", False): start_epoch, start_step, sampler_start_idx = ret - logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}") - logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch") - - if cfg.dataset.type == "VideoTextDataset": + logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step) + if cfg.dataset.type == DEFAULT_DATASET_NAME: dataloader.sampler.set_start_index(sampler_start_idx) + model_sharding(ema) - # 6.2. training loop - for epoch in range(start_epoch, cfg.epochs): - if cfg.dataset.type == "VideoTextDataset": + # ======================================================= + # 6. training loop + # ======================================================= + for epoch in range(start_epoch, cfg_epochs): + # == set dataloader to new epoch == + if cfg.dataset.type == DEFAULT_DATASET_NAME: dataloader.sampler.set_epoch(epoch) dataloader_iter = iter(dataloader) - logger.info(f"Beginning epoch {epoch}...") + logger.info("Beginning epoch %s...", epoch) + # == training loop in an epoch == with tqdm( enumerate(dataloader_iter, start=start_step), desc=f"Epoch {epoch}", @@ -238,50 +229,51 @@ def main(): for step, batch in pbar: x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] y = batch.pop("text") - # Visual and text encoding + + # == visual and text encoding == with torch.no_grad(): # Prepare visual inputs x = vae.encode(x) # [B, C, T, H/P, W/P] # Prepare text inputs model_args = text_encoder.encode(y) - # Mask - if cfg.mask_ratios is not None: + # == mask == + mask = None + if cfg.get("mask_ratios", None) is not None: mask = mask_generator.get_masks(x) model_args["x_mask"] = mask - else: - mask = None - # Video info + # == video meta info == for k, v in batch.items(): model_args[k] = v.to(device, dtype) - # Diffusion + # == diffusion loss computation == loss_dict = scheduler.training_losses(model, x, model_args, mask=mask) - # Backward & update + # == backward & update == loss = loss_dict["loss"].mean() booster.backward(loss=loss, optimizer=optimizer) optimizer.step() optimizer.zero_grad() - # Update EMA + # == update EMA == update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999)) - # Log loss values: + # == update log info == all_reduce_mean(loss) running_loss += loss.item() global_step = epoch * num_steps_per_epoch + step log_step += 1 acc_step += 1 - # Log to tensorboard + # == logging == if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0: avg_loss = running_loss / log_step + # progress bar pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step}) - running_loss = 0 - log_step = 0 - writer.add_scalar("loss", loss.item(), global_step) + # tensorboard + tb_writer.add_scalar("loss", loss.item(), global_step) + # wandb if cfg.wandb: wandb.log( { @@ -294,7 +286,10 @@ def main(): step=global_step, ) - # Save checkpoint + running_loss = 0 + log_step = 0 + + # == checkpoint saving == if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0: save( booster, @@ -312,15 +307,19 @@ def main(): sampler=sampler_to_io, ) logger.info( - f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" + "Saved checkpoint at epoch %s step %s global_step %s to %s", + epoch, + step + 1, + global_step + 1, + exp_dir, ) - # the continue epochs are not resumed, so we need to reset the sampler start index and start step - if cfg.dataset.type == "VideoTextDataset": + # NOTE: the continue epochs are not resumed, so we need to reset the sampler start index and start step + if cfg.dataset.type == DEFAULT_DATASET_NAME: dataloader.sampler.set_start_index(0) - if cfg.dataset.type == "VariableVideoTextDataset": + else: dataloader.batch_sampler.set_epoch(epoch + 1) - print("Epoch done, recomputing batch sampler") + logger.info("Epoch done, recomputing batch sampler") start_step = 0 diff --git a/tools/datasets/datautil.py b/tools/datasets/datautil.py index 475b847..bf34aa9 100644 --- a/tools/datasets/datautil.py +++ b/tools/datasets/datautil.py @@ -486,13 +486,15 @@ def main(args): if args.refine_llm_caption: assert "text" in data.columns data["text"] = apply(data["text"], remove_caption_prefix) + if args.append_text is not None: + assert "text" in data.columns + data["text"] = data["text"] + args.append_text if args.clean_caption: assert "text" in data.columns data["text"] = apply( data["text"], partial(text_preprocessing, use_text_preprocessing=True), ) - if args.count_num_token is not None: assert "text" in data.columns data["text_len"] = apply(data["text"], lambda x: len(tokenizer(x)["input_ids"])) @@ -597,6 +599,7 @@ def parse_args(): parser.add_argument( "--count-num-token", type=str, choices=["t5"], default=None, help="Count the number of tokens in the caption" ) + parser.add_argument("--append-text", type=str, default=None, help="append text to the caption") # score filtering parser.add_argument("--fmin", type=int, default=None, help="filter the dataset by minimum number of frames") @@ -661,6 +664,8 @@ def get_output_path(args, input_name): name += "_cmcaption" if args.count_num_token: name += "_ntoken" + if args.append_text is not None: + name += "_appendtext" # score filtering if args.fmin is not None: