Open-Sora/scripts/train.py

390 lines
15 KiB
Python
Raw Normal View History

import os
2024-03-15 14:49:38 +01:00
from copy import deepcopy
from datetime import timedelta
2024-05-08 10:07:57 +02:00
from pprint import pformat
2024-03-15 14:49:38 +01:00
import torch
import torch.distributed as dist
2024-05-16 10:50:24 +02:00
import wandb
2024-03-15 14:49:38 +01:00
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
2024-03-15 14:49:38 +01:00
from tqdm import tqdm
2024-03-15 15:16:20 +01:00
from opensora.acceleration.checkpoint import set_grad_checkpoint
2024-05-08 10:07:57 +02:00
from opensora.acceleration.parallel_states import get_data_parallel_group
2024-05-21 07:45:06 +02:00
from opensora.datasets.dataloader import prepare_dataloader
2024-03-26 17:24:46 +01:00
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
2024-06-05 04:12:52 +02:00
from opensora.utils.lr_scheduler import LinearWarmupLR
2024-05-09 07:53:19 +02:00
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
2024-05-13 13:39:16 +02:00
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
2024-05-09 10:07:56 +02:00
from opensora.utils.misc import (
2024-05-30 10:49:39 +02:00
Timer,
2024-05-09 10:07:56 +02:00
all_reduce_mean,
create_logger,
2024-05-13 13:39:16 +02:00
create_tensorboard_writer,
2024-05-09 10:07:56 +02:00
format_numel_str,
get_model_numel,
requires_grad,
to_torch_dtype,
)
from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema
2024-05-08 10:07:57 +02:00
2024-03-15 14:49:38 +01:00
def main():
# ======================================================
2024-05-09 04:52:43 +02:00
# 1. configs & runtime variables
2024-03-15 14:49:38 +01:00
# ======================================================
2024-05-08 10:07:57 +02:00
# == parse configs ==
2024-03-15 14:49:38 +01:00
cfg = parse_configs(training=True)
2024-05-10 15:26:14 +02:00
# == device and dtype ==
2024-03-15 14:49:38 +01:00
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
2024-05-08 10:07:57 +02:00
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
2024-03-15 14:49:38 +01:00
2024-05-08 10:07:57 +02:00
# == colossalai init distributed training ==
# NOTE: A very large timeout is set to avoid some processes exit early
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
2024-05-08 10:38:04 +02:00
set_seed(cfg.get("seed", 1024))
2024-03-15 14:49:38 +01:00
coordinator = DistCoordinator()
device = get_current_device()
2024-05-08 10:07:57 +02:00
# == init exp_dir ==
exp_name, exp_dir = define_experiment_workspace(cfg)
coordinator.block_all()
if coordinator.is_master():
os.makedirs(exp_dir, exist_ok=True)
2024-05-08 10:07:57 +02:00
save_training_config(cfg.to_dict(), exp_dir)
coordinator.block_all()
2024-05-08 10:07:57 +02:00
# == init logger, tensorboard & wandb ==
2024-05-09 04:52:43 +02:00
logger = create_logger(exp_dir)
2024-05-13 08:33:12 +02:00
logger.info("Experiment directory created at %s", exp_dir)
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
2024-05-08 10:07:57 +02:00
if coordinator.is_master():
tb_writer = create_tensorboard_writer(exp_dir)
if cfg.get("wandb", False):
2024-06-07 07:49:34 +02:00
wandb.init(project="Open-Sora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb")
2024-05-08 10:07:57 +02:00
# == init ColossalAI booster ==
plugin = create_colossalai_plugin(
plugin=cfg.get("plugin", "zero2"),
dtype=cfg_dtype,
grad_clip=cfg.get("grad_clip", 0),
sp_size=cfg.get("sp_size", 1),
)
2024-03-15 14:49:38 +01:00
booster = Booster(plugin=plugin)
2024-05-30 10:09:03 +02:00
torch.set_num_threads(1)
2024-03-15 14:49:38 +01:00
# ======================================================
2024-05-08 10:07:57 +02:00
# 2. build dataset and dataloader
2024-03-15 14:49:38 +01:00
# ======================================================
2024-05-09 10:07:56 +02:00
logger.info("Building dataset...")
2024-05-08 10:07:57 +02:00
# == build dataset ==
2024-03-26 10:02:41 +01:00
dataset = build_module(cfg.dataset, DATASETS)
2024-05-08 10:07:57 +02:00
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader ==
dataloader_args = dict(
dataset=dataset,
2024-05-09 07:53:19 +02:00
batch_size=cfg.get("batch_size", None),
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
)
2024-05-20 10:40:45 +02:00
dataloader, sampler = prepare_dataloader(
bucket_config=cfg.get("bucket_config", None),
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
**dataloader_args,
)
num_steps_per_epoch = len(dataloader)
2024-03-15 14:49:38 +01:00
# ======================================================
2024-05-08 10:07:57 +02:00
# 3. build model
2024-03-15 14:49:38 +01:00
# ======================================================
2024-05-09 10:07:56 +02:00
logger.info("Building models...")
2024-05-08 10:07:57 +02:00
# == build text-encoder and vae ==
2024-05-21 07:45:06 +02:00
text_encoder = build_module(cfg.get("text_encoder", None), MODELS, device=device, dtype=dtype)
2024-05-21 09:20:14 +02:00
if text_encoder is not None:
text_encoder_output_dim = text_encoder.output_dim
text_encoder_model_max_length = text_encoder.model_max_length
else:
2024-05-21 11:55:26 +02:00
text_encoder_output_dim = cfg.get("text_encoder_output_dim", 4096)
2024-05-21 09:20:14 +02:00
text_encoder_model_max_length = cfg.get("text_encoder_model_max_length", 300)
2024-05-08 10:07:57 +02:00
2024-05-21 09:20:14 +02:00
# == build vae ==
vae = build_module(cfg.get("vae", None), MODELS)
if vae is not None:
vae = vae.to(device, dtype).eval()
2024-05-21 07:45:06 +02:00
if vae is not None:
2024-05-21 09:20:14 +02:00
input_size = (dataset.num_frames, *dataset.image_size)
2024-05-21 07:45:06 +02:00
latent_size = vae.get_latent_size(input_size)
2024-05-21 09:20:14 +02:00
vae_out_channels = vae.out_channels
2024-05-21 07:45:06 +02:00
else:
latent_size = (None, None, None)
2024-05-21 09:20:14 +02:00
vae_out_channels = cfg.get("vae_out_channels", 4)
# == build diffusion model ==
2024-05-13 08:33:12 +02:00
model = (
build_module(
cfg.model,
MODELS,
input_size=latent_size,
2024-05-21 09:20:14 +02:00
in_channels=vae_out_channels,
caption_channels=text_encoder_output_dim,
model_max_length=text_encoder_model_max_length,
2024-05-13 08:33:12 +02:00
)
.to(device, dtype)
.train()
)
2024-03-15 14:49:38 +01:00
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
2024-05-09 07:53:19 +02:00
"[Diffusion] Trainable model params: %s, Total model params: %s",
2024-05-08 10:07:57 +02:00
format_numel_str(model_numel_trainable),
format_numel_str(model_numel),
2024-03-15 14:49:38 +01:00
)
2024-05-08 10:07:57 +02:00
# == build ema for diffusion model ==
2024-03-15 14:49:38 +01:00
ema = deepcopy(model).to(torch.float32).to(device)
requires_grad(ema, False)
ema_shape_dict = record_model_param_shape(ema)
2024-05-08 10:07:57 +02:00
ema.eval()
update_ema(ema, model, decay=0, sharded=False)
2024-03-15 14:49:38 +01:00
2024-05-09 07:53:19 +02:00
# == setup loss function, build scheduler ==
2024-03-15 14:49:38 +01:00
scheduler = build_module(cfg.scheduler, SCHEDULERS)
2024-05-08 10:07:57 +02:00
# == setup optimizer ==
2024-03-15 14:49:38 +01:00
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()),
adamw_mode=True,
2024-05-08 10:07:57 +02:00
lr=cfg.get("lr", 1e-4),
weight_decay=cfg.get("weight_decay", 0),
2024-05-04 10:16:06 +02:00
eps=cfg.get("adam_eps", 1e-8),
2024-03-15 14:49:38 +01:00
)
2024-06-05 04:12:52 +02:00
warmup_steps = cfg.get("warmup_steps", None)
if warmup_steps is None:
lr_scheduler = None
else:
lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=cfg.get("warmup_steps"))
2024-03-15 14:49:38 +01:00
2024-05-08 10:07:57 +02:00
# == additional preparation ==
if cfg.get("grad_checkpoint", False):
2024-03-15 14:49:38 +01:00
set_grad_checkpoint(model)
2024-05-08 10:07:57 +02:00
if cfg.get("mask_ratios", None) is not None:
2024-03-23 15:06:19 +01:00
mask_generator = MaskGenerator(cfg.mask_ratios)
2024-03-15 14:49:38 +01:00
# =======================================================
2024-05-09 10:07:56 +02:00
# 4. distributed training preparation with colossalai
2024-03-15 14:49:38 +01:00
# =======================================================
2024-05-09 10:07:56 +02:00
logger.info("Preparing for distributed training...")
2024-05-08 10:07:57 +02:00
# == boosting ==
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
2024-03-15 14:49:38 +01:00
torch.set_default_dtype(dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
2024-03-15 14:49:38 +01:00
)
torch.set_default_dtype(torch.float)
2024-05-08 10:07:57 +02:00
logger.info("Boosting model for distributed training")
# == global variables ==
2024-05-09 10:07:56 +02:00
cfg_epochs = cfg.get("epochs", 1000)
2024-05-20 10:40:45 +02:00
start_epoch = start_step = log_step = acc_step = 0
2024-03-15 14:49:38 +01:00
running_loss = 0.0
2024-05-09 10:07:56 +02:00
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
2024-05-08 10:07:57 +02:00
# == resume ==
if cfg.get("load", None) is not None:
2024-03-15 14:49:38 +01:00
logger.info("Loading checkpoint")
2024-04-23 13:34:35 +02:00
ret = load(
booster,
cfg.load,
2024-05-09 07:53:19 +02:00
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
2024-05-31 08:46:30 +02:00
sampler=None if cfg.get("start_from_scratch", False) else sampler,
)
2024-05-31 08:46:30 +02:00
if not cfg.get("start_from_scratch", False):
2024-05-20 10:40:45 +02:00
start_epoch, start_step = ret
2024-05-08 10:07:57 +02:00
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
2024-03-15 14:49:38 +01:00
model_sharding(ema)
2024-05-08 10:07:57 +02:00
# =======================================================
2024-05-09 10:07:56 +02:00
# 5. training loop
2024-05-08 10:07:57 +02:00
# =======================================================
2024-05-09 07:53:19 +02:00
dist.barrier()
2024-05-08 10:07:57 +02:00
for epoch in range(start_epoch, cfg_epochs):
# == set dataloader to new epoch ==
2024-05-20 10:40:45 +02:00
sampler.set_epoch(epoch)
2024-03-15 14:49:38 +01:00
dataloader_iter = iter(dataloader)
2024-05-08 10:07:57 +02:00
logger.info("Beginning epoch %s...", epoch)
2024-03-15 14:49:38 +01:00
2024-05-08 10:07:57 +02:00
# == training loop in an epoch ==
2024-03-15 14:49:38 +01:00
with tqdm(
enumerate(dataloader_iter, start=start_step),
2024-03-15 14:49:38 +01:00
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
initial=start_step,
total=num_steps_per_epoch,
2024-03-15 14:49:38 +01:00
) as pbar:
for step, batch in pbar:
2024-05-29 08:43:15 +02:00
timer_list = []
with Timer("move data") as move_data_t:
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
timer_list.append(move_data_t)
2024-05-08 10:07:57 +02:00
# == visual and text encoding ==
2024-05-29 08:43:15 +02:00
with Timer("encode") as encode_t:
with torch.no_grad():
# Prepare visual inputs
if cfg.get("load_video_features", False):
x = x.to(device, dtype)
else:
x = vae.encode(x) # [B, C, T, H/P, W/P]
# Prepare text inputs
if cfg.get("load_text_features", False):
model_args = {"y": y.to(device, dtype)}
mask = batch.pop("mask")
if isinstance(mask, torch.Tensor):
mask = mask.to(device, dtype)
model_args["mask"] = mask
else:
model_args = text_encoder.encode(y)
coordinator.block_all()
timer_list.append(encode_t)
2024-03-26 09:50:36 +01:00
2024-05-08 10:07:57 +02:00
# == mask ==
2024-05-29 08:43:15 +02:00
with Timer("mask") as mask_t:
mask = None
if cfg.get("mask_ratios", None) is not None:
mask = mask_generator.get_masks(x)
model_args["x_mask"] = mask
coordinator.block_all()
timer_list.append(mask_t)
2024-03-15 14:49:38 +01:00
2024-05-08 10:07:57 +02:00
# == video meta info ==
for k, v in batch.items():
2024-05-21 06:05:02 +02:00
if isinstance(v, torch.Tensor):
model_args[k] = v.to(device, dtype)
2024-03-26 17:24:46 +01:00
2024-05-08 10:07:57 +02:00
# == diffusion loss computation ==
2024-05-29 08:43:15 +02:00
with Timer("diffusion") as loss_t:
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
coordinator.block_all()
timer_list.append(loss_t)
2024-03-15 14:49:38 +01:00
2024-05-08 10:07:57 +02:00
# == backward & update ==
2024-05-29 08:43:15 +02:00
with Timer("backward") as backward_t:
loss = loss_dict["loss"].mean()
booster.backward(loss=loss, optimizer=optimizer)
optimizer.step()
optimizer.zero_grad()
2024-06-05 04:12:52 +02:00
# update learning rate
if lr_scheduler is not None:
lr_scheduler.step()
2024-05-29 08:43:15 +02:00
coordinator.block_all()
timer_list.append(backward_t)
2024-03-15 14:49:38 +01:00
2024-05-08 10:07:57 +02:00
# == update EMA ==
2024-05-29 08:43:15 +02:00
with Timer("update_ema") as ema_t:
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
coordinator.block_all()
timer_list.append(ema_t)
2024-03-15 14:49:38 +01:00
2024-05-08 10:07:57 +02:00
# == update log info ==
2024-05-29 08:43:15 +02:00
with Timer("reduce_loss") as reduce_loss_t:
all_reduce_mean(loss)
running_loss += loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
coordinator.block_all()
timer_list.append(reduce_loss_t)
2024-03-15 14:49:38 +01:00
2024-05-08 10:07:57 +02:00
# == logging ==
2024-05-09 07:53:19 +02:00
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
2024-03-15 14:49:38 +01:00
avg_loss = running_loss / log_step
2024-05-08 10:07:57 +02:00
# progress bar
2024-03-30 06:34:19 +01:00
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
2024-05-08 10:07:57 +02:00
# tensorboard
tb_writer.add_scalar("loss", loss.item(), global_step)
# wandb
2024-05-09 07:53:19 +02:00
if cfg.get("wandb", False):
2024-03-15 14:49:38 +01:00
wandb.log(
{
"iter": global_step,
2024-06-05 17:49:51 +02:00
"acc_step": acc_step,
2024-03-15 14:49:38 +01:00
"epoch": epoch,
"loss": loss.item(),
"avg_loss": avg_loss,
2024-06-05 04:12:52 +02:00
"lr": optimizer.param_groups[0]["lr"],
2024-06-05 17:49:51 +02:00
"debug/move_data_time": move_data_t.elapsed_time,
"debug/encode_time": encode_t.elapsed_time,
"debug/mask_time": mask_t.elapsed_time,
"debug/diffusion_time": loss_t.elapsed_time,
"debug/backward_time": backward_t.elapsed_time,
"debug/update_ema_time": ema_t.elapsed_time,
"debug/reduce_loss_time": reduce_loss_t.elapsed_time,
2024-03-15 14:49:38 +01:00
},
step=global_step,
)
2024-05-09 07:53:19 +02:00
running_loss = 0.0
2024-05-08 10:07:57 +02:00
log_step = 0
# == checkpoint saving ==
2024-05-09 07:53:19 +02:00
ckpt_every = cfg.get("ckpt_every", 0)
if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0:
model_gathering(ema, ema_shape_dict)
2024-05-20 10:40:45 +02:00
save_dir = save(
2024-03-15 14:49:38 +01:00
booster,
exp_dir,
2024-05-09 07:53:19 +02:00
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
2024-05-20 10:40:45 +02:00
sampler=sampler,
2024-05-09 07:53:19 +02:00
epoch=epoch,
step=step + 1,
global_step=global_step + 1,
batch_size=cfg.get("batch_size", None),
2024-03-15 14:49:38 +01:00
)
2024-05-12 07:15:06 +02:00
if dist.get_rank() == 0:
model_sharding(ema)
2024-03-15 14:49:38 +01:00
logger.info(
2024-05-20 10:40:45 +02:00
"Saved checkpoint at epoch %s, step %s, global_step %s to %s",
2024-05-08 10:07:57 +02:00
epoch,
step + 1,
global_step + 1,
2024-05-20 10:40:45 +02:00
save_dir,
2024-03-15 14:49:38 +01:00
)
2024-05-29 08:43:15 +02:00
log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | "
for timer in timer_list:
log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | "
print(log_str)
coordinator.block_all()
2024-05-20 10:40:45 +02:00
sampler.reset()
2024-03-15 14:49:38 +01:00
start_step = 0
if __name__ == "__main__":
main()