Dev/load batch (#105)

* add train_load_batch

* update train_load_batch

---------

Co-authored-by: pxy <pexure@gmail.com>
This commit is contained in:
Zheng Zangwei (Alex Zheng) 2024-05-20 17:40:05 +08:00 committed by GitHub
parent 5b9e753039
commit 9cf394884c
7 changed files with 617 additions and 2 deletions

View file

@ -0,0 +1,68 @@
# Dataset settings
dataset = dict(
type="BatchDataset",
)
grad_checkpoint = True
# Acceleration settings
num_workers = 8
dtype = "bf16"
plugin = "zero2"
# Model settings
model = dict(
type="STDiT3-XL/2",
from_pretrained=None,
qk_norm=True,
enable_flash_attn=True,
enable_layernorm_kernel=True,
freeze_y_embedder=True,
)
vae = dict(
type="OpenSoraVAE_V1_2",
from_pretrained="pretrained_models/vae-pipeline",
micro_frame_size=17,
micro_batch_size=4,
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=300,
shardformer=True,
local_files_only=True,
)
scheduler = dict(
type="rflow",
use_timestep_transform=True,
sample_method="logit-normal",
)
# Mask settings
mask_ratios = {
"random": 0.2,
"intepolate": 0.01,
"quarter_random": 0.01,
"quarter_head": 0.01,
"quarter_tail": 0.01,
"quarter_head_tail": 0.01,
"image_random": 0.05,
"image_head": 0.1,
"image_tail": 0.05,
"image_head_tail": 0.05,
}
# Log settings
seed = 42
outputs = "outputs"
wandb = False
epochs = 1
log_every = 10
ckpt_every = 500
# optimization settings
load = None
grad_clip = 1.0
lr = 2e-4
ema_decay = 0.99
adam_eps = 1e-15

View file

@ -1,2 +1,2 @@
from .datasets import IMG_FPS, VariableVideoTextDataset, VideoTextDataset
from .datasets import IMG_FPS, VariableVideoTextDataset, VideoTextDataset, BatchDataset
from .utils import get_transforms_image, get_transforms_video, is_img, is_vid, save_sample

View file

@ -8,7 +8,7 @@ from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import DataLoader
from .datasets import VariableVideoTextDataset, VideoTextDataset
from .sampler import StatefulDistributedSampler, VariableVideoBatchSampler
from .sampler import StatefulDistributedSampler, VariableVideoBatchSampler, BatchDistributedSampler
# Deterministic dataloader
@ -82,3 +82,95 @@ def prepare_dataloader(
)
else:
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
def prepare_variable_dataloader(
dataset,
batch_size,
bucket_config,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
process_group=None,
num_bucket_build_workers=1,
**kwargs,
):
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
batch_sampler = VariableVideoBatchSampler(
dataset,
bucket_config,
num_replicas=process_group.size(),
rank=process_group.rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
verbose=True,
num_bucket_build_workers=num_bucket_build_workers,
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
worker_init_fn=seed_worker,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
def build_batch_dataloader(
dataset,
# batch_size=1,
# shuffle=False,
seed=1024,
# drop_last=False,
pin_memory=False,
num_workers=0,
process_group: Optional[ProcessGroup] = None,
distributed=True,
**kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `BatchDistributedSampler`.
batch_size must be 1; shuffle is not supported so far
"""
_kwargs = kwargs.copy()
if distributed:
process_group = process_group or _get_default_group()
sampler = BatchDistributedSampler(
dataset,
num_replicas=process_group.size(),
rank=process_group.rank(),
)
else:
raise NotImplementedError
sampler = None
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=1,
sampler=sampler,
worker_init_fn=seed_worker,
# drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)

View file

@ -185,3 +185,50 @@ class VariableVideoTextDataset(VideoTextDataset):
# we return None here in case of errorneous data
# the collate function will handle it
return None
@DATASETS.register_module()
class BatchDataset(torch.utils.data.Dataset):
"""
The dataset is composed of multiple .bin files.
Each .bin file is a list of batch data (like a buffer). All .bin files have the same length.
In each training iteration, one batch is fetched from the current buffer.
Once a buffer is consumed, load another one.
Avoid loading the same .bin on two difference GPUs, i.e., one .bin is assigned to one GPU only.
"""
def __init__(self):
# self.meta = read_file(data_path)
# self.path_list = self.meta['path'].tolist()
self.path_list = [f'/mnt/nfs-207/sora_data/webvid-10M/feat_text/data/{idx}.bin' for idx in range(5)]
self._len_buffer = len(torch.load(self.path_list[0]))
self._num_buffers = len(self.path_list)
self.num_samples = self.len_buffer * len(self.path_list)
self.cur_file_idx = -1
@property
def num_buffers(self):
return self._num_buffers
@property
def len_buffer(self):
return self._len_buffer
def _load_buffer(self, idx):
file_idx = idx // self.len_buffer
if file_idx == self.cur_file_idx:
return
self.cur_file_idx = file_idx
self.cur_buffer = torch.load(self.path_list[file_idx])
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
self._load_buffer(idx)
batch = self.cur_buffer[idx % self.len_buffer] # dict; keys are {'x', 'fps'} and text related
return batch

View file

@ -1,7 +1,11 @@
import math
import warnings
from collections import OrderedDict, defaultdict
from pprint import pformat
from typing import Iterator, List, Optional
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DistributedSampler
@ -283,3 +287,24 @@ class VariableVideoBatchSampler(DistributedSampler):
def load_state_dict(self, state_dict: dict) -> None:
self.__dict__.update(state_dict)
class BatchDistributedSampler(DistributedSampler):
"""
Used with BatchDataset;
Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then
| buffer {i} | buffer {i+1}
rank 0 | 0, 1, 2, 3, 4, | 5, 6, 7, 8, 9
rank 1 | 10, 11, 12, 13, 14, | 15, 16, 17, 18, 19
rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29
"""
def __iter__(self):
num_buffers = self.dataset.num_buffers
len_buffer = self.dataset.len_buffer
num_buffers_i = num_buffers // self.num_replicas
num_samples_i = len_buffer * num_buffers_i
indices_i = np.arange(num_samples_i) + self.rank * num_samples_i
indices_i = indices_i.tolist()
return iter(indices_i)

View file

@ -1,5 +1,6 @@
import os
import re
import collections
import numpy as np
import pandas as pd
@ -221,3 +222,25 @@ def collate_fn_ignore_none(batch):
# None value is returned when the get_item fails for an index
batch = [val for val in batch if val is not None]
return torch.utils.data.default_collate(batch)
def collate_fn_batch(batch):
"""
Used only with BatchDistributedSampler
"""
res = torch.utils.data.default_collate(batch)
# squeeze the first dimension, which is due to torch.stack() in default_collate()
if isinstance(res, collections.abc.Mapping):
for k, v in res.items():
if isinstance(v, torch.Tensor):
res[k] = v.squeeze(0)
elif isinstance(res, collections.abc.Sequence):
res = [x.squeeze(0) if isinstance(x, torch.Tensor) else x for x in res]
elif isinstance(res, torch.Tensor):
res = res.squeeze(0)
else:
raise TypeError
return res

360
scripts/train_load_batch.py Normal file
View file

@ -0,0 +1,360 @@
import os
from copy import deepcopy
from datetime import timedelta
from pprint import pformat
import torch
import torch.distributed as dist
import wandb
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
from tqdm import tqdm
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import build_batch_dataloader
from opensora.datasets.utils import collate_fn_batch
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
from opensora.utils.misc import (
all_reduce_mean,
create_logger,
create_tensorboard_writer,
format_numel_str,
get_model_numel,
requires_grad,
to_torch_dtype,
)
from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema
DEFAULT_DATASET_NAME = "VideoTextDataset"
def main():
# ======================================================
# 1. configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs(training=True)
# == device and dtype ==
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
# == colossalai init distributed training ==
# NOTE: A very large timeout is set to avoid some processes exit early
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(cfg.get("seed", 1024))
coordinator = DistCoordinator()
device = get_current_device()
# == init exp_dir ==
exp_name, exp_dir = define_experiment_workspace(cfg)
coordinator.block_all()
if coordinator.is_master():
os.makedirs(exp_dir, exist_ok=True)
save_training_config(cfg.to_dict(), exp_dir)
coordinator.block_all()
# == init logger, tensorboard & wandb ==
logger = create_logger(exp_dir)
logger.info("Experiment directory created at %s", exp_dir)
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
if coordinator.is_master():
tb_writer = create_tensorboard_writer(exp_dir)
if cfg.get("wandb", False):
wandb.init(project="minisora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb")
# == init ColossalAI booster ==
plugin = create_colossalai_plugin(
plugin=cfg.get("plugin", "zero2"),
dtype=cfg_dtype,
grad_clip=cfg.get("grad_clip", 0),
sp_size=cfg.get("sp_size", 1),
)
booster = Booster(plugin=plugin)
# ======================================================
# 2. build dataset and dataloader
# ======================================================
logger.info("Building dataset...")
# == build dataset ==
dataset = build_module(cfg.dataset, DATASETS)
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader ==
# modify here
dataloader_args = dict(
dataset=dataset,
# batch_size=cfg.get("batch_size", 1),
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
# shuffle=True,
# drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
collate_fn=collate_fn_batch,
)
dataloader = build_batch_dataloader(**dataloader_args)
num_steps_per_epoch = len(dataset) // dist.get_world_size()
sampler_to_io = None
'''
TODO:
- prefetch
- collate fn
- resume
- sampler_to_io ?
- remove text_encoder & caption_embedder
- currently only support 1 epoch; every epoch is the same
'''
# if cfg.dataset.type == DEFAULT_DATASET_NAME:
# dataloader = prepare_dataloader(**dataloader_args)
# total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1)
# logger.info("Total batch size: %s", total_batch_size)
# num_steps_per_epoch = len(dataloader)
# sampler_to_io = None
# else:
# dataloader = prepare_variable_dataloader(
# bucket_config=cfg.get("bucket_config", None),
# num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
# **dataloader_args,
# )
# num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
# sampler_to_io = None if cfg.get("start_from_scratch ", False) else dataloader.batch_sampler
# ======================================================
# 3. build model
# ======================================================
logger.info("Building models...")
# == build text-encoder and vae ==
text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype)
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
# == build diffusion model ==
# modify here
# input_size = (dataset.num_frames, *dataset.image_size)
# latent_size = vae.get_latent_size(input_size)
latent_size = None, None, None
model = (
build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
)
.to(device, dtype)
.train()
)
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
"[Diffusion] Trainable model params: %s, Total model params: %s",
format_numel_str(model_numel_trainable),
format_numel_str(model_numel),
)
# == build ema for diffusion model ==
ema = deepcopy(model).to(torch.float32).to(device)
requires_grad(ema, False)
ema_shape_dict = record_model_param_shape(ema)
ema.eval()
update_ema(ema, model, decay=0, sharded=False)
# == setup loss function, build scheduler ==
scheduler = build_module(cfg.scheduler, SCHEDULERS)
# == setup optimizer ==
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()),
adamw_mode=True,
lr=cfg.get("lr", 1e-4),
weight_decay=cfg.get("weight_decay", 0),
eps=cfg.get("adam_eps", 1e-8),
)
lr_scheduler = None
# == additional preparation ==
if cfg.get("grad_checkpoint", False):
set_grad_checkpoint(model)
if cfg.get("mask_ratios", None) is not None:
mask_generator = MaskGenerator(cfg.mask_ratios)
# =======================================================
# 4. distributed training preparation with colossalai
# =======================================================
logger.info("Preparing for distributed training...")
# == boosting ==
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
torch.set_default_dtype(dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
torch.set_default_dtype(torch.float)
logger.info("Boosting model for distributed training")
# == global variables ==
# modify here
cfg_epochs = cfg.get("epochs", 1)
assert cfg_epochs == 1
start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0
running_loss = 0.0
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
# == resume ==
if cfg.get("load", None) is not None:
logger.info("Loading checkpoint")
ret = load(
booster,
cfg.load,
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
sampler=sampler_to_io,
)
if not cfg.get("start_from_scratch ", False):
start_epoch, start_step, sampler_start_idx = ret
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_start_index(sampler_start_idx)
model_sharding(ema)
# =======================================================
# 5. training loop
# =======================================================
dist.barrier()
for epoch in range(start_epoch, cfg_epochs):
# == set dataloader to new epoch ==
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info("Beginning epoch %s...", epoch)
# == training loop in an epoch ==
with tqdm(
enumerate(dataloader_iter, start=start_step),
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
initial=start_step,
total=num_steps_per_epoch,
) as pbar:
for step, batch in pbar:
# modify here
x = batch['x'].to(device, dtype) # feat of vae encoder
print(step, dist.get_rank(), batch['x'].shape)
continue
# x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
# y = batch.pop("text")
# == visual and text encoding ==
# with torch.no_grad():
# # Prepare visual inputs
# x = vae.encode(x) # [B, C, T, H/P, W/P]
# # Prepare text inputs
# model_args = text_encoder.encode(y)
model_args = {}
# == mask ==
mask = None
if cfg.get("mask_ratios", None) is not None:
mask = mask_generator.get_masks(x)
model_args["x_mask"] = mask
# == video meta info ==
for k, v in batch.items():
model_args[k] = v.to(device, dtype)
# == diffusion loss computation ==
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
# == backward & update ==
loss = loss_dict["loss"].mean()
booster.backward(loss=loss, optimizer=optimizer)
optimizer.step()
optimizer.zero_grad()
# == update EMA ==
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
# == update log info ==
all_reduce_mean(loss)
running_loss += loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
# == logging ==
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
avg_loss = running_loss / log_step
# progress bar
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
# tensorboard
tb_writer.add_scalar("loss", loss.item(), global_step)
# wandb
if cfg.get("wandb", False):
wandb.log(
{
"iter": global_step,
"epoch": epoch,
"loss": loss.item(),
"avg_loss": avg_loss,
"acc_step": acc_step,
},
step=global_step,
)
running_loss = 0.0
log_step = 0
# == checkpoint saving ==
ckpt_every = cfg.get("ckpt_every", 0)
if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0:
model_gathering(ema, ema_shape_dict)
save(
booster,
exp_dir,
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
sampler=sampler_to_io,
epoch=epoch,
step=step + 1,
global_step=global_step + 1,
batch_size=cfg.get("batch_size", None),
)
if dist.get_rank() == 0:
model_sharding(ema)
logger.info(
"Saved checkpoint at epoch %s step %s global_step %s to %s",
epoch,
step + 1,
global_step + 1,
exp_dir,
)
# NOTE: the continue epochs are not resumed, so we need to reset the sampler start index and start step
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_start_index(0)
else:
dataloader.batch_sampler.set_epoch(epoch + 1)
logger.info("Epoch done, recomputing batch sampler")
start_step = 0
if __name__ == "__main__":
main()