mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
367 lines
15 KiB
Python
367 lines
15 KiB
Python
import os
|
|
import random
|
|
from datetime import timedelta
|
|
from pprint import pformat
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import wandb
|
|
from colossalai.booster import Booster
|
|
from colossalai.cluster import DistCoordinator
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
from colossalai.utils import get_current_device, set_seed
|
|
from tqdm import tqdm
|
|
|
|
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
|
from opensora.acceleration.parallel_states import get_data_parallel_group
|
|
from opensora.datasets.dataloader import prepare_dataloader
|
|
from opensora.models.vae_v1_3.losses import DiscriminatorLoss, GeneratorLoss, VAELoss
|
|
from opensora.registry import DATASETS, MODELS, build_module
|
|
from opensora.utils.ckpt_utils import load, save
|
|
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
|
|
from opensora.utils.lr_scheduler import LinearWarmupLR
|
|
from opensora.utils.misc import (
|
|
all_reduce_mean,
|
|
create_logger,
|
|
create_tensorboard_writer,
|
|
format_numel_str,
|
|
get_model_numel,
|
|
to_torch_dtype,
|
|
)
|
|
from opensora.utils.train_utils import create_colossalai_plugin
|
|
|
|
|
|
def main():
|
|
# ======================================================
|
|
# 1. configs & runtime variables
|
|
# ======================================================
|
|
# == parse configs ==
|
|
cfg = parse_configs(training=True)
|
|
|
|
# == device and dtype ==
|
|
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
|
cfg_dtype = cfg.get("dtype", "bf16")
|
|
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
|
|
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
|
|
|
# == colossalai init distributed training ==
|
|
# NOTE: A very large timeout is set to avoid some processes exit early
|
|
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
|
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
|
set_seed(cfg.get("seed", 1024))
|
|
coordinator = DistCoordinator()
|
|
device = get_current_device()
|
|
|
|
# == init exp_dir ==
|
|
exp_name, exp_dir = define_experiment_workspace(cfg)
|
|
coordinator.block_all()
|
|
if coordinator.is_master():
|
|
os.makedirs(exp_dir, exist_ok=True)
|
|
save_training_config(cfg.to_dict(), exp_dir)
|
|
coordinator.block_all()
|
|
|
|
# == init logger, tensorboard & wandb ==
|
|
logger = create_logger(exp_dir)
|
|
logger.info("Experiment directory created at %s", exp_dir)
|
|
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
|
|
if coordinator.is_master():
|
|
tb_writer = create_tensorboard_writer(exp_dir)
|
|
if cfg.get("wandb", False):
|
|
wandb.init(project="causalvae", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb")
|
|
|
|
# == init ColossalAI booster ==
|
|
plugin = create_colossalai_plugin(
|
|
plugin=cfg.get("plugin", "zero2"),
|
|
dtype=cfg_dtype,
|
|
grad_clip=cfg.get("grad_clip", 0),
|
|
sp_size=cfg.get("sp_size", 1),
|
|
)
|
|
booster = Booster(plugin=plugin)
|
|
|
|
# ======================================================
|
|
# 2. build dataset and dataloader
|
|
# ======================================================
|
|
logger.info("Building dataset...")
|
|
# == build dataset ==
|
|
assert cfg.dataset.type == "VideoTextDataset", "Only support VideoTextDataset for vae training"
|
|
dataset = build_module(cfg.dataset, DATASETS)
|
|
logger.info("Dataset contains %s samples.", len(dataset))
|
|
|
|
# == build dataloader ==
|
|
dataloader_args = dict(
|
|
dataset=dataset,
|
|
batch_size=cfg.batch_size,
|
|
num_workers=cfg.get("num_workers", 4),
|
|
seed=cfg.get("seed", 1024),
|
|
shuffle=True,
|
|
drop_last=True,
|
|
pin_memory=True,
|
|
process_group=get_data_parallel_group(),
|
|
)
|
|
dataloader, sampler = prepare_dataloader(**dataloader_args)
|
|
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1)
|
|
logger.info("Total batch size: %s", total_batch_size)
|
|
num_steps_per_epoch = len(dataloader)
|
|
|
|
# ======================================================
|
|
# 3. build model
|
|
# ======================================================
|
|
logger.info("Building models...")
|
|
# == build vae model ==
|
|
model = build_module(cfg.model, MODELS).to(device, dtype).train()
|
|
model_numel, model_numel_trainable = get_model_numel(model)
|
|
logger.info(
|
|
"[VAE] Trainable model params: %s, Total model params: %s",
|
|
format_numel_str(model_numel_trainable),
|
|
format_numel_str(model_numel),
|
|
)
|
|
|
|
# == build discriminator model ==
|
|
use_discriminator = cfg.get("discriminator", None) is not None
|
|
if use_discriminator:
|
|
discriminator = build_module(cfg.discriminator, MODELS).to(device, dtype).train()
|
|
discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator)
|
|
logger.info(
|
|
"[Discriminator] Trainable model params: %s, Total model params: %s",
|
|
format_numel_str(discriminator_numel_trainable),
|
|
format_numel_str(discriminator_numel),
|
|
)
|
|
|
|
# == setup loss functions ==
|
|
vae_loss_fn = VAELoss(**cfg.vae_loss_config, device=device, dtype=dtype)
|
|
|
|
if use_discriminator:
|
|
generator_loss_fn = GeneratorLoss(**cfg.gan_loss_confg)
|
|
discriminator_loss_fn = DiscriminatorLoss(**cfg.gan_loss_confg)
|
|
|
|
# == setup vae optimizer ==
|
|
optimizer = HybridAdam(
|
|
filter(lambda p: p.requires_grad, model.parameters()),
|
|
adamw_mode=True,
|
|
lr=cfg.get("lr", 1e-5),
|
|
weight_decay=cfg.get("weight_decay", 0),
|
|
)
|
|
|
|
warmup_steps = cfg.get("warmup_steps", None)
|
|
if warmup_steps is None:
|
|
lr_scheduler = None
|
|
else:
|
|
lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=cfg.get("warmup_steps"))
|
|
|
|
# == setup discriminator optimizer ==
|
|
if use_discriminator:
|
|
disc_optimizer = HybridAdam(
|
|
filter(lambda p: p.requires_grad, discriminator.parameters()),
|
|
adamw_mode=True,
|
|
lr=cfg.get("lr", 1e-5),
|
|
weight_decay=cfg.get("weight_decay", 0),
|
|
)
|
|
disc_lr_scheduler = None
|
|
|
|
# == additional preparation ==
|
|
if cfg.get("grad_checkpoint", False):
|
|
set_grad_checkpoint(model)
|
|
|
|
# =======================================================
|
|
# 4. distributed training preparation with colossalai
|
|
# =======================================================
|
|
logger.info("Preparing for distributed training...")
|
|
# == boosting ==
|
|
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
|
|
torch.set_default_dtype(dtype)
|
|
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
|
model=model,
|
|
optimizer=optimizer,
|
|
lr_scheduler=lr_scheduler,
|
|
dataloader=dataloader,
|
|
)
|
|
if use_discriminator:
|
|
discriminator, disc_optimizer, _, _, disc_lr_scheduler = booster.boost(
|
|
model=discriminator,
|
|
optimizer=disc_optimizer,
|
|
lr_scheduler=disc_lr_scheduler,
|
|
)
|
|
torch.set_default_dtype(torch.float)
|
|
logger.info("Boosting model for distributed training")
|
|
|
|
# == global variables ==
|
|
cfg_epochs = cfg.get("epochs", 1000)
|
|
start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0
|
|
running_loss = running_disc_loss = running_gen_loss = 0.0
|
|
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
|
|
|
|
# == resume == # TODO
|
|
if cfg.get("load", None) is not None:
|
|
logger.info("Loading checkpoint")
|
|
start_epoch, start_step = load(
|
|
booster,
|
|
cfg.load,
|
|
model=model,
|
|
optimizer=optimizer,
|
|
lr_scheduler=lr_scheduler,
|
|
sampler=sampler,
|
|
)
|
|
if use_discriminator and os.path.exists(os.path.join(cfg.load, "discriminator")):
|
|
booster.load_model(discriminator, os.path.join(cfg.load, "discriminator"))
|
|
booster.load_optimizer(disc_optimizer, os.path.join(cfg.load, "disc_optimizer"))
|
|
dist.barrier()
|
|
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
|
|
|
|
# =======================================================
|
|
# 5. training loop
|
|
# =======================================================
|
|
dist.barrier()
|
|
for epoch in range(start_epoch, cfg_epochs):
|
|
# == set dataloader to new epoch ==
|
|
sampler.set_epoch(epoch)
|
|
dataiter = iter(dataloader)
|
|
logger.info("Beginning epoch %s...", epoch)
|
|
|
|
# == training loop in an epoch ==
|
|
with tqdm(
|
|
enumerate(dataiter, start=start_step),
|
|
desc=f"Epoch {epoch}",
|
|
disable=not coordinator.is_master(),
|
|
total=num_steps_per_epoch,
|
|
initial=start_step,
|
|
) as pbar:
|
|
for step, batch in pbar:
|
|
# == log config ==
|
|
global_step = epoch * num_steps_per_epoch + step
|
|
log_step += 1
|
|
acc_step += 1
|
|
|
|
# == vae encoding & decoding ===
|
|
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
|
# mixed training setting
|
|
mixed_strategy = cfg.get("mixed_strategy", None)
|
|
if mixed_strategy == "mixed_video_image":
|
|
if random.random() < cfg.get("mixed_image_ratio", 0.0):
|
|
x = x[:, :, :1, :, :]
|
|
elif mixed_strategy == "mixed_video_random":
|
|
length = random.randint(1, x.size(2))
|
|
x = x[:, :, :length, :, :]
|
|
z, x_rec, posterior = model(x, is_training=True)
|
|
|
|
# == loss initialization ==
|
|
vae_loss = torch.tensor(0.0, device=device, dtype=dtype)
|
|
disc_loss = torch.tensor(0.0, device=device, dtype=dtype)
|
|
log_dict = {}
|
|
|
|
# == reconstruction loss ==
|
|
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
|
|
log_dict["kl_loss"] = weighted_kl_loss.item()
|
|
log_dict["nll_loss"] = weighted_nll_loss.item()
|
|
vae_loss += weighted_nll_loss + weighted_kl_loss
|
|
|
|
# == generator loss ==
|
|
if use_discriminator:
|
|
fake_logits = discriminator(x_rec.contiguous())
|
|
generator_loss = generator_loss_fn(
|
|
fake_logits,
|
|
nll_loss,
|
|
model.module.get_last_layer(),
|
|
global_step,
|
|
is_training=model.training,
|
|
)
|
|
log_dict["generator_loss"] = generator_loss.item()
|
|
vae_loss += generator_loss
|
|
running_gen_loss += generator_loss.item()
|
|
|
|
# == generator backward & update ==
|
|
optimizer.zero_grad()
|
|
booster.backward(loss=vae_loss, optimizer=optimizer)
|
|
optimizer.step()
|
|
all_reduce_mean(vae_loss)
|
|
running_loss += vae_loss.item()
|
|
|
|
# == loss: discriminator adversarial ==
|
|
if use_discriminator:
|
|
real_logits = discriminator(x.contiguous().detach())
|
|
fake_logits = discriminator(x_rec.contiguous().detach())
|
|
disc_loss = discriminator_loss_fn(
|
|
real_logits,
|
|
fake_logits,
|
|
global_step,
|
|
)
|
|
log_dict["disc_loss"] = disc_loss.item()
|
|
|
|
# == discriminator backward & update ==
|
|
disc_optimizer.zero_grad()
|
|
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
|
|
disc_optimizer.step()
|
|
all_reduce_mean(disc_loss)
|
|
running_disc_loss += disc_loss.item()
|
|
|
|
# == logging ==
|
|
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
|
|
avg_loss = running_loss / log_step
|
|
avg_disc_loss = running_disc_loss / log_step
|
|
avg_gen_loss = running_gen_loss / log_step
|
|
# progress bar
|
|
pbar.set_postfix(
|
|
{
|
|
"vae_loss": avg_loss,
|
|
"gen_loss": avg_gen_loss,
|
|
"disc_loss": avg_disc_loss,
|
|
"step": step,
|
|
"global_step": global_step,
|
|
}
|
|
)
|
|
# tensorboard
|
|
tb_writer.add_scalar("loss", vae_loss.item(), global_step)
|
|
# wandb
|
|
if cfg.wandb:
|
|
wandb.log(
|
|
{
|
|
"iter": global_step,
|
|
"num_samples": global_step * total_batch_size,
|
|
"epoch": epoch,
|
|
"loss": vae_loss.item(),
|
|
"avg_loss": avg_loss,
|
|
**log_dict,
|
|
},
|
|
step=global_step,
|
|
)
|
|
running_loss = running_disc_loss = 0.0
|
|
log_step = 0
|
|
|
|
# == checkpoint saving ==
|
|
ckpt_every = cfg.get("ckpt_every", 0)
|
|
if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0:
|
|
save(
|
|
booster,
|
|
exp_dir,
|
|
model=model,
|
|
optimizer=optimizer,
|
|
lr_scheduler=lr_scheduler,
|
|
epoch=epoch,
|
|
step=step + 1,
|
|
global_step=global_step + 1,
|
|
batch_size=cfg.get("batch_size", None),
|
|
sampler=sampler,
|
|
)
|
|
|
|
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
|
if use_discriminator:
|
|
booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
|
|
booster.save_optimizer(
|
|
disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096
|
|
)
|
|
dist.barrier()
|
|
|
|
logger.info(
|
|
"Saved checkpoint at epoch %s step %s global_step %s to %s",
|
|
epoch,
|
|
step + 1,
|
|
global_step + 1,
|
|
exp_dir,
|
|
)
|
|
|
|
sampler.reset()
|
|
start_step = 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|