Co-authored-by: Shen-Chenhui <shen_chenhui@u.nus.edu>
This commit is contained in:
Shen Chenhui 2024-05-08 17:39:21 +08:00 committed by GitHub
parent 338bd83f24
commit dc7b7fd64a
3 changed files with 149 additions and 120 deletions

View file

@ -249,7 +249,7 @@ def load(
load_dir: str,
sampler=None,
) -> Tuple[int, int, int]:
booster.load_model(model, os.path.join(load_dir, "model"))
booster.load_model(model, os.path.join(load_dir, model_name))
# ema is not boosted, so we don't use booster.load_model
ema.load_state_dict(
torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")),

View file

@ -98,7 +98,7 @@ def main():
)
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader = prepare_dataloader(**dataloader_args)
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1)
logger.info("Total batch size: %s", total_batch_size)
else:
dataloader = prepare_variable_dataloader(

View file

@ -1,12 +1,11 @@
import os
import random
from datetime import timedelta
from pprint import pprint
from pprint import pformat
import torch
import torch.distributed as dist
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
@ -15,7 +14,7 @@ from tqdm import tqdm
import wandb
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELoss
from opensora.registry import DATASETS, MODELS, build_module
@ -27,68 +26,64 @@ from opensora.utils.config_utils import (
save_training_config,
)
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, to_torch_dtype
from opensora.utils.train_utils import create_colossalai_plugin
DEFAULT_DATASET_NAME = "VideoTextDataset"
def main():
# ======================================================
# 1. args & cfg
# 1. configs & runtime variables & colossalai launch
# ======================================================
cfg = parse_configs(training=True)
# ======================================================
# 2. runtime variables & colossalai launch
# ======================================================
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}"
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
# 2.1. colossalai init distributed training
# we set a very large timeout to avoid some processes exit early
# == colossalai init distributed training ==
# NOTE: A very large timeout is set to avoid some processes exit early
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(1024)
coordinator = DistCoordinator()
device = get_current_device()
dtype = to_torch_dtype(cfg.dtype)
# 2.2. init exp_dir, logger, tensorboard & wandb
# == init exp_dir ==
exp_name, exp_dir = define_experiment_workspace(cfg)
coordinator.block_all()
if coordinator.is_master():
os.makedirs(exp_dir, exist_ok=True)
save_training_config(cfg._cfg_dict, exp_dir)
save_training_config(cfg.to_dict(), exp_dir)
coordinator.block_all()
if not coordinator.is_master():
logger = create_logger(None)
else:
print("Training configuration:")
pprint(cfg._cfg_dict)
# == init logger, tensorboard & wandb ==
if coordinator.is_master():
logger = create_logger(exp_dir)
logger.info(f"Experiment directory created at {exp_dir}")
writer = create_tensorboard_writer(exp_dir)
if cfg.wandb:
wandb.init(project="minisora", name=exp_name, config=cfg._cfg_dict)
# 2.3. initialize ColossalAI booster
if cfg.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=cfg.dtype,
initial_scale=2**16,
max_norm=cfg.grad_clip,
)
set_data_parallel_group(dist.group.WORLD)
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
tb_writer = create_tensorboard_writer(exp_dir)
if cfg.get("wandb", False):
wandb.init(project="minisora", name=exp_name, config=cfg.to_dict())
else:
raise ValueError(f"Unknown plugin {cfg.plugin}")
logger = create_logger(None)
# == init ColossalAI booster ==
plugin = create_colossalai_plugin(
plugin=cfg.get("plugin", "zero2"),
dtype=cfg_dtype,
grad_clip=cfg.get("grad_clip", 0),
sp_size=cfg.get("sp_size", 1),
)
booster = Booster(plugin=plugin)
# ======================================================
# 3. build dataset and dataloader
# 2. build dataset and dataloader
# ======================================================
assert cfg.dataset.type == "VideoTextDataset", "Only support VideoTextDataset for now"
dataset = build_module(cfg.dataset, DATASETS)
logger.info(f"Dataset contains {len(dataset)} samples.")
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader ==
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.batch_size,
@ -99,76 +94,93 @@ def main():
pin_memory=True,
process_group=get_data_parallel_group(),
)
# TODO: use plugin's prepare dataloader
dataloader = prepare_dataloader(**dataloader_args)
total_batch_size = cfg.batch_size * dist.get_world_size()
logger.info(f"Total batch size: {total_batch_size}")
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader = prepare_dataloader(**dataloader_args)
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1)
logger.info("Total batch size: %s", total_batch_size)
else:
dataloader = prepare_variable_dataloader(
bucket_config=cfg.bucket_config,
num_bucket_build_workers=cfg.num_bucket_build_workers,
**dataloader_args,
)
# ======================================================
# 4. build model
# 3. build model
# ======================================================
# 4.1. build model
model = build_module(cfg.model, MODELS)
model.to(device, dtype)
# == build vae model ==
model = build_module(cfg.model, MODELS).to(device, dtype)
model.train()
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}"
"Trainable model params: %s, Total model params: %s",
format_numel_str(model_numel_trainable),
format_numel_str(model_numel),
)
# == build discriminator model ==
if cfg.get("discriminator", False) != False:
discriminator = build_module(cfg.discriminator, MODELS).to(device, dtype)
discriminator.train()
discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator)
logger.info(
"Trainable model params: %s, Total model params: %s",
format_numel_str(discriminator_numel_trainable),
format_numel_str(discriminator_numel),
)
# 4.4 loss functions
# == setup loss functions ==
vae_loss_fn = VAELoss(
logvar_init=cfg.get("logvar_init", 0.0),
perceptual_loss_weight=cfg.perceptual_loss_weight,
kl_loss_weight=cfg.kl_loss_weight,
perceptual_loss_weight=cfg.get("perceptual_loss_weight", 0.1),
kl_loss_weight=cfg.get("kl_loss_weight", 1e-6),
device=device,
dtype=dtype,
)
if cfg.get("discriminator", False) != False:
adversarial_loss_fn = AdversarialLoss(
discriminator_factor=cfg.discriminator_factor,
discriminator_start=cfg.discriminator_start,
generator_factor=cfg.generator_factor,
generator_loss_type=cfg.generator_loss_type,
discriminator_factor=cfg.get("discriminator_factor", 1),
discriminator_start=cfg.get("discriminator_start", -1),
generator_factor=cfg.get("generator_factor", 0.5),
generator_loss_type=cfg.get("generator_loss_type", "hinge"),
)
disc_loss_fn = DiscriminatorLoss(
discriminator_factor=cfg.discriminator_factor,
discriminator_start=cfg.discriminator_start,
discriminator_loss_type=cfg.discriminator_loss_type,
lecam_loss_weight=cfg.lecam_loss_weight,
gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight,
discriminator_factor=cfg.get("discriminator_factor", 1),
discriminator_start=cfg.get("discriminator_start", -1),
discriminator_loss_type=cfg.get("discriminator_loss_type", "hinge"),
lecam_loss_weight=cfg.get("lecam_loss_weight", None),
gradient_penalty_loss_weight=cfg.get("gradient_penalty_loss_weight", None),
)
# 4.5. setup optimizer
# vae optimizer
# == setup vae optimizer ==
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True
filter(lambda p: p.requires_grad, model.parameters()),
adamw_mode=True,
lr=cfg.get("lr", 1e-5),
weight_decay=cfg.get("weight_decay", 0),
)
lr_scheduler = None
# 4.6. prepare for training
if cfg.grad_checkpoint:
set_grad_checkpoint(model)
model.train()
# 4.7 add discriminator if specified in config
# == setup discriminator optimizer ==
if cfg.get("discriminator", False) != False:
discriminator = build_module(cfg.discriminator, MODELS)
discriminator.to(device, dtype)
discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator)
logger.info(
f"Trainable model params: {format_numel_str(discriminator_numel_trainable)}, Total model params: {format_numel_str(discriminator_numel)}"
)
disc_optimizer = HybridAdam(
filter(lambda p: p.requires_grad, discriminator.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True
filter(lambda p: p.requires_grad, discriminator.parameters()),
adamw_mode=True,
lr=cfg.get("lr", 1e-5),
weight_decay=cfg.get("weight_decay", 0),
)
disc_lr_scheduler = None
discriminator.train()
# == additional preparation ==
if cfg.get("grad_checkpoint", False):
set_grad_checkpoint(model)
# =======================================================
# 5. boost model for distributed training with colossalai
# 4. distributed training preparation with colossalai
# =======================================================
# == boosting ==
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
torch.set_default_dtype(dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
@ -182,49 +194,53 @@ def main():
optimizer=disc_optimizer,
lr_scheduler=disc_lr_scheduler,
)
torch.set_default_dtype(torch.float)
num_steps_per_epoch = len(dataloader)
logger.info("Boost model for distributed training")
num_steps_per_epoch = len(dataloader)
logger.info("Boosting model for distributed training")
if cfg.dataset.type == DEFAULT_DATASET_NAME:
num_steps_per_epoch = len(dataloader)
else:
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
None if cfg.get("start_from_scratch ", False) else dataloader.batch_sampler
# =======================================================
# 6. training loop
# =======================================================
start_epoch = start_step = log_step = sampler_start_idx = 0
acc_step = 0
cfg_epochs = cfg.get("epochs", 1000)
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
# == global variables ==
start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0
running_loss = 0.0
running_disc_loss = 0.0
# 6.1. resume training
if cfg.load is not None:
# == resume ==
if cfg.get("load", None) is not None:
logger.info("Loading checkpoint")
booster.load_model(model, os.path.join(cfg.load, "model"))
booster.load_optimizer(optimizer, os.path.join(cfg.load, "optimizer"))
running_states = load_json(os.path.join(cfg.load, "running_states.json"))
if cfg.get("discriminator", False) != False and os.path.exists(os.path.join(cfg.load, "discriminator")):
booster.load_model(discriminator, os.path.join(cfg.load, "discriminator"))
booster.load_optimizer(disc_optimizer, os.path.join(cfg.load, "disc_optimizer"))
dist.barrier()
running_states = load_json(os.path.join(cfg.load, "running_states.json"))
start_epoch, start_step, sampler_start_idx = (
running_states["epoch"],
running_states["step"],
running_states["sample_start_index"],
)
logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}")
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
dist.barrier()
logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch")
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_start_index(sampler_start_idx)
dataloader.sampler.set_start_index(sampler_start_idx)
# 6.3. training loop
for epoch in range(start_epoch, cfg.epochs):
dataloader.sampler.set_epoch(epoch)
# =======================================================
# 5. training loop
# =======================================================
for epoch in range(start_epoch, cfg_epochs):
# == set dataloader to new epoch ==
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info(f"Beginning epoch {epoch}...")
logger.info("Beginning epoch %s...", epoch)
# == training loop in an epoch ==
with tqdm(
enumerate(dataloader_iter, start=start_step),
desc=f"Epoch {epoch}",
@ -239,35 +255,36 @@ def main():
length = random.randint(1, x.size(2))
x = x[:, :, :length, :, :]
# ===== VAE =====
# == vae encoding ===
x_rec, x_z_rec, z, posterior, x_z = model(x)
# ====== Generator Loss ======
# == loss initialization ==
vae_loss = torch.tensor(0.0, device=device, dtype=dtype)
disc_loss = torch.tensor(0.0, device=device, dtype=dtype)
log_dict = {}
# real image reconstruction loss
# == real image reconstruction loss computation ==
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
log_dict["kl_loss"] = weighted_kl_loss.item()
log_dict["nll_loss"] = weighted_nll_loss.item()
if cfg.get("use_real_rec_loss", False):
vae_loss += weighted_nll_loss + weighted_kl_loss
# == temporal vae reconstruction loss computation ==
_, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior, no_perceptual=True)
log_dict["z_nll_loss"] = weighted_z_nll_loss.item()
# z reconstruction loss
# == use z reconstruction loss ==
if cfg.get("use_z_rec_loss", False):
vae_loss += weighted_z_nll_loss
# only for image
# == use image loss only ==
if cfg.get("use_image_identity_loss", False) and x.size(2) == 1:
_, image_identity_loss, _ = vae_loss_fn(x_z, z, posterior, no_perceptual=True)
vae_loss += image_identity_loss
log_dict["image_identity_loss"] = image_identity_loss.item()
# Adversarial Generator Loss
# == generator adversarial loss ==
if cfg.get("discriminator", False) != False:
recon_video = rearrange(x_rec, "b c t h w -> (b t) c h w").contiguous()
global_step = epoch * num_steps_per_epoch + step
@ -281,12 +298,12 @@ def main():
)
vae_loss += adversarial_loss
# Backward & update
# == generator backward & update ==
optimizer.zero_grad()
booster.backward(loss=vae_loss, optimizer=optimizer)
optimizer.step()
# Adversarial Discriminator loss
# == discriminator adversarial loss ==
if cfg.get("discriminator", False) != False:
real_video = rearrange(x, "b c t h w -> (b t) c h w").contiguous()
fake_video = rearrange(x_rec, "b c t h w -> (b t) c h w").contiguous()
@ -298,31 +315,30 @@ def main():
global_step,
)
disc_loss = weighted_d_adversarial_loss
# Backward & update
# == discriminator backward & update ==
disc_optimizer.zero_grad()
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
disc_optimizer.step()
all_reduce_mean(disc_loss)
running_disc_loss += disc_loss.item()
# Log loss values:
# == update log info ==
all_reduce_mean(vae_loss)
running_loss += vae_loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
# Log to tensorboard
# == logging ==
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
avg_loss = running_loss / log_step
avg_disc_loss = running_disc_loss / log_step
pbar.set_postfix(
{"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step}
)
running_loss = 0
log_step = 0
running_disc_loss = 0
writer.add_scalar("loss", vae_loss.item(), global_step)
# tensorboard
tb_writer.add_scalar("loss", vae_loss.item(), global_step)
if cfg.wandb:
wandb.log(
{
@ -335,8 +351,11 @@ def main():
},
step=global_step,
)
running_loss = 0
running_disc_loss = 0
log_step = 0
# Save checkpoint
# == checkpoint saving ==
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model
@ -356,15 +375,25 @@ def main():
"global_step": global_step + 1,
"sample_start_index": (step + 1) * cfg.batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
dist.barrier()
logger.info(
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
"Saved checkpoint at epoch %s step %s global_step %s to %s",
epoch,
step + 1,
global_step + 1,
exp_dir,
)
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(0)
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_start_index(0)
else:
dataloader.batch_sampler.set_epoch(epoch + 1)
logger.info("Epoch done, recomputing batch sampler")
start_step = 0