from copy import deepcopy from datetime import timedelta from pprint import pprint import torch import torch.distributed as dist import wandb from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device, set_seed from tqdm import tqdm from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import ( get_data_parallel_group, set_data_parallel_group, set_sequence_parallel_group, ) from opensora.acceleration.plugin import ZeroSeqParallelPlugin from opensora.datasets import prepare_dataloader, prepare_variable_dataloader from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save from opensora.utils.config_utils import ( create_experiment_workspace, create_tensorboard_writer, parse_configs, save_training_config, ) from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype from opensora.utils.train_utils import MaskGenerator, update_ema def main(): # ====================================================== # 1. args & cfg # ====================================================== cfg = parse_configs(training=True) exp_name, exp_dir = create_experiment_workspace(cfg) save_training_config(cfg._cfg_dict, exp_dir) # ====================================================== # 2. runtime variables & colossalai launch # ====================================================== assert torch.cuda.is_available(), "Training currently requires at least one GPU." assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}" # 2.1. colossalai init distributed training # we set a very large timeout to avoid some processes exit early dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) set_seed(1024) coordinator = DistCoordinator() device = get_current_device() dtype = to_torch_dtype(cfg.dtype) # 2.2. init logger, tensorboard & wandb if not coordinator.is_master(): logger = create_logger(None) else: print("Training configuration:") pprint(cfg._cfg_dict) logger = create_logger(exp_dir) logger.info(f"Experiment directory created at {exp_dir}") writer = create_tensorboard_writer(exp_dir) if cfg.wandb: wandb.init(project="minisora", name=exp_name, config=cfg._cfg_dict) # 2.3. initialize ColossalAI booster if cfg.plugin == "zero2": plugin = LowLevelZeroPlugin( stage=2, precision=cfg.dtype, initial_scale=2**16, max_norm=cfg.grad_clip, ) set_data_parallel_group(dist.group.WORLD) elif cfg.plugin == "zero2-seq": plugin = ZeroSeqParallelPlugin( sp_size=cfg.sp_size, stage=2, precision=cfg.dtype, initial_scale=2**16, max_norm=cfg.grad_clip, ) set_sequence_parallel_group(plugin.sp_group) set_data_parallel_group(plugin.dp_group) else: raise ValueError(f"Unknown plugin {cfg.plugin}") booster = Booster(plugin=plugin) # ====================================================== # 3. build dataset and dataloader # ====================================================== dataset = build_module(cfg.dataset, DATASETS) logger.info(f"Dataset contains {len(dataset)} samples.") dataloader_args = dict( dataset=dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, seed=cfg.seed, shuffle=True, drop_last=True, pin_memory=True, process_group=get_data_parallel_group(), ) # TODO: use plugin's prepare dataloader if cfg.bucket_config is None: dataloader = prepare_dataloader(**dataloader_args) else: dataloader = prepare_variable_dataloader( bucket_config=cfg.bucket_config, num_bucket_build_workers=cfg.num_bucket_build_workers, **dataloader_args, ) if cfg.dataset.type == "VideoTextDataset": total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size logger.info(f"Total batch size: {total_batch_size}") # ====================================================== # 4. build model # ====================================================== # 4.1. build model text_encoder = build_module(cfg.text_encoder, MODELS, device=device) vae = build_module(cfg.vae, MODELS) input_size = (dataset.num_frames, *dataset.image_size) latent_size = vae.get_latent_size(input_size) model = build_module( cfg.model, MODELS, input_size=latent_size, in_channels=vae.out_channels, caption_channels=text_encoder.output_dim, model_max_length=text_encoder.model_max_length, dtype=dtype, ) model_numel, model_numel_trainable = get_model_numel(model) logger.info( f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}" ) # 4.2. create ema ema = deepcopy(model).to(torch.float32).to(device) requires_grad(ema, False) ema_shape_dict = record_model_param_shape(ema) # 4.3. move to device vae = vae.to(device, dtype) model = model.to(device, dtype) # 4.4. build scheduler scheduler = build_module(cfg.scheduler, SCHEDULERS) # 4.5. setup optimizer optimizer = HybridAdam( filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True, ) lr_scheduler = None # 4.6. prepare for training if cfg.grad_checkpoint: set_grad_checkpoint(model) model.train() update_ema(ema, model, decay=0, sharded=False) ema.eval() if cfg.mask_ratios is not None: mask_generator = MaskGenerator(cfg.mask_ratios) # ======================================================= # 5. boost model for distributed training with colossalai # ======================================================= torch.set_default_dtype(dtype) model, optimizer, _, dataloader, lr_scheduler = booster.boost( model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, dataloader=dataloader, ) torch.set_default_dtype(torch.float) logger.info("Boost model for distributed training") if cfg.dataset.type == "VariableVideoTextDataset": num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size() else: num_steps_per_epoch = len(dataloader) # ======================================================= # 6. training loop # ======================================================= start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0 running_loss = 0.0 sampler_to_io = dataloader.batch_sampler if cfg.dataset.type == "VariableVideoTextDataset" else None # 6.1. resume training if cfg.load is not None: logger.info("Loading checkpoint") ret = load( booster, model, ema, optimizer, lr_scheduler, cfg.load, sampler=sampler_to_io if not cfg.start_from_scratch else None, ) if not cfg.start_from_scratch: start_epoch, start_step, sampler_start_idx = ret logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}") logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch") if cfg.dataset.type == "VideoTextDataset": dataloader.sampler.set_start_index(sampler_start_idx) model_sharding(ema) # 6.2. training loop for epoch in range(start_epoch, cfg.epochs): if cfg.dataset.type == "VideoTextDataset": dataloader.sampler.set_epoch(epoch) dataloader_iter = iter(dataloader) logger.info(f"Beginning epoch {epoch}...") with tqdm( enumerate(dataloader_iter, start=start_step), desc=f"Epoch {epoch}", disable=not coordinator.is_master(), initial=start_step, total=num_steps_per_epoch, ) as pbar: for step, batch in pbar: x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] y = batch.pop("text") # Visual and text encoding with torch.no_grad(): # Prepare visual inputs x = vae.encode(x) # [B, C, T, H/P, W/P] # Prepare text inputs model_args = text_encoder.encode(y) # Mask if cfg.mask_ratios is not None: mask = mask_generator.get_masks(x) model_args["x_mask"] = mask else: mask = None # Video info for k, v in batch.items(): model_args[k] = v.to(device, dtype) # Diffusion t = torch.randint(0, scheduler.num_timesteps, (x.shape[0],), device=device) loss_dict = scheduler.training_losses(model, x, t, model_args, mask=mask) # Backward & update loss = loss_dict["loss"].mean() booster.backward(loss=loss, optimizer=optimizer) optimizer.step() optimizer.zero_grad() # Update EMA update_ema(ema, model.module, optimizer=optimizer) # Log loss values: all_reduce_mean(loss) running_loss += loss.item() global_step = epoch * num_steps_per_epoch + step log_step += 1 acc_step += 1 # Log to tensorboard if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0: avg_loss = running_loss / log_step pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step}) running_loss = 0 log_step = 0 writer.add_scalar("loss", loss.item(), global_step) if cfg.wandb: wandb.log( { "iter": global_step, "epoch": epoch, "loss": loss.item(), "avg_loss": avg_loss, "acc_step": acc_step, }, step=global_step, ) # Save checkpoint if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0: save( booster, model, ema, optimizer, lr_scheduler, epoch, step + 1, global_step + 1, cfg.batch_size, coordinator, exp_dir, ema_shape_dict, sampler=sampler_to_io, ) logger.info( f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" ) # the continue epochs are not resumed, so we need to reset the sampler start index and start step if cfg.dataset.type == "VideoTextDataset": dataloader.sampler.set_start_index(0) if cfg.dataset.type == "VariableVideoTextDataset": dataloader.batch_sampler.set_epoch(epoch + 1) print("Epoch done, recomputing batch sampler") start_step = 0 if __name__ == "__main__": main()