diff --git a/configs/opensora-v1-2/train/train_load_batch.py b/configs/opensora-v1-2/train/train_load_batch.py new file mode 100644 index 0000000..ae20146 --- /dev/null +++ b/configs/opensora-v1-2/train/train_load_batch.py @@ -0,0 +1,68 @@ +# Dataset settings +dataset = dict( + type="BatchDataset", +) + +grad_checkpoint = True + +# Acceleration settings +num_workers = 8 +dtype = "bf16" +plugin = "zero2" + +# Model settings +model = dict( + type="STDiT3-XL/2", + from_pretrained=None, + qk_norm=True, + enable_flash_attn=True, + enable_layernorm_kernel=True, + freeze_y_embedder=True, +) +vae = dict( + type="OpenSoraVAE_V1_2", + from_pretrained="pretrained_models/vae-pipeline", + micro_frame_size=17, + micro_batch_size=4, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=300, + shardformer=True, + local_files_only=True, +) +scheduler = dict( + type="rflow", + use_timestep_transform=True, + sample_method="logit-normal", +) + +# Mask settings +mask_ratios = { + "random": 0.2, + "intepolate": 0.01, + "quarter_random": 0.01, + "quarter_head": 0.01, + "quarter_tail": 0.01, + "quarter_head_tail": 0.01, + "image_random": 0.05, + "image_head": 0.1, + "image_tail": 0.05, + "image_head_tail": 0.05, +} + +# Log settings +seed = 42 +outputs = "outputs" +wandb = False +epochs = 1 +log_every = 10 +ckpt_every = 500 + +# optimization settings +load = None +grad_clip = 1.0 +lr = 2e-4 +ema_decay = 0.99 +adam_eps = 1e-15 diff --git a/opensora/datasets/__init__.py b/opensora/datasets/__init__.py index 68220b9..fa4c798 100644 --- a/opensora/datasets/__init__.py +++ b/opensora/datasets/__init__.py @@ -1,2 +1,2 @@ -from .datasets import IMG_FPS, VariableVideoTextDataset, VideoTextDataset +from .datasets import IMG_FPS, VariableVideoTextDataset, VideoTextDataset, BatchDataset from .utils import get_transforms_image, get_transforms_video, is_img, is_vid, save_sample diff --git a/opensora/datasets/dataloader.py b/opensora/datasets/dataloader.py index 2a7c7de..1832f20 100644 --- a/opensora/datasets/dataloader.py +++ b/opensora/datasets/dataloader.py @@ -8,7 +8,7 @@ from torch.distributed.distributed_c10d import _get_default_group from torch.utils.data import DataLoader from .datasets import VariableVideoTextDataset, VideoTextDataset -from .sampler import StatefulDistributedSampler, VariableVideoBatchSampler +from .sampler import StatefulDistributedSampler, VariableVideoBatchSampler, BatchDistributedSampler # Deterministic dataloader @@ -82,3 +82,95 @@ def prepare_dataloader( ) else: raise ValueError(f"Unsupported dataset type: {type(dataset)}") + +def prepare_variable_dataloader( + dataset, + batch_size, + bucket_config, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + process_group=None, + num_bucket_build_workers=1, + **kwargs, +): + _kwargs = kwargs.copy() + process_group = process_group or _get_default_group() + batch_sampler = VariableVideoBatchSampler( + dataset, + bucket_config, + num_replicas=process_group.size(), + rank=process_group.rank(), + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + verbose=True, + num_bucket_build_workers=num_bucket_build_workers, + ) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return torch.utils.data.DataLoader( + dataset, + batch_sampler=batch_sampler, + worker_init_fn=seed_worker, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) + + +def build_batch_dataloader( + dataset, + # batch_size=1, + # shuffle=False, + seed=1024, + # drop_last=False, + pin_memory=False, + num_workers=0, + process_group: Optional[ProcessGroup] = None, + distributed=True, + **kwargs, +): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `BatchDistributedSampler`. + + batch_size must be 1; shuffle is not supported so far + """ + _kwargs = kwargs.copy() + if distributed: + process_group = process_group or _get_default_group() + sampler = BatchDistributedSampler( + dataset, + num_replicas=process_group.size(), + rank=process_group.rank(), + ) + else: + raise NotImplementedError + sampler = None + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=1, + sampler=sampler, + worker_init_fn=seed_worker, + # drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) \ No newline at end of file diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index 68cfbe7..084a455 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -190,3 +190,50 @@ class VariableVideoTextDataset(VideoTextDataset): # we return None here in case of errorneous data # the collate function will handle it return None + + +@DATASETS.register_module() +class BatchDataset(torch.utils.data.Dataset): + """ + The dataset is composed of multiple .bin files. + Each .bin file is a list of batch data (like a buffer). All .bin files have the same length. + In each training iteration, one batch is fetched from the current buffer. + Once a buffer is consumed, load another one. + Avoid loading the same .bin on two difference GPUs, i.e., one .bin is assigned to one GPU only. + """ + + def __init__(self): + # self.meta = read_file(data_path) + # self.path_list = self.meta['path'].tolist() + self.path_list = [f'/mnt/nfs-207/sora_data/webvid-10M/feat_text/data/{idx}.bin' for idx in range(5)] + + self._len_buffer = len(torch.load(self.path_list[0])) + self._num_buffers = len(self.path_list) + self.num_samples = self.len_buffer * len(self.path_list) + + self.cur_file_idx = -1 + + @property + def num_buffers(self): + return self._num_buffers + + @property + def len_buffer(self): + return self._len_buffer + + def _load_buffer(self, idx): + file_idx = idx // self.len_buffer + if file_idx == self.cur_file_idx: + return + self.cur_file_idx = file_idx + self.cur_buffer = torch.load(self.path_list[file_idx]) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + self._load_buffer(idx) + + batch = self.cur_buffer[idx % self.len_buffer] # dict; keys are {'x', 'fps'} and text related + return batch + diff --git a/opensora/datasets/sampler.py b/opensora/datasets/sampler.py index 3822499..bf2c642 100644 --- a/opensora/datasets/sampler.py +++ b/opensora/datasets/sampler.py @@ -1,7 +1,11 @@ +import math +import warnings + from collections import OrderedDict, defaultdict from pprint import pformat from typing import Iterator, List, Optional +import numpy as np import torch import torch.distributed as dist from torch.utils.data import Dataset, DistributedSampler @@ -283,3 +287,24 @@ class VariableVideoBatchSampler(DistributedSampler): def load_state_dict(self, state_dict: dict) -> None: self.__dict__.update(state_dict) + + +class BatchDistributedSampler(DistributedSampler): + """ + Used with BatchDataset; + Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then + | buffer {i} | buffer {i+1} + rank 0 | 0, 1, 2, 3, 4, | 5, 6, 7, 8, 9 + rank 1 | 10, 11, 12, 13, 14, | 15, 16, 17, 18, 19 + rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29 + """ + def __iter__(self): + num_buffers = self.dataset.num_buffers + len_buffer = self.dataset.len_buffer + num_buffers_i = num_buffers // self.num_replicas + num_samples_i = len_buffer * num_buffers_i + + indices_i = np.arange(num_samples_i) + self.rank * num_samples_i + indices_i = indices_i.tolist() + + return iter(indices_i) diff --git a/opensora/datasets/utils.py b/opensora/datasets/utils.py index 7f1be89..520a0ce 100644 --- a/opensora/datasets/utils.py +++ b/opensora/datasets/utils.py @@ -1,5 +1,6 @@ import os import re +import collections import numpy as np import pandas as pd @@ -221,3 +222,25 @@ def collate_fn_ignore_none(batch): # None value is returned when the get_item fails for an index batch = [val for val in batch if val is not None] return torch.utils.data.default_collate(batch) + + +def collate_fn_batch(batch): + """ + Used only with BatchDistributedSampler + """ + res = torch.utils.data.default_collate(batch) + + # squeeze the first dimension, which is due to torch.stack() in default_collate() + if isinstance(res, collections.abc.Mapping): + for k, v in res.items(): + if isinstance(v, torch.Tensor): + res[k] = v.squeeze(0) + elif isinstance(res, collections.abc.Sequence): + res = [x.squeeze(0) if isinstance(x, torch.Tensor) else x for x in res] + elif isinstance(res, torch.Tensor): + res = res.squeeze(0) + else: + raise TypeError + + return res + diff --git a/scripts/train_load_batch.py b/scripts/train_load_batch.py new file mode 100644 index 0000000..b316cd5 --- /dev/null +++ b/scripts/train_load_batch.py @@ -0,0 +1,360 @@ +import os +from copy import deepcopy +from datetime import timedelta +from pprint import pformat + +import torch +import torch.distributed as dist +import wandb +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device, set_seed +from tqdm import tqdm + +from opensora.acceleration.checkpoint import set_grad_checkpoint +from opensora.acceleration.parallel_states import get_data_parallel_group +from opensora.datasets import build_batch_dataloader +from opensora.datasets.utils import collate_fn_batch +from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module +from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save +from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config +from opensora.utils.misc import ( + all_reduce_mean, + create_logger, + create_tensorboard_writer, + format_numel_str, + get_model_numel, + requires_grad, + to_torch_dtype, +) +from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema + +DEFAULT_DATASET_NAME = "VideoTextDataset" + + +def main(): + # ====================================================== + # 1. configs & runtime variables + # ====================================================== + # == parse configs == + cfg = parse_configs(training=True) + + # == device and dtype == + assert torch.cuda.is_available(), "Training currently requires at least one GPU." + cfg_dtype = cfg.get("dtype", "bf16") + assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}" + dtype = to_torch_dtype(cfg.get("dtype", "bf16")) + + # == colossalai init distributed training == + # NOTE: A very large timeout is set to avoid some processes exit early + dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) + torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) + set_seed(cfg.get("seed", 1024)) + coordinator = DistCoordinator() + device = get_current_device() + + # == init exp_dir == + exp_name, exp_dir = define_experiment_workspace(cfg) + coordinator.block_all() + if coordinator.is_master(): + os.makedirs(exp_dir, exist_ok=True) + save_training_config(cfg.to_dict(), exp_dir) + coordinator.block_all() + + # == init logger, tensorboard & wandb == + logger = create_logger(exp_dir) + logger.info("Experiment directory created at %s", exp_dir) + logger.info("Training configuration:\n %s", pformat(cfg.to_dict())) + if coordinator.is_master(): + tb_writer = create_tensorboard_writer(exp_dir) + if cfg.get("wandb", False): + wandb.init(project="minisora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb") + + # == init ColossalAI booster == + plugin = create_colossalai_plugin( + plugin=cfg.get("plugin", "zero2"), + dtype=cfg_dtype, + grad_clip=cfg.get("grad_clip", 0), + sp_size=cfg.get("sp_size", 1), + ) + booster = Booster(plugin=plugin) + + # ====================================================== + # 2. build dataset and dataloader + # ====================================================== + logger.info("Building dataset...") + # == build dataset == + dataset = build_module(cfg.dataset, DATASETS) + logger.info("Dataset contains %s samples.", len(dataset)) + + # == build dataloader == + # modify here + dataloader_args = dict( + dataset=dataset, + # batch_size=cfg.get("batch_size", 1), + num_workers=cfg.get("num_workers", 4), + seed=cfg.get("seed", 1024), + # shuffle=True, + # drop_last=True, + pin_memory=True, + process_group=get_data_parallel_group(), + collate_fn=collate_fn_batch, + ) + dataloader = build_batch_dataloader(**dataloader_args) + num_steps_per_epoch = len(dataset) // dist.get_world_size() + sampler_to_io = None + + ''' + TODO: + - prefetch + - collate fn + - resume + - sampler_to_io ? + - remove text_encoder & caption_embedder + - currently only support 1 epoch; every epoch is the same + ''' + + # if cfg.dataset.type == DEFAULT_DATASET_NAME: + # dataloader = prepare_dataloader(**dataloader_args) + # total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1) + # logger.info("Total batch size: %s", total_batch_size) + # num_steps_per_epoch = len(dataloader) + # sampler_to_io = None + # else: + # dataloader = prepare_variable_dataloader( + # bucket_config=cfg.get("bucket_config", None), + # num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1), + # **dataloader_args, + # ) + # num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size() + # sampler_to_io = None if cfg.get("start_from_scratch ", False) else dataloader.batch_sampler + + # ====================================================== + # 3. build model + # ====================================================== + logger.info("Building models...") + # == build text-encoder and vae == + text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype) + vae = build_module(cfg.vae, MODELS).to(device, dtype).eval() + + # == build diffusion model == + # modify here + # input_size = (dataset.num_frames, *dataset.image_size) + # latent_size = vae.get_latent_size(input_size) + latent_size = None, None, None + model = ( + build_module( + cfg.model, + MODELS, + input_size=latent_size, + in_channels=vae.out_channels, + caption_channels=text_encoder.output_dim, + model_max_length=text_encoder.model_max_length, + ) + .to(device, dtype) + .train() + ) + model_numel, model_numel_trainable = get_model_numel(model) + logger.info( + "[Diffusion] Trainable model params: %s, Total model params: %s", + format_numel_str(model_numel_trainable), + format_numel_str(model_numel), + ) + + # == build ema for diffusion model == + ema = deepcopy(model).to(torch.float32).to(device) + requires_grad(ema, False) + ema_shape_dict = record_model_param_shape(ema) + ema.eval() + update_ema(ema, model, decay=0, sharded=False) + + # == setup loss function, build scheduler == + scheduler = build_module(cfg.scheduler, SCHEDULERS) + + # == setup optimizer == + optimizer = HybridAdam( + filter(lambda p: p.requires_grad, model.parameters()), + adamw_mode=True, + lr=cfg.get("lr", 1e-4), + weight_decay=cfg.get("weight_decay", 0), + eps=cfg.get("adam_eps", 1e-8), + ) + lr_scheduler = None + + # == additional preparation == + if cfg.get("grad_checkpoint", False): + set_grad_checkpoint(model) + if cfg.get("mask_ratios", None) is not None: + mask_generator = MaskGenerator(cfg.mask_ratios) + + # ======================================================= + # 4. distributed training preparation with colossalai + # ======================================================= + logger.info("Preparing for distributed training...") + # == boosting == + # NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32 + torch.set_default_dtype(dtype) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + dataloader=dataloader, + ) + torch.set_default_dtype(torch.float) + logger.info("Boosting model for distributed training") + + # == global variables == + # modify here + cfg_epochs = cfg.get("epochs", 1) + assert cfg_epochs == 1 + start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0 + running_loss = 0.0 + logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch) + + # == resume == + if cfg.get("load", None) is not None: + logger.info("Loading checkpoint") + ret = load( + booster, + cfg.load, + model=model, + ema=ema, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + sampler=sampler_to_io, + ) + if not cfg.get("start_from_scratch ", False): + start_epoch, start_step, sampler_start_idx = ret + logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step) + if cfg.dataset.type == DEFAULT_DATASET_NAME: + dataloader.sampler.set_start_index(sampler_start_idx) + + model_sharding(ema) + + # ======================================================= + # 5. training loop + # ======================================================= + dist.barrier() + for epoch in range(start_epoch, cfg_epochs): + # == set dataloader to new epoch == + if cfg.dataset.type == DEFAULT_DATASET_NAME: + dataloader.sampler.set_epoch(epoch) + dataloader_iter = iter(dataloader) + logger.info("Beginning epoch %s...", epoch) + + # == training loop in an epoch == + with tqdm( + enumerate(dataloader_iter, start=start_step), + desc=f"Epoch {epoch}", + disable=not coordinator.is_master(), + initial=start_step, + total=num_steps_per_epoch, + ) as pbar: + for step, batch in pbar: + # modify here + x = batch['x'].to(device, dtype) # feat of vae encoder + print(step, dist.get_rank(), batch['x'].shape) + continue + + # x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] + # y = batch.pop("text") + + # == visual and text encoding == + # with torch.no_grad(): + # # Prepare visual inputs + # x = vae.encode(x) # [B, C, T, H/P, W/P] + # # Prepare text inputs + # model_args = text_encoder.encode(y) + + model_args = {} + # == mask == + mask = None + if cfg.get("mask_ratios", None) is not None: + mask = mask_generator.get_masks(x) + model_args["x_mask"] = mask + + # == video meta info == + for k, v in batch.items(): + model_args[k] = v.to(device, dtype) + + # == diffusion loss computation == + loss_dict = scheduler.training_losses(model, x, model_args, mask=mask) + + # == backward & update == + loss = loss_dict["loss"].mean() + booster.backward(loss=loss, optimizer=optimizer) + optimizer.step() + optimizer.zero_grad() + + # == update EMA == + update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999)) + + # == update log info == + all_reduce_mean(loss) + running_loss += loss.item() + global_step = epoch * num_steps_per_epoch + step + log_step += 1 + acc_step += 1 + + # == logging == + if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0: + avg_loss = running_loss / log_step + # progress bar + pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step}) + # tensorboard + tb_writer.add_scalar("loss", loss.item(), global_step) + # wandb + if cfg.get("wandb", False): + wandb.log( + { + "iter": global_step, + "epoch": epoch, + "loss": loss.item(), + "avg_loss": avg_loss, + "acc_step": acc_step, + }, + step=global_step, + ) + + running_loss = 0.0 + log_step = 0 + + # == checkpoint saving == + ckpt_every = cfg.get("ckpt_every", 0) + if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0: + model_gathering(ema, ema_shape_dict) + save( + booster, + exp_dir, + model=model, + ema=ema, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + sampler=sampler_to_io, + epoch=epoch, + step=step + 1, + global_step=global_step + 1, + batch_size=cfg.get("batch_size", None), + ) + if dist.get_rank() == 0: + model_sharding(ema) + logger.info( + "Saved checkpoint at epoch %s step %s global_step %s to %s", + epoch, + step + 1, + global_step + 1, + exp_dir, + ) + + # NOTE: the continue epochs are not resumed, so we need to reset the sampler start index and start step + if cfg.dataset.type == DEFAULT_DATASET_NAME: + dataloader.sampler.set_start_index(0) + else: + dataloader.batch_sampler.set_epoch(epoch + 1) + logger.info("Epoch done, recomputing batch sampler") + start_step = 0 + + +if __name__ == "__main__": + main()