[refactor] clean train.py code (#94)

This commit is contained in:
Zheng Zangwei (Alex Zheng) 2024-05-08 16:07:57 +08:00 committed by GitHub
parent 28dec22d2c
commit f047d5786a
4 changed files with 154 additions and 121 deletions

View file

@ -18,7 +18,7 @@ def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1):
def auto_grad_checkpoint(module, *args, **kwargs):
if getattr(module, "grad_checkpointing", False):
if not isinstance(module, Iterable):
return checkpoint(module, *args, **kwargs)
return checkpoint(module, *args, use_reentrant=False, **kwargs)
gc_step = module[0].grad_checkpointing_step
return checkpoint_sequential(module, gc_step, *args, **kwargs)
return checkpoint_sequential(module, gc_step, *args, use_reentrant=False, **kwargs)
return module(*args, **kwargs)

View file

@ -3,6 +3,35 @@ import random
from collections import OrderedDict
import torch
import torch.distributed as dist
from colossalai.booster.plugin import LowLevelZeroPlugin
from opensora.acceleration.parallel_states import set_data_parallel_group, set_sequence_parallel_group
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size):
if plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=dtype,
initial_scale=2**16,
max_norm=grad_clip,
)
set_data_parallel_group(dist.group.WORLD)
elif plugin == "zero2-seq":
plugin = ZeroSeqParallelPlugin(
sp_size=sp_size,
stage=2,
precision=dtype,
initial_scale=2**16,
max_norm=grad_clip,
)
set_sequence_parallel_group(plugin.sp_group)
set_data_parallel_group(plugin.dp_group)
else:
raise ValueError(f"Unknown plugin {plugin}")
return plugin
@torch.no_grad()
@ -18,7 +47,7 @@ def update_ema(
for name, param in model_params.items():
if name == "pos_embed":
continue
if param.requires_grad == False:
if not param.requires_grad:
continue
if not sharded:
param_data = param.data

View file

@ -1,12 +1,11 @@
import os
from copy import deepcopy
from datetime import timedelta
from pprint import pprint
from pprint import pformat
import torch
import torch.distributed as dist
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
@ -14,12 +13,7 @@ from tqdm import tqdm
import wandb
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import (
get_data_parallel_group,
set_data_parallel_group,
set_sequence_parallel_group,
)
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save
@ -30,78 +24,68 @@ from opensora.utils.config_utils import (
save_training_config,
)
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype
from opensora.utils.train_utils import MaskGenerator, update_ema
from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema
DEFAULT_DATASET_NAME = "VideoTextDataset"
def main():
# ======================================================
# 1. args & cfg
# 1. configs & runtime variables & colossalai launch
# ======================================================
# == parse configs ==
cfg = parse_configs(training=True)
# ======================================================
# 2. runtime variables & colossalai launch
# ======================================================
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}"
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
# 2.1. colossalai init distributed training
# we set a very large timeout to avoid some processes exit early
# == colossalai init distributed training ==
# NOTE: A very large timeout is set to avoid some processes exit early
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(1024)
coordinator = DistCoordinator()
device = get_current_device()
dtype = to_torch_dtype(cfg.dtype)
# 2.2. init exp_dir, logger, tensorboard & wandb
# == init exp_dir ==
exp_name, exp_dir = define_experiment_workspace(cfg)
coordinator.block_all()
if coordinator.is_master():
os.makedirs(exp_dir, exist_ok=True)
save_training_config(cfg._cfg_dict, exp_dir)
save_training_config(cfg.to_dict(), exp_dir)
coordinator.block_all()
if not coordinator.is_master():
logger = create_logger(None)
else:
print("Training configuration:")
pprint(cfg._cfg_dict)
# == init logger, tensorboard & wandb ==
if coordinator.is_master():
logger = create_logger(exp_dir)
logger.info(f"Experiment directory created at {exp_dir}")
logger.info("Experiment directory created at %s", exp_dir)
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
writer = create_tensorboard_writer(exp_dir)
if cfg.wandb:
wandb.init(project="minisora", name=exp_name, config=cfg._cfg_dict)
# 2.3. initialize ColossalAI booster
if cfg.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=cfg.dtype,
initial_scale=2**16,
max_norm=cfg.grad_clip,
)
set_data_parallel_group(dist.group.WORLD)
elif cfg.plugin == "zero2-seq":
plugin = ZeroSeqParallelPlugin(
sp_size=cfg.get("sp_size", 1),
stage=2,
precision=cfg.dtype,
initial_scale=2**16,
max_norm=cfg.grad_clip,
)
set_sequence_parallel_group(plugin.sp_group)
set_data_parallel_group(plugin.dp_group)
tb_writer = create_tensorboard_writer(exp_dir)
if cfg.get("wandb", False):
wandb.init(project="minisora", name=exp_name, config=cfg.to_dict())
else:
raise ValueError(f"Unknown plugin {cfg.plugin}")
logger = create_logger(None)
# == init ColossalAI booster ==
plugin = create_colossalai_plugin(
plugin=cfg.get("plugin", "zero2"),
dtype=cfg_dtype,
grad_clip=cfg.get("grad_clip", 0),
sp_size=cfg.get("sp_size", 1),
)
booster = Booster(plugin=plugin)
# ======================================================
# 3. build dataset and dataloader
# 2. build dataset and dataloader
# ======================================================
# == build dataset ==
dataset = build_module(cfg.dataset, DATASETS)
logger.info(f"Dataset contains {len(dataset)} samples.")
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader ==
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.batch_size,
@ -112,25 +96,26 @@ def main():
pin_memory=True,
process_group=get_data_parallel_group(),
)
# TODO: use plugin's prepare dataloader
if cfg.bucket_config is None:
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader = prepare_dataloader(**dataloader_args)
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
logger.info("Total batch size: %s", total_batch_size)
else:
dataloader = prepare_variable_dataloader(
bucket_config=cfg.bucket_config,
num_bucket_build_workers=cfg.num_bucket_build_workers,
**dataloader_args,
)
if cfg.dataset.type == "VideoTextDataset":
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
logger.info(f"Total batch size: {total_batch_size}")
# ======================================================
# 4. build model
# 3. build model
# ======================================================
# 4.1. build model
# == build text-encoder and vae ==
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
vae = build_module(cfg.vae, MODELS)
vae = build_module(cfg.vae, MODELS).to(device, dtype)
vae.eval()
# == build diffusion model ==
input_size = (dataset.num_frames, *dataset.image_size)
latent_size = vae.get_latent_size(input_size)
model = build_module(
@ -140,46 +125,46 @@ def main():
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
)
).to(device, dtype)
model.train()
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}"
"Trainable model params: %s, Total model params: %s",
format_numel_str(model_numel_trainable),
format_numel_str(model_numel),
)
# 4.2. create ema
# == build ema for diffusion model ==
ema = deepcopy(model).to(torch.float32).to(device)
requires_grad(ema, False)
ema_shape_dict = record_model_param_shape(ema)
ema.eval()
update_ema(ema, model, decay=0, sharded=False)
# 4.3. move to device
vae = vae.to(device, dtype)
model = model.to(device, dtype)
# 4.4. build scheduler
# == build scheduler ==
scheduler = build_module(cfg.scheduler, SCHEDULERS)
# 4.5. setup optimizer
# == setup optimizer ==
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=cfg.lr,
weight_decay=0,
adamw_mode=True,
lr=cfg.get("lr", 1e-4),
weight_decay=cfg.get("weight_decay", 0),
eps=cfg.get("adam_eps", 1e-8),
)
lr_scheduler = None
# 4.6. prepare for training
if cfg.grad_checkpoint:
# == additional preparation ==
if cfg.get("grad_checkpoint", False):
set_grad_checkpoint(model)
model.train()
update_ema(ema, model, decay=0, sharded=False)
ema.eval()
if cfg.mask_ratios is not None:
if cfg.get("mask_ratios", None) is not None:
mask_generator = MaskGenerator(cfg.mask_ratios)
# =======================================================
# 5. boost model for distributed training with colossalai
# 5. distributed training preparation with colossalai
# =======================================================
# == boosting ==
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
torch.set_default_dtype(dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
@ -188,20 +173,23 @@ def main():
dataloader=dataloader,
)
torch.set_default_dtype(torch.float)
logger.info("Boost model for distributed training")
if cfg.dataset.type == "VariableVideoTextDataset":
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
else:
logger.info("Boosting model for distributed training")
if cfg.dataset.type == DEFAULT_DATASET_NAME:
num_steps_per_epoch = len(dataloader)
sampler_to_io = None
else:
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
sampler_to_io = None if cfg.get("start_from_scratch ", False) else dataloader.batch_sampler
# =======================================================
# 6. training loop
# =======================================================
cfg_epochs = cfg.get("epochs", 1000)
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
# == global variables ==
start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0
running_loss = 0.0
sampler_to_io = dataloader.batch_sampler if cfg.dataset.type == "VariableVideoTextDataset" else None
# 6.1. resume training
if cfg.load is not None:
# == resume ==
if cfg.get("load", None) is not None:
logger.info("Loading checkpoint")
ret = load(
booster,
@ -210,24 +198,27 @@ def main():
optimizer,
lr_scheduler,
cfg.load,
sampler=sampler_to_io if not cfg.start_from_scratch else None,
sampler=sampler_to_io,
)
if not cfg.start_from_scratch:
if cfg.get("start_from_scratch ", False):
start_epoch, start_step, sampler_start_idx = ret
logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}")
logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch")
if cfg.dataset.type == "VideoTextDataset":
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_start_index(sampler_start_idx)
model_sharding(ema)
# 6.2. training loop
for epoch in range(start_epoch, cfg.epochs):
if cfg.dataset.type == "VideoTextDataset":
# =======================================================
# 6. training loop
# =======================================================
for epoch in range(start_epoch, cfg_epochs):
# == set dataloader to new epoch ==
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info(f"Beginning epoch {epoch}...")
logger.info("Beginning epoch %s...", epoch)
# == training loop in an epoch ==
with tqdm(
enumerate(dataloader_iter, start=start_step),
desc=f"Epoch {epoch}",
@ -238,50 +229,51 @@ def main():
for step, batch in pbar:
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
# Visual and text encoding
# == visual and text encoding ==
with torch.no_grad():
# Prepare visual inputs
x = vae.encode(x) # [B, C, T, H/P, W/P]
# Prepare text inputs
model_args = text_encoder.encode(y)
# Mask
if cfg.mask_ratios is not None:
# == mask ==
mask = None
if cfg.get("mask_ratios", None) is not None:
mask = mask_generator.get_masks(x)
model_args["x_mask"] = mask
else:
mask = None
# Video info
# == video meta info ==
for k, v in batch.items():
model_args[k] = v.to(device, dtype)
# Diffusion
# == diffusion loss computation ==
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
# Backward & update
# == backward & update ==
loss = loss_dict["loss"].mean()
booster.backward(loss=loss, optimizer=optimizer)
optimizer.step()
optimizer.zero_grad()
# Update EMA
# == update EMA ==
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
# Log loss values:
# == update log info ==
all_reduce_mean(loss)
running_loss += loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
# Log to tensorboard
# == logging ==
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
avg_loss = running_loss / log_step
# progress bar
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
running_loss = 0
log_step = 0
writer.add_scalar("loss", loss.item(), global_step)
# tensorboard
tb_writer.add_scalar("loss", loss.item(), global_step)
# wandb
if cfg.wandb:
wandb.log(
{
@ -294,7 +286,10 @@ def main():
step=global_step,
)
# Save checkpoint
running_loss = 0
log_step = 0
# == checkpoint saving ==
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
save(
booster,
@ -312,15 +307,19 @@ def main():
sampler=sampler_to_io,
)
logger.info(
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
"Saved checkpoint at epoch %s step %s global_step %s to %s",
epoch,
step + 1,
global_step + 1,
exp_dir,
)
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
if cfg.dataset.type == "VideoTextDataset":
# NOTE: the continue epochs are not resumed, so we need to reset the sampler start index and start step
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_start_index(0)
if cfg.dataset.type == "VariableVideoTextDataset":
else:
dataloader.batch_sampler.set_epoch(epoch + 1)
print("Epoch done, recomputing batch sampler")
logger.info("Epoch done, recomputing batch sampler")
start_step = 0

View file

@ -486,13 +486,15 @@ def main(args):
if args.refine_llm_caption:
assert "text" in data.columns
data["text"] = apply(data["text"], remove_caption_prefix)
if args.append_text is not None:
assert "text" in data.columns
data["text"] = data["text"] + args.append_text
if args.clean_caption:
assert "text" in data.columns
data["text"] = apply(
data["text"],
partial(text_preprocessing, use_text_preprocessing=True),
)
if args.count_num_token is not None:
assert "text" in data.columns
data["text_len"] = apply(data["text"], lambda x: len(tokenizer(x)["input_ids"]))
@ -597,6 +599,7 @@ def parse_args():
parser.add_argument(
"--count-num-token", type=str, choices=["t5"], default=None, help="Count the number of tokens in the caption"
)
parser.add_argument("--append-text", type=str, default=None, help="append text to the caption")
# score filtering
parser.add_argument("--fmin", type=int, default=None, help="filter the dataset by minimum number of frames")
@ -661,6 +664,8 @@ def get_output_path(args, input_name):
name += "_cmcaption"
if args.count_num_token:
name += "_ntoken"
if args.append_text is not None:
name += "_appendtext"
# score filtering
if args.fmin is not None: