Open-Sora/scripts/train.py
Zheng Zangwei (Alex Zheng) f1c6b8b88e open-sora v1.3 code upload (#786)
Co-authored-by: gxyes <gxynoz@gmail.com>
2025-02-20 16:50:24 +08:00

521 lines
21 KiB
Python

import os
import random
from contextlib import nullcontext
from copy import deepcopy
from datetime import timedelta
from pprint import pformat
import torch
import torch.distributed as dist
import wandb
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
from tqdm import tqdm
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets.dataloader import prepare_dataloader
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
from opensora.utils.lr_scheduler import LinearWarmupLR
from opensora.utils.misc import (
Timer,
all_reduce_mean,
create_logger,
create_tensorboard_writer,
format_numel_str,
get_model_numel,
requires_grad,
to_torch_dtype,
)
from opensora.utils.train_utils import (
MaskGenerator,
aug_x,
create_colossalai_plugin,
get_mask_cond,
get_mask_index,
update_ema,
)
def main():
# ======================================================
# 1. configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs(training=True)
record_time = cfg.get("record_time", False)
# == device and dtype ==
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
# == colossalai init distributed training ==
# NOTE: A very large timeout is set to avoid some processes exit early
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(cfg.get("seed", 1024))
coordinator = DistCoordinator()
device = get_current_device()
# == init exp_dir ==
exp_name, exp_dir = define_experiment_workspace(cfg)
coordinator.block_all()
if coordinator.is_master():
os.makedirs(exp_dir, exist_ok=True)
save_training_config(cfg.to_dict(), exp_dir)
coordinator.block_all()
# == init logger, tensorboard & wandb ==
logger = create_logger(exp_dir)
logger.info("Experiment directory created at %s", exp_dir)
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
if coordinator.is_master():
tb_writer = create_tensorboard_writer(exp_dir)
if cfg.get("wandb", False):
wandb.init(project="Open-Sora", name=exp_name, config=cfg.to_dict(), dir=exp_dir)
# == init ColossalAI booster ==
plugin = create_colossalai_plugin(
plugin=cfg.get("plugin", "zero2"),
dtype=cfg_dtype,
grad_clip=cfg.get("grad_clip", 0),
sp_size=cfg.get("sp_size", 1),
reduce_bucket_size_in_m=cfg.get("reduce_bucket_size_in_m", 20),
)
booster = Booster(plugin=plugin)
torch.set_num_threads(1)
# == build text-encoder ==
text_encoder = build_module(cfg.get("text_encoder", None), MODELS, device=device, dtype=dtype)
if text_encoder is not None:
text_encoder_output_dim = text_encoder.output_dim
text_encoder_model_max_length = text_encoder.model_max_length
cfg.dataset.tokenize_fn = text_encoder.tokenize_fn
else:
text_encoder_output_dim = cfg.get("text_encoder_output_dim", 4096)
text_encoder_model_max_length = cfg.get("text_encoder_model_max_length", 300)
# ======================================================
# 2. build dataset and dataloader
# ======================================================
logger.info("Building dataset...")
# == build dataset ==
dataset = build_module(cfg.dataset, DATASETS)
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader ==
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.get("batch_size", None),
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
prefetch_factor=cfg.get("prefetch_factor", None),
)
dataloader, sampler = prepare_dataloader(
bucket_config=cfg.get("bucket_config", None),
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
**dataloader_args,
)
num_steps_per_epoch = len(dataloader)
# ======================================================
# 3. build model
# ======================================================
logger.info("Building models...")
# == build vae ==
vae = build_module(cfg.get("vae", None), MODELS)
if vae is not None:
vae = vae.to(device, dtype).eval()
if vae is not None:
input_size = (dataset.num_frames, *dataset.image_size)
latent_size = vae.get_latent_size(input_size)
vae_out_channels = vae.out_channels
else:
latent_size = (None, None, None)
vae_out_channels = cfg.get("vae_out_channels", 4)
# == build diffusion model ==
model = (
build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae_out_channels,
caption_channels=text_encoder_output_dim,
model_max_length=text_encoder_model_max_length,
enable_sequence_parallelism=cfg.get("sp_size", 1) > 1,
)
.to(device, dtype)
.train()
)
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
"[Diffusion] Trainable model params: %s, Total model params: %s",
format_numel_str(model_numel_trainable),
format_numel_str(model_numel),
)
# == build ema for diffusion model ==
ema = deepcopy(model).to(torch.float32).to(device)
requires_grad(ema, False)
ema_shape_dict = record_model_param_shape(ema)
ema.eval()
update_ema(ema, model, decay=0, sharded=False)
# == setup loss function, build scheduler ==
scheduler = build_module(cfg.scheduler, SCHEDULERS)
# == setup optimizer ==
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()),
adamw_mode=True,
lr=cfg.get("lr", 1e-4),
weight_decay=cfg.get("weight_decay", 0),
eps=cfg.get("adam_eps", 1e-8),
)
warmup_steps = cfg.get("warmup_steps", None)
use_cosine_scheduler = cfg.get("use_cosine_scheduler", False)
if warmup_steps is None and not use_cosine_scheduler:
lr_scheduler = None
elif use_cosine_scheduler:
lr_scheduler = CosineAnnealingWarmupLR(
optimizer,
total_steps=num_steps_per_epoch * cfg.get("epochs", 1000),
warmup_steps=cfg.get("warmup_steps", 0),
)
else:
lr_scheduler = LinearWarmupLR(optimizer, initial_lr=1e-6, warmup_steps=cfg.get("warmup_steps"))
# == additional preparation ==
if cfg.get("grad_checkpoint", False):
set_grad_checkpoint(model)
if cfg.get("mask_ratios", None) is not None:
mask_generator = MaskGenerator(cfg.mask_ratios)
# =======================================================
# 4. distributed training preparation with colossalai
# =======================================================
logger.info("Preparing for distributed training...")
# == boosting ==
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
torch.set_default_dtype(dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
torch.set_default_dtype(torch.float)
logger.info("Boosting model for distributed training")
# == global variables ==
cfg_epochs = cfg.get("epochs", 1000)
start_epoch = start_step = log_step = acc_step = 0
running_loss = 0.0
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
# == resume ==
if cfg.get("load", None) is not None:
logger.info("Loading checkpoint")
ret = load(
booster,
cfg.load,
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
sampler=None if cfg.get("start_from_scratch", False) else sampler,
)
if not cfg.get("start_from_scratch", False):
start_epoch, start_step = ret
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
model_sharding(ema)
# == mask ==
mask_types = cfg.get("mask_types", None)
if mask_types is not None:
mask_randgen = random.Random(dist.get_rank())
# =======================================================
# 5. training loop
# =======================================================
dist.barrier()
timers = {}
timer_keys = [
"move_data",
"mask_index",
"encode",
"mask",
"move_args",
"diffusion",
"backward",
"update_ema",
"reduce_loss",
"log",
"checkpoint",
]
for key in timer_keys:
if record_time:
timers[key] = Timer(key, coordinator=coordinator)
else:
timers[key] = nullcontext()
if record_time:
record_file = open(os.path.join(exp_dir, f"record_time_r{dist.get_rank()}.txt"), "w")
accumulation_steps = cfg.get("accumulation_steps", 1)
for epoch in range(start_epoch, cfg_epochs):
# == set dataloader to new epoch ==
sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info("Beginning epoch %s...", epoch)
# == training loop in an epoch ==
with tqdm(
enumerate(dataloader_iter, start=start_step),
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
initial=start_step,
total=num_steps_per_epoch,
) as pbar:
for step, batch in pbar:
timer_list = []
paths = batch.pop("path")
with timers["move_data"] as move_data_t:
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
input_ids = batch.pop("input_ids")
attention_mask = batch.pop("attention_mask")
if record_time:
timer_list.append(move_data_t)
# == prepare i2v&v2v mask_index ==
with timers["mask_index"] as mask_index_t:
num_frames = x.shape[2]
latent_t = vae.get_latent_size(x.shape[2:])[0]
mask_index = []
text_uncond_prob = cfg.model.get("class_dropout_prob", 0.1)
if mask_types is not None:
mask_cond = get_mask_cond(mask_randgen, mask_types)
if num_frames > 1: # NOTE: only use mask_indx for video
mask_index = get_mask_index(mask_cond, latent_t)
if len(mask_index) > 0:
text_uncond_prob = 0.0
if record_time:
timer_list.append(mask_index_t)
# == visual and text encoding ==
with timers["encode"] as encode_t:
x_noisy_ref = None # for v2v, add a little noise to video's referenced part
with torch.no_grad():
# Prepare visual inputs
if cfg.get("load_video_features", False):
x = x.to(device, dtype)
x_gt = x
# NOTE: x_noisy_ref is skipped for now
elif cfg.get("noise_augmentation", False) and x.shape[2] > 1:
x, x_gt = aug_x(
x,
vae,
cfg.get("noise_prob", {}),
cfg.get("noise_strength", {}),
)
# NOTE: x_noisy_ref is skipped for now
else:
if 0 in mask_index and "noisy" in mask_cond:
v2v_noise_min_weight = cfg.model.get("v2v_noise_min_weight", 0.1)
v2v_noise_max_weight = cfg.model.get("v2v_noise_max_weight", 0.3)
v2v_noise_ratio = v2v_noise_min_weight + random.uniform(0, 1) * (
v2v_noise_max_weight - v2v_noise_min_weight
)
x_noisy = v2v_noise_ratio * torch.randn_like(x) + (1 - v2v_noise_ratio) * x
else:
x_noisy = x
x_noisy_ref = vae.encode(x_noisy)
x_gt = x = vae.encode(x) # [B, C, T, H/P, W/P]
# Prepare text inputs
if cfg.get("load_text_features", False):
model_args = {"y": y.to(device, dtype)}
mask = batch.pop("mask")
if isinstance(mask, torch.Tensor):
mask = mask.to(device, dtype)
model_args["mask"] = mask
else:
model_args = text_encoder.encode(input_ids, attention_mask=attention_mask)
if record_time:
timer_list.append(encode_t)
# == mask ==
with timers["mask"] as mask_t:
mask = None
if cfg.get("mask_ratios", None) is not None:
mask = mask_generator.get_masks(x)
model_args["x_mask"] = mask
if record_time:
timer_list.append(mask_t)
# == video meta info ==
with timers["move_args"] as move_args_t:
for k, v in batch.items():
if isinstance(v, torch.Tensor):
model_args[k] = v.to(device, dtype)
if record_time:
timer_list.append(move_args_t)
# == diffusion loss computation ==
with timers["diffusion"] as loss_t:
if len(mask_index) > 0: # i2v and v2v training
model_args["x_mask"] = None # Don't use any other input masks
mask = None
loss_dict = scheduler.training_losses(
model,
x,
model_args=model_args,
mask=mask,
mask_index=mask_index,
x_gt=x_gt,
noise_disable_threshold=cfg.get("noise_disable_threshold", None),
text_uncond_prob=text_uncond_prob,
x_noisy_ref=x_noisy_ref,
)
if record_time:
timer_list.append(loss_t)
# == backward & update ==
with timers["backward"] as backward_t:
loss = loss_dict["loss"].mean()
loss = loss / accumulation_steps
ctx = (
booster.no_sync(model, optimizer)
if cfg.get("plugin", "zero2") in ("zero1", "zero1-seq") and (step + 1) % accumulation_steps != 0
else nullcontext()
)
with ctx:
booster.backward(loss=loss, optimizer=optimizer)
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# update learning rate
if lr_scheduler is not None:
lr_scheduler.step()
if record_time:
timer_list.append(backward_t)
# == update EMA ==
with timers["update_ema"] as ema_t:
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
if record_time:
timer_list.append(ema_t)
# == update log info ==
with timers["reduce_loss"] as reduce_loss_t:
all_reduce_mean(loss.data)
running_loss += loss.item() * accumulation_steps
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
if record_time:
timer_list.append(reduce_loss_t)
with timers["log"] as log_t:
# == logging ==
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
avg_loss = running_loss / log_step
# progress bar
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
# tensorboard
tb_writer.add_scalar("loss", loss.item() * accumulation_steps, global_step)
# wandb
if cfg.get("wandb", False):
wandb_dict = {
"iter": global_step,
"acc_step": acc_step,
"epoch": epoch,
"loss": loss.item() * accumulation_steps,
"avg_loss": avg_loss,
"lr": optimizer.param_groups[0]["lr"],
}
if record_time:
wandb_dict.update(
{
"debug/move_data_time": move_data_t.elapsed_time,
"debug/encode_time": encode_t.elapsed_time,
"debug/mask_time": mask_t.elapsed_time,
"debug/diffusion_time": loss_t.elapsed_time,
"debug/backward_time": backward_t.elapsed_time,
"debug/update_ema_time": ema_t.elapsed_time,
"debug/reduce_loss_time": reduce_loss_t.elapsed_time,
}
)
wandb.log(wandb_dict, step=global_step)
running_loss = 0.0
log_step = 0
if record_time:
timer_list.append(log_t)
# == checkpoint saving ==
with timers["checkpoint"] as checkpoint_t:
ckpt_every = cfg.get("ckpt_every", 0)
if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0:
model_gathering(ema, ema_shape_dict)
save_dir = save(
booster,
exp_dir,
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
sampler=sampler,
epoch=epoch,
step=step + 1,
global_step=global_step + 1,
batch_size=cfg.get("batch_size", None),
)
if dist.get_rank() == 0:
model_sharding(ema)
logger.info(
"Saved checkpoint at epoch %s, step %s, global_step %s to %s",
epoch,
step + 1,
global_step + 1,
save_dir,
)
if record_time:
timer_list.append(checkpoint_t)
if record_time:
total_step_time = sum([timer.elapsed_time for timer in timer_list])
log_str = (
f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | Step time: {total_step_time:.3f}s | "
)
for timer in timer_list:
log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | "
# print(log_str)
log_str += f"path: {paths}"
record_file.write(log_str + "\n")
record_file.flush()
sampler.reset()
start_step = 0
if __name__ == "__main__":
main()