mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
521 lines
21 KiB
Python
521 lines
21 KiB
Python
import os
|
|
import random
|
|
from contextlib import nullcontext
|
|
from copy import deepcopy
|
|
from datetime import timedelta
|
|
from pprint import pformat
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import wandb
|
|
from colossalai.booster import Booster
|
|
from colossalai.cluster import DistCoordinator
|
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
from colossalai.utils import get_current_device, set_seed
|
|
from tqdm import tqdm
|
|
|
|
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
|
from opensora.acceleration.parallel_states import get_data_parallel_group
|
|
from opensora.datasets.dataloader import prepare_dataloader
|
|
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
|
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
|
|
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
|
|
from opensora.utils.lr_scheduler import LinearWarmupLR
|
|
from opensora.utils.misc import (
|
|
Timer,
|
|
all_reduce_mean,
|
|
create_logger,
|
|
create_tensorboard_writer,
|
|
format_numel_str,
|
|
get_model_numel,
|
|
requires_grad,
|
|
to_torch_dtype,
|
|
)
|
|
from opensora.utils.train_utils import (
|
|
MaskGenerator,
|
|
aug_x,
|
|
create_colossalai_plugin,
|
|
get_mask_cond,
|
|
get_mask_index,
|
|
update_ema,
|
|
)
|
|
|
|
|
|
def main():
|
|
# ======================================================
|
|
# 1. configs & runtime variables
|
|
# ======================================================
|
|
# == parse configs ==
|
|
cfg = parse_configs(training=True)
|
|
record_time = cfg.get("record_time", False)
|
|
|
|
# == device and dtype ==
|
|
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
|
cfg_dtype = cfg.get("dtype", "bf16")
|
|
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
|
|
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
|
|
|
# == colossalai init distributed training ==
|
|
# NOTE: A very large timeout is set to avoid some processes exit early
|
|
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
|
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
|
set_seed(cfg.get("seed", 1024))
|
|
coordinator = DistCoordinator()
|
|
device = get_current_device()
|
|
|
|
# == init exp_dir ==
|
|
exp_name, exp_dir = define_experiment_workspace(cfg)
|
|
coordinator.block_all()
|
|
if coordinator.is_master():
|
|
os.makedirs(exp_dir, exist_ok=True)
|
|
save_training_config(cfg.to_dict(), exp_dir)
|
|
coordinator.block_all()
|
|
|
|
# == init logger, tensorboard & wandb ==
|
|
logger = create_logger(exp_dir)
|
|
logger.info("Experiment directory created at %s", exp_dir)
|
|
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
|
|
if coordinator.is_master():
|
|
tb_writer = create_tensorboard_writer(exp_dir)
|
|
if cfg.get("wandb", False):
|
|
wandb.init(project="Open-Sora", name=exp_name, config=cfg.to_dict(), dir=exp_dir)
|
|
|
|
# == init ColossalAI booster ==
|
|
plugin = create_colossalai_plugin(
|
|
plugin=cfg.get("plugin", "zero2"),
|
|
dtype=cfg_dtype,
|
|
grad_clip=cfg.get("grad_clip", 0),
|
|
sp_size=cfg.get("sp_size", 1),
|
|
reduce_bucket_size_in_m=cfg.get("reduce_bucket_size_in_m", 20),
|
|
)
|
|
booster = Booster(plugin=plugin)
|
|
torch.set_num_threads(1)
|
|
|
|
# == build text-encoder ==
|
|
text_encoder = build_module(cfg.get("text_encoder", None), MODELS, device=device, dtype=dtype)
|
|
if text_encoder is not None:
|
|
text_encoder_output_dim = text_encoder.output_dim
|
|
text_encoder_model_max_length = text_encoder.model_max_length
|
|
cfg.dataset.tokenize_fn = text_encoder.tokenize_fn
|
|
else:
|
|
text_encoder_output_dim = cfg.get("text_encoder_output_dim", 4096)
|
|
text_encoder_model_max_length = cfg.get("text_encoder_model_max_length", 300)
|
|
|
|
# ======================================================
|
|
# 2. build dataset and dataloader
|
|
# ======================================================
|
|
logger.info("Building dataset...")
|
|
# == build dataset ==
|
|
dataset = build_module(cfg.dataset, DATASETS)
|
|
logger.info("Dataset contains %s samples.", len(dataset))
|
|
|
|
# == build dataloader ==
|
|
dataloader_args = dict(
|
|
dataset=dataset,
|
|
batch_size=cfg.get("batch_size", None),
|
|
num_workers=cfg.get("num_workers", 4),
|
|
seed=cfg.get("seed", 1024),
|
|
shuffle=True,
|
|
drop_last=True,
|
|
pin_memory=True,
|
|
process_group=get_data_parallel_group(),
|
|
prefetch_factor=cfg.get("prefetch_factor", None),
|
|
)
|
|
dataloader, sampler = prepare_dataloader(
|
|
bucket_config=cfg.get("bucket_config", None),
|
|
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
|
|
**dataloader_args,
|
|
)
|
|
num_steps_per_epoch = len(dataloader)
|
|
|
|
# ======================================================
|
|
# 3. build model
|
|
# ======================================================
|
|
logger.info("Building models...")
|
|
|
|
# == build vae ==
|
|
vae = build_module(cfg.get("vae", None), MODELS)
|
|
if vae is not None:
|
|
vae = vae.to(device, dtype).eval()
|
|
if vae is not None:
|
|
input_size = (dataset.num_frames, *dataset.image_size)
|
|
latent_size = vae.get_latent_size(input_size)
|
|
vae_out_channels = vae.out_channels
|
|
else:
|
|
latent_size = (None, None, None)
|
|
vae_out_channels = cfg.get("vae_out_channels", 4)
|
|
|
|
# == build diffusion model ==
|
|
model = (
|
|
build_module(
|
|
cfg.model,
|
|
MODELS,
|
|
input_size=latent_size,
|
|
in_channels=vae_out_channels,
|
|
caption_channels=text_encoder_output_dim,
|
|
model_max_length=text_encoder_model_max_length,
|
|
enable_sequence_parallelism=cfg.get("sp_size", 1) > 1,
|
|
)
|
|
.to(device, dtype)
|
|
.train()
|
|
)
|
|
model_numel, model_numel_trainable = get_model_numel(model)
|
|
logger.info(
|
|
"[Diffusion] Trainable model params: %s, Total model params: %s",
|
|
format_numel_str(model_numel_trainable),
|
|
format_numel_str(model_numel),
|
|
)
|
|
|
|
# == build ema for diffusion model ==
|
|
ema = deepcopy(model).to(torch.float32).to(device)
|
|
requires_grad(ema, False)
|
|
ema_shape_dict = record_model_param_shape(ema)
|
|
ema.eval()
|
|
update_ema(ema, model, decay=0, sharded=False)
|
|
|
|
# == setup loss function, build scheduler ==
|
|
scheduler = build_module(cfg.scheduler, SCHEDULERS)
|
|
|
|
# == setup optimizer ==
|
|
optimizer = HybridAdam(
|
|
filter(lambda p: p.requires_grad, model.parameters()),
|
|
adamw_mode=True,
|
|
lr=cfg.get("lr", 1e-4),
|
|
weight_decay=cfg.get("weight_decay", 0),
|
|
eps=cfg.get("adam_eps", 1e-8),
|
|
)
|
|
|
|
warmup_steps = cfg.get("warmup_steps", None)
|
|
use_cosine_scheduler = cfg.get("use_cosine_scheduler", False)
|
|
|
|
if warmup_steps is None and not use_cosine_scheduler:
|
|
lr_scheduler = None
|
|
elif use_cosine_scheduler:
|
|
lr_scheduler = CosineAnnealingWarmupLR(
|
|
optimizer,
|
|
total_steps=num_steps_per_epoch * cfg.get("epochs", 1000),
|
|
warmup_steps=cfg.get("warmup_steps", 0),
|
|
)
|
|
else:
|
|
lr_scheduler = LinearWarmupLR(optimizer, initial_lr=1e-6, warmup_steps=cfg.get("warmup_steps"))
|
|
|
|
# == additional preparation ==
|
|
if cfg.get("grad_checkpoint", False):
|
|
set_grad_checkpoint(model)
|
|
if cfg.get("mask_ratios", None) is not None:
|
|
mask_generator = MaskGenerator(cfg.mask_ratios)
|
|
|
|
# =======================================================
|
|
# 4. distributed training preparation with colossalai
|
|
# =======================================================
|
|
logger.info("Preparing for distributed training...")
|
|
# == boosting ==
|
|
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
|
|
torch.set_default_dtype(dtype)
|
|
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
|
model=model,
|
|
optimizer=optimizer,
|
|
lr_scheduler=lr_scheduler,
|
|
dataloader=dataloader,
|
|
)
|
|
torch.set_default_dtype(torch.float)
|
|
logger.info("Boosting model for distributed training")
|
|
|
|
# == global variables ==
|
|
cfg_epochs = cfg.get("epochs", 1000)
|
|
start_epoch = start_step = log_step = acc_step = 0
|
|
running_loss = 0.0
|
|
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
|
|
|
|
# == resume ==
|
|
if cfg.get("load", None) is not None:
|
|
logger.info("Loading checkpoint")
|
|
ret = load(
|
|
booster,
|
|
cfg.load,
|
|
model=model,
|
|
ema=ema,
|
|
optimizer=optimizer,
|
|
lr_scheduler=lr_scheduler,
|
|
sampler=None if cfg.get("start_from_scratch", False) else sampler,
|
|
)
|
|
if not cfg.get("start_from_scratch", False):
|
|
start_epoch, start_step = ret
|
|
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
|
|
|
|
model_sharding(ema)
|
|
|
|
# == mask ==
|
|
mask_types = cfg.get("mask_types", None)
|
|
if mask_types is not None:
|
|
mask_randgen = random.Random(dist.get_rank())
|
|
|
|
# =======================================================
|
|
# 5. training loop
|
|
# =======================================================
|
|
dist.barrier()
|
|
timers = {}
|
|
timer_keys = [
|
|
"move_data",
|
|
"mask_index",
|
|
"encode",
|
|
"mask",
|
|
"move_args",
|
|
"diffusion",
|
|
"backward",
|
|
"update_ema",
|
|
"reduce_loss",
|
|
"log",
|
|
"checkpoint",
|
|
]
|
|
for key in timer_keys:
|
|
if record_time:
|
|
timers[key] = Timer(key, coordinator=coordinator)
|
|
else:
|
|
timers[key] = nullcontext()
|
|
if record_time:
|
|
record_file = open(os.path.join(exp_dir, f"record_time_r{dist.get_rank()}.txt"), "w")
|
|
|
|
accumulation_steps = cfg.get("accumulation_steps", 1)
|
|
for epoch in range(start_epoch, cfg_epochs):
|
|
# == set dataloader to new epoch ==
|
|
sampler.set_epoch(epoch)
|
|
dataloader_iter = iter(dataloader)
|
|
logger.info("Beginning epoch %s...", epoch)
|
|
|
|
# == training loop in an epoch ==
|
|
with tqdm(
|
|
enumerate(dataloader_iter, start=start_step),
|
|
desc=f"Epoch {epoch}",
|
|
disable=not coordinator.is_master(),
|
|
initial=start_step,
|
|
total=num_steps_per_epoch,
|
|
) as pbar:
|
|
for step, batch in pbar:
|
|
timer_list = []
|
|
paths = batch.pop("path")
|
|
with timers["move_data"] as move_data_t:
|
|
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
|
|
y = batch.pop("text")
|
|
input_ids = batch.pop("input_ids")
|
|
attention_mask = batch.pop("attention_mask")
|
|
if record_time:
|
|
timer_list.append(move_data_t)
|
|
|
|
# == prepare i2v&v2v mask_index ==
|
|
with timers["mask_index"] as mask_index_t:
|
|
num_frames = x.shape[2]
|
|
latent_t = vae.get_latent_size(x.shape[2:])[0]
|
|
mask_index = []
|
|
text_uncond_prob = cfg.model.get("class_dropout_prob", 0.1)
|
|
if mask_types is not None:
|
|
mask_cond = get_mask_cond(mask_randgen, mask_types)
|
|
if num_frames > 1: # NOTE: only use mask_indx for video
|
|
mask_index = get_mask_index(mask_cond, latent_t)
|
|
if len(mask_index) > 0:
|
|
text_uncond_prob = 0.0
|
|
if record_time:
|
|
timer_list.append(mask_index_t)
|
|
|
|
# == visual and text encoding ==
|
|
with timers["encode"] as encode_t:
|
|
x_noisy_ref = None # for v2v, add a little noise to video's referenced part
|
|
with torch.no_grad():
|
|
# Prepare visual inputs
|
|
if cfg.get("load_video_features", False):
|
|
x = x.to(device, dtype)
|
|
x_gt = x
|
|
# NOTE: x_noisy_ref is skipped for now
|
|
elif cfg.get("noise_augmentation", False) and x.shape[2] > 1:
|
|
x, x_gt = aug_x(
|
|
x,
|
|
vae,
|
|
cfg.get("noise_prob", {}),
|
|
cfg.get("noise_strength", {}),
|
|
)
|
|
# NOTE: x_noisy_ref is skipped for now
|
|
else:
|
|
if 0 in mask_index and "noisy" in mask_cond:
|
|
v2v_noise_min_weight = cfg.model.get("v2v_noise_min_weight", 0.1)
|
|
v2v_noise_max_weight = cfg.model.get("v2v_noise_max_weight", 0.3)
|
|
v2v_noise_ratio = v2v_noise_min_weight + random.uniform(0, 1) * (
|
|
v2v_noise_max_weight - v2v_noise_min_weight
|
|
)
|
|
x_noisy = v2v_noise_ratio * torch.randn_like(x) + (1 - v2v_noise_ratio) * x
|
|
else:
|
|
x_noisy = x
|
|
|
|
x_noisy_ref = vae.encode(x_noisy)
|
|
|
|
x_gt = x = vae.encode(x) # [B, C, T, H/P, W/P]
|
|
|
|
# Prepare text inputs
|
|
if cfg.get("load_text_features", False):
|
|
model_args = {"y": y.to(device, dtype)}
|
|
mask = batch.pop("mask")
|
|
if isinstance(mask, torch.Tensor):
|
|
mask = mask.to(device, dtype)
|
|
model_args["mask"] = mask
|
|
else:
|
|
model_args = text_encoder.encode(input_ids, attention_mask=attention_mask)
|
|
if record_time:
|
|
timer_list.append(encode_t)
|
|
|
|
# == mask ==
|
|
with timers["mask"] as mask_t:
|
|
mask = None
|
|
if cfg.get("mask_ratios", None) is not None:
|
|
mask = mask_generator.get_masks(x)
|
|
model_args["x_mask"] = mask
|
|
if record_time:
|
|
timer_list.append(mask_t)
|
|
|
|
# == video meta info ==
|
|
with timers["move_args"] as move_args_t:
|
|
for k, v in batch.items():
|
|
if isinstance(v, torch.Tensor):
|
|
model_args[k] = v.to(device, dtype)
|
|
if record_time:
|
|
timer_list.append(move_args_t)
|
|
|
|
# == diffusion loss computation ==
|
|
with timers["diffusion"] as loss_t:
|
|
if len(mask_index) > 0: # i2v and v2v training
|
|
model_args["x_mask"] = None # Don't use any other input masks
|
|
mask = None
|
|
loss_dict = scheduler.training_losses(
|
|
model,
|
|
x,
|
|
model_args=model_args,
|
|
mask=mask,
|
|
mask_index=mask_index,
|
|
x_gt=x_gt,
|
|
noise_disable_threshold=cfg.get("noise_disable_threshold", None),
|
|
text_uncond_prob=text_uncond_prob,
|
|
x_noisy_ref=x_noisy_ref,
|
|
)
|
|
if record_time:
|
|
timer_list.append(loss_t)
|
|
|
|
# == backward & update ==
|
|
with timers["backward"] as backward_t:
|
|
loss = loss_dict["loss"].mean()
|
|
loss = loss / accumulation_steps
|
|
ctx = (
|
|
booster.no_sync(model, optimizer)
|
|
if cfg.get("plugin", "zero2") in ("zero1", "zero1-seq") and (step + 1) % accumulation_steps != 0
|
|
else nullcontext()
|
|
)
|
|
with ctx:
|
|
booster.backward(loss=loss, optimizer=optimizer)
|
|
if (step + 1) % accumulation_steps == 0:
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
# update learning rate
|
|
if lr_scheduler is not None:
|
|
lr_scheduler.step()
|
|
if record_time:
|
|
timer_list.append(backward_t)
|
|
|
|
# == update EMA ==
|
|
with timers["update_ema"] as ema_t:
|
|
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
|
|
if record_time:
|
|
timer_list.append(ema_t)
|
|
|
|
# == update log info ==
|
|
with timers["reduce_loss"] as reduce_loss_t:
|
|
all_reduce_mean(loss.data)
|
|
running_loss += loss.item() * accumulation_steps
|
|
global_step = epoch * num_steps_per_epoch + step
|
|
log_step += 1
|
|
acc_step += 1
|
|
if record_time:
|
|
timer_list.append(reduce_loss_t)
|
|
|
|
with timers["log"] as log_t:
|
|
# == logging ==
|
|
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
|
|
avg_loss = running_loss / log_step
|
|
# progress bar
|
|
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
|
|
# tensorboard
|
|
tb_writer.add_scalar("loss", loss.item() * accumulation_steps, global_step)
|
|
# wandb
|
|
if cfg.get("wandb", False):
|
|
wandb_dict = {
|
|
"iter": global_step,
|
|
"acc_step": acc_step,
|
|
"epoch": epoch,
|
|
"loss": loss.item() * accumulation_steps,
|
|
"avg_loss": avg_loss,
|
|
"lr": optimizer.param_groups[0]["lr"],
|
|
}
|
|
if record_time:
|
|
wandb_dict.update(
|
|
{
|
|
"debug/move_data_time": move_data_t.elapsed_time,
|
|
"debug/encode_time": encode_t.elapsed_time,
|
|
"debug/mask_time": mask_t.elapsed_time,
|
|
"debug/diffusion_time": loss_t.elapsed_time,
|
|
"debug/backward_time": backward_t.elapsed_time,
|
|
"debug/update_ema_time": ema_t.elapsed_time,
|
|
"debug/reduce_loss_time": reduce_loss_t.elapsed_time,
|
|
}
|
|
)
|
|
wandb.log(wandb_dict, step=global_step)
|
|
|
|
running_loss = 0.0
|
|
log_step = 0
|
|
if record_time:
|
|
timer_list.append(log_t)
|
|
|
|
# == checkpoint saving ==
|
|
with timers["checkpoint"] as checkpoint_t:
|
|
ckpt_every = cfg.get("ckpt_every", 0)
|
|
if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0:
|
|
model_gathering(ema, ema_shape_dict)
|
|
save_dir = save(
|
|
booster,
|
|
exp_dir,
|
|
model=model,
|
|
ema=ema,
|
|
optimizer=optimizer,
|
|
lr_scheduler=lr_scheduler,
|
|
sampler=sampler,
|
|
epoch=epoch,
|
|
step=step + 1,
|
|
global_step=global_step + 1,
|
|
batch_size=cfg.get("batch_size", None),
|
|
)
|
|
if dist.get_rank() == 0:
|
|
model_sharding(ema)
|
|
logger.info(
|
|
"Saved checkpoint at epoch %s, step %s, global_step %s to %s",
|
|
epoch,
|
|
step + 1,
|
|
global_step + 1,
|
|
save_dir,
|
|
)
|
|
if record_time:
|
|
timer_list.append(checkpoint_t)
|
|
if record_time:
|
|
total_step_time = sum([timer.elapsed_time for timer in timer_list])
|
|
log_str = (
|
|
f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | Step time: {total_step_time:.3f}s | "
|
|
)
|
|
for timer in timer_list:
|
|
log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | "
|
|
# print(log_str)
|
|
log_str += f"path: {paths}"
|
|
record_file.write(log_str + "\n")
|
|
record_file.flush()
|
|
sampler.reset()
|
|
start_step = 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|