Open-Sora/scripts/diffusion/train.py

655 lines
26 KiB
Python
Raw Normal View History

import gc
import math
import os
import subprocess
import warnings
from contextlib import nullcontext
from copy import deepcopy
from pprint import pformat
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
gc.disable()
import torch
import torch.distributed as dist
import torch.nn.functional as F
import wandb
from colossalai.booster import Booster
from colossalai.utils import set_seed
from peft import LoraConfig
from tqdm import tqdm
from opensora.acceleration.checkpoint import (
GLOBAL_ACTIVATION_MANAGER,
set_grad_checkpoint,
)
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets.aspect import bucket_to_shapes
from opensora.datasets.dataloader import prepare_dataloader
from opensora.datasets.pin_memory_cache import PinMemoryCache
from opensora.models.mmdit.distributed import MMDiTPolicy
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.ckpt import (
CheckpointIO,
model_sharding,
record_model_param_shape,
rm_checkpoints,
)
from opensora.utils.config import (
config_to_name,
create_experiment_workspace,
parse_configs,
)
from opensora.utils.logger import create_logger
from opensora.utils.misc import (
NsysProfiler,
Timers,
all_reduce_mean,
create_tensorboard_writer,
is_log_process,
is_pipeline_enabled,
log_cuda_max_memory,
log_cuda_memory,
log_model_params,
print_mem,
to_torch_dtype,
)
from opensora.utils.optimizer import create_lr_scheduler, create_optimizer
from opensora.utils.sampling import (
get_res_lin_function,
pack,
prepare,
prepare_ids,
time_shift,
)
from opensora.utils.train import (
create_colossalai_plugin,
dropout_condition,
get_batch_loss,
prepare_visual_condition_causal,
prepare_visual_condition_uncausal,
set_eps,
set_lr,
setup_device,
update_ema,
warmup_ae,
)
torch.backends.cudnn.benchmark = False # True leads to slow down in conv3d
def main():
# ======================================================
# 1. configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs()
# == get dtype & device ==
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
device, coordinator = setup_device()
grad_ckpt_buffer_size = cfg.get("grad_ckpt_buffer_size", 0)
if grad_ckpt_buffer_size > 0:
GLOBAL_ACTIVATION_MANAGER.setup_buffer(grad_ckpt_buffer_size, dtype)
checkpoint_io = CheckpointIO()
set_seed(cfg.get("seed", 1024))
PinMemoryCache.force_dtype = dtype
pin_memory_cache_pre_alloc_numels = cfg.get("pin_memory_cache_pre_alloc_numels", None)
PinMemoryCache.pre_alloc_numels = pin_memory_cache_pre_alloc_numels
# == init ColossalAI booster ==
plugin_type = cfg.get("plugin", "zero2")
plugin_config = cfg.get("plugin_config", {})
plugin_kwargs = {}
if plugin_type == "hybrid":
plugin_kwargs["custom_policy"] = MMDiTPolicy
plugin = create_colossalai_plugin(
plugin=plugin_type,
dtype=cfg.get("dtype", "bf16"),
grad_clip=cfg.get("grad_clip", 0),
**plugin_config,
**plugin_kwargs,
)
booster = Booster(plugin=plugin)
seq_align = plugin_config.get("sp_size", 1)
# == init exp_dir ==
exp_name, exp_dir = create_experiment_workspace(
cfg.get("outputs", "./outputs"),
model_name=config_to_name(cfg),
config=cfg.to_dict(),
exp_name=cfg.get("exp_name", None), # useful for automatic restart to specify the exp_name
)
if is_log_process(plugin_type, plugin_config):
print(f"changing {exp_dir} to share")
os.system(f"chgrp -R share {exp_dir}")
# == init logger, tensorboard & wandb ==
logger = create_logger(exp_dir)
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
tb_writer = None
if coordinator.is_master():
tb_writer = create_tensorboard_writer(exp_dir)
if cfg.get("wandb", False):
wandb.init(
project=cfg.get("wandb_project", "Open-Sora"),
name=exp_name,
config=cfg.to_dict(),
dir=exp_dir,
)
num_gpus = dist.get_world_size() if dist.is_initialized() else 1
tp_size = cfg["plugin_config"].get("tp_size", 1)
sp_size = cfg["plugin_config"].get("sp_size", 1)
pp_size = cfg["plugin_config"].get("pp_size", 1)
num_groups = num_gpus // (tp_size * sp_size * pp_size)
logger.info("Number of GPUs: %s", num_gpus)
logger.info("Number of groups: %s", num_groups)
# ======================================================
# 2. build dataset and dataloader
# ======================================================
logger.info("Building dataset...")
# == build dataset ==
dataset = build_module(cfg.dataset, DATASETS)
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader ==
cache_pin_memory = pin_memory_cache_pre_alloc_numels is not None
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.get("batch_size", None),
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
prefetch_factor=cfg.get("prefetch_factor", None),
cache_pin_memory=cache_pin_memory,
num_groups=num_groups,
)
print_mem("before prepare_dataloader")
dataloader, sampler = prepare_dataloader(
bucket_config=cfg.get("bucket_config", None),
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
**dataloader_args,
)
print_mem("after prepare_dataloader")
num_steps_per_epoch = len(dataloader)
dataset.to_efficient()
# ======================================================
# 3. build model
# ======================================================
logger.info("Building models...")
# == build model model ==
model = build_module(cfg.model, MODELS, device_map=device, torch_dtype=dtype).train()
if cfg.get("grad_checkpoint", True):
set_grad_checkpoint(model)
log_cuda_memory("diffusion")
log_model_params(model)
# == build EMA model ==
use_lora = cfg.get("lora_config", None) is not None
if cfg.get("ema_decay", None) is not None and not use_lora:
ema = deepcopy(model).cpu().eval().requires_grad_(False)
ema_shape_dict = record_model_param_shape(ema)
logger.info("EMA model created.")
else:
ema = ema_shape_dict = None
logger.info("No EMA model created.")
log_cuda_memory("EMA")
# == enable LoRA ==
if use_lora:
lora_config = LoraConfig(**cfg.get("lora_config", None))
model = booster.enable_lora(
model=model,
lora_config=lora_config,
pretrained_dir=cfg.get("lora_checkpoint", None),
)
log_cuda_memory("lora")
log_model_params(model)
if not cfg.get("cached_video", False):
# == buildn autoencoder ==
model_ae = build_module(cfg.ae, MODELS, device_map=device, torch_dtype=dtype).eval().requires_grad_(False)
del model_ae.decoder
log_cuda_memory("autoencoder")
log_model_params(model_ae)
model_ae.encode = torch.compile(model_ae.encoder, dynamic=True)
if not cfg.get("cached_text", False):
# == build text encoder (t5) ==
model_t5 = build_module(cfg.t5, MODELS, device_map=device, torch_dtype=dtype).eval().requires_grad_(False)
log_cuda_memory("t5")
log_model_params(model_t5)
# == build text encoder (clip) ==
model_clip = build_module(cfg.clip, MODELS, device_map=device, torch_dtype=dtype).eval().requires_grad_(False)
log_cuda_memory("clip")
log_model_params(model_clip)
# == setup optimizer ==
optimizer = create_optimizer(model, cfg.optim)
# == setup lr scheduler ==
lr_scheduler = create_lr_scheduler(
optimizer=optimizer,
num_steps_per_epoch=num_steps_per_epoch,
epochs=cfg.get("epochs", 1000),
warmup_steps=cfg.get("warmup_steps", None),
use_cosine_scheduler=cfg.get("use_cosine_scheduler", False),
)
log_cuda_memory("optimizer")
# == prepare null vectors for dropout ==
if cfg.get("cached_text", False):
null_txt = torch.load("/mnt/ddn/sora/tmp_load/null_t5.pt", map_location=device)
null_vec = torch.load("/mnt/ddn/sora/tmp_load/null_clip.pt", map_location=device)
else:
null_txt = model_t5("")
null_vec = model_clip("")
# =======================================================
# 4. distributed training preparation with colossalai
# =======================================================
logger.info("Preparing for distributed training...")
# == boosting ==
torch.set_default_dtype(dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
torch.set_default_dtype(torch.float)
logger.info("Boosted model for distributed training")
log_cuda_memory("boost")
# == global variables ==
cfg_epochs = cfg.get("epochs", 1000)
log_step = acc_step = 0
running_loss = 0.0
timers = Timers(record_time=cfg.get("record_time", False), record_barrier=cfg.get("record_barrier", False))
nsys = NsysProfiler(
warmup_steps=cfg.get("nsys_warmup_steps", 2),
num_steps=cfg.get("nsys_num_steps", 2),
enabled=cfg.get("nsys", False),
)
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
# == resume ==
load_master_weights = cfg.get("load_master_weights", False)
save_master_weights = cfg.get("save_master_weights", False)
start_epoch = cfg.get("start_epoch", None)
start_step = cfg.get("start_step", None)
if cfg.get("load", None) is not None:
logger.info("Loading checkpoint from %s", cfg.load)
lr_scheduler_to_load = lr_scheduler
if cfg.get("update_warmup_steps", False):
lr_scheduler_to_load = None
ret = checkpoint_io.load(
booster,
cfg.load,
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler_to_load,
sampler=(
None if start_step is not None else sampler
), # if specify start step, set last_micro_batch_access_index of a new sampler instead
include_master_weights=load_master_weights,
)
start_epoch = start_epoch if start_epoch is not None else ret[0]
start_step = start_step if start_step is not None else ret[1]
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, ret[0], ret[1])
# load optimizer and scheduler will overwrite some of the hyperparameters, so we need to reset them
set_lr(optimizer, lr_scheduler, cfg.optim.lr, cfg.get("initial_lr", None))
set_eps(optimizer, cfg.optim.eps)
if cfg.get("update_warmup_steps", False):
assert (
cfg.get("warmup_steps", None) is not None
), "you need to set warmup_steps in order to pass --update-warmup-steps True"
# set_warmup_steps(lr_scheduler, cfg.warmup_steps)
lr_scheduler.step(start_epoch * num_steps_per_epoch + start_step)
logger.info("The learning rate starts from %s", optimizer.param_groups[0]["lr"])
if start_step is not None:
# if start step exceeds data length, go to next epoch
if start_step > num_steps_per_epoch:
start_epoch = (
start_epoch + start_step // num_steps_per_epoch
if start_epoch is not None
else start_step // num_steps_per_epoch
)
start_step = start_step % num_steps_per_epoch
else:
start_step = 0
sampler.set_step(start_step)
start_epoch = start_epoch if start_epoch is not None else 0
logger.info("Starting from epoch %s step %s", start_epoch, start_step)
# == sharding EMA model ==
if ema is not None:
model_sharding(ema)
ema = ema.to(device)
log_cuda_memory("sharding EMA")
# == warmup autoencoder ==
if cfg.get("warmup_ae", False):
shapes = bucket_to_shapes(cfg.get("bucket_config", None), batch_size=cfg.ae.batch_size)
warmup_ae(model_ae, shapes, device, dtype)
# =======================================================
# 5. training iter
# =======================================================
sigma_min = cfg.get("sigma_min", 1e-5)
accumulation_steps = cfg.get("accumulation_steps", 1)
ckpt_every = cfg.get("ckpt_every", 0)
if cfg.get("is_causal_vae", False):
prepare_visual_condition = prepare_visual_condition_causal
else:
prepare_visual_condition = prepare_visual_condition_uncausal
@torch.no_grad()
def prepare_inputs(batch):
inp = dict()
x = batch.pop("video")
y = batch.pop("text")
bs = x.shape[0]
# == encode video ==
with nsys.range("encode_video"), timers["encode_video"]:
# == prepare condition ==
if cfg.get("condition_config", None) is not None:
# condition for i2v & v2v
x_0, cond = prepare_visual_condition(x, cfg.condition_config, model_ae)
cond = pack(cond, patch_size=cfg.get("patch_size", 2))
inp["cond"] = cond
else:
if cfg.get("cached_video", False):
x_0 = batch.pop("video_latents").to(device=device, dtype=dtype)
else:
x_0 = model_ae.encode(x)
# == prepare timestep ==
# follow SD3 time shift, shift_alpha = 1 for 256px and shift_alpha = 3 for 1024px
shift_alpha = get_res_lin_function()((x_0.shape[-1] * x_0.shape[-2]) // 4)
# add temporal influence
shift_alpha *= math.sqrt(x_0.shape[-3]) # for image, T=1 so no effect
t = torch.sigmoid(torch.randn((bs), device=device))
t = time_shift(shift_alpha, t).to(dtype)
if cfg.get("cached_text", False):
# == encode text ==
t5_embedding = batch.pop("text_t5").to(device=device, dtype=dtype)
clip_embedding = batch.pop("text_clip").to(device=device, dtype=dtype)
with nsys.range("encode_text"), timers["encode_text"]:
inp_ = prepare_ids(x_0, t5_embedding, clip_embedding)
inp.update(inp_)
x_0 = pack(x_0, patch_size=cfg.get("patch_size", 2))
else:
# == encode text ==
with nsys.range("encode_text"), timers["encode_text"]:
inp_ = prepare(
model_t5,
model_clip,
x_0,
prompt=y,
seq_align=seq_align,
patch_size=cfg.get("patch_size", 2),
)
inp.update(inp_)
x_0 = pack(x_0, patch_size=cfg.get("patch_size", 2))
# == dropout ==
if cfg.get("dropout_ratio", None) is not None:
cur_null_txt = null_txt
num_pad_null_txt = inp["txt"].shape[1] - cur_null_txt.shape[1]
if num_pad_null_txt > 0:
cur_null_txt = torch.cat([cur_null_txt] + [cur_null_txt[:, -1:]] * num_pad_null_txt, dim=1)
inp["txt"] = dropout_condition(
cfg.dropout_ratio.get("t5", 0.0),
inp["txt"],
cur_null_txt,
)
inp["y_vec"] = dropout_condition(
cfg.dropout_ratio.get("clip", 0.0),
inp["y_vec"],
null_vec,
)
# == prepare noise vector ==
x_1 = torch.randn_like(x_0, dtype=torch.float32).to(device, dtype)
t_rev = 1 - t
x_t = t_rev[:, None, None] * x_0 + (1 - (1 - sigma_min) * t_rev[:, None, None]) * x_1
inp["img"] = x_t
inp["timesteps"] = t.to(dtype)
inp["guidance"] = torch.full((x_t.shape[0],), cfg.get("guidance", 4), device=x_t.device, dtype=x_t.dtype)
return inp, x_0, x_1
def run_iter(inp, x_0, x_1):
if is_pipeline_enabled(plugin_type, plugin_config):
inp["target"] = (1 - sigma_min) * x_1 - x_0 # follow MovieGen, modify V_t accordingly
with nsys.range("forward-backward"), timers["forward-backward"]:
data_iter = iter([inp])
if cfg.get("no_i2v_ref_loss", False):
loss_fn = (
lambda out, input_: get_batch_loss(out, input_["target"], input_.pop("masks", None))
/ accumulation_steps
)
else:
loss_fn = (
lambda out, input_: F.mse_loss(out.float(), input_["target"].float(), reduction="mean")
/ accumulation_steps
)
loss = booster.execute_pipeline(data_iter, model, loss_fn, optimizer)["loss"]
loss = loss * accumulation_steps if loss is not None else loss
loss_item = all_reduce_mean(loss.data.clone().detach())
else:
with nsys.range("forward"), timers["forward"]:
model_pred = model(**inp) # B, T, L
v_t = (1 - sigma_min) * x_1 - x_0
if cfg.get("no_i2v_ref_loss", False):
loss = get_batch_loss(model_pred, v_t, inp.pop("masks", None))
else:
loss = F.mse_loss(model_pred.float(), v_t.float(), reduction="mean")
loss_item = all_reduce_mean(loss.data.clone().detach()).item()
# == backward & update ==
dist.barrier()
with nsys.range("backward"), timers["backward"]:
ctx = (
booster.no_sync(model, optimizer)
if cfg.get("plugin", "zero2") in ("zero1", "zero1-seq") and (step + 1) % accumulation_steps != 0
else nullcontext()
)
with ctx:
booster.backward(loss=(loss / accumulation_steps), optimizer=optimizer)
with nsys.range("optim"), timers["optim"]:
if (step + 1) % accumulation_steps == 0:
booster.checkpoint_io.synchronize()
optimizer.step()
optimizer.zero_grad()
if lr_scheduler is not None:
lr_scheduler.step()
# == update EMA ==
if ema is not None:
with nsys.range("update_ema"), timers["update_ema"]:
update_ema(
ema,
model.unwrap(),
optimizer=optimizer,
decay=cfg.get("ema_decay", 0.9999),
)
return loss_item
# =======================================================
# 6. training loop
# =======================================================
dist.barrier()
for epoch in range(start_epoch, cfg_epochs):
# == set dataloader to new epoch ==
sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info("Beginning epoch %s...", epoch)
# == training loop in an epoch ==
with tqdm(
enumerate(dataloader_iter, start=start_step),
desc=f"Epoch {epoch}",
disable=not is_log_process(plugin_type, plugin_config),
initial=start_step,
total=num_steps_per_epoch,
) as pbar:
pbar_iter = iter(pbar)
# prefetch one for non-blocking data loading
def fetch_data():
step, batch = next(pbar_iter)
# print(f"==debug== rank{dist.get_rank()} {dataloader_iter.get_cache_info()}")
pinned_video = batch["video"]
batch["video"] = pinned_video.to(device, dtype, non_blocking=True)
return batch, step, pinned_video
batch_, step_, pinned_video_ = fetch_data()
for _ in range(start_step, num_steps_per_epoch):
nsys.step()
# == load data ===
with nsys.range("load_data"), timers["load_data"]:
batch, step, pinned_video = batch_, step_, pinned_video_
if step + 1 < num_steps_per_epoch:
# only fetch new data if not last step
batch_, step_, pinned_video_ = fetch_data()
# == run iter ==
with nsys.range("iter"), timers["iter"]:
inp, x_0, x_1 = prepare_inputs(batch)
if cache_pin_memory:
dataloader_iter.remove_cache(pinned_video)
loss = run_iter(inp, x_0, x_1)
# == update log info ==
if loss is not None:
running_loss += loss
# == log config ==
global_step = epoch * num_steps_per_epoch + step
actual_update_step = (global_step + 1) // accumulation_steps
log_step += 1
acc_step += 1
# == logging ==
if (global_step + 1) % accumulation_steps == 0:
if actual_update_step % cfg.get("log_every", 1) == 0:
if is_log_process(plugin_type, plugin_config):
avg_loss = running_loss / log_step
# progress bar
pbar.set_postfix(
{
"loss": avg_loss,
"global_grad_norm": optimizer.get_grad_norm(),
"step": step,
"global_step": global_step,
# "actual_update_step": actual_update_step,
"lr": optimizer.param_groups[0]["lr"],
}
)
# tensorboard
if tb_writer is not None:
tb_writer.add_scalar("loss", loss, actual_update_step)
# wandb
if cfg.get("wandb", False):
wandb_dict = {
"iter": global_step,
"acc_step": acc_step,
"epoch": epoch,
"loss": loss,
"avg_loss": avg_loss,
"lr": optimizer.param_groups[0]["lr"],
"eps": optimizer.param_groups[0]["eps"],
"global_grad_norm": optimizer.get_grad_norm(), # test grad norm
}
if cfg.get("record_time", False):
wandb_dict.update(timers.to_dict())
wandb.log(wandb_dict, step=actual_update_step)
running_loss = 0.0
log_step = 0
# == checkpoint saving ==
# uncomment below 3 lines to forcely clean cache
with nsys.range("clean_cache"), timers["clean_cache"]:
if ckpt_every > 0 and actual_update_step % ckpt_every == 0 and coordinator.is_master():
subprocess.run("sudo drop_cache", shell=True)
with nsys.range("checkpoint"), timers["checkpoint"]:
if ckpt_every > 0 and actual_update_step % ckpt_every == 0:
# mannual garbage collection
gc.collect()
save_dir = checkpoint_io.save(
booster,
exp_dir,
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
sampler=sampler,
epoch=epoch,
step=step + 1,
global_step=global_step + 1,
batch_size=cfg.get("batch_size", None),
lora=use_lora,
actual_update_step=actual_update_step,
ema_shape_dict=ema_shape_dict,
async_io=cfg.get("async_io", False),
include_master_weights=save_master_weights,
)
if is_log_process(plugin_type, plugin_config):
os.system(f"chgrp -R share {save_dir}")
logger.info(
"Saved checkpoint at epoch %s, step %s, global_step %s to %s",
epoch,
step + 1,
actual_update_step,
save_dir,
)
# remove old checkpoints
rm_checkpoints(exp_dir, keep_n_latest=cfg.get("keep_n_latest", -1))
logger.info("Removed old checkpoints and kept %s latest ones.", cfg.get("keep_n_latest", -1))
# uncomment below 3 lines to benchmark checkpoint
# if ckpt_every > 0 and actual_update_step % ckpt_every == 0:
# booster.checkpoint_io._sync_io()
# checkpoint_io._sync_io()
# == terminal timer ==
if cfg.get("record_time", False):
print(timers.to_str(epoch, step))
sampler.reset()
start_step = 0
log_cuda_max_memory("final")
if __name__ == "__main__":
main()