Open-Sora/scripts/train-vae.py
Shen-Chenhui 40fd8a6b7b typo
2024-05-04 21:21:35 +08:00

371 lines
15 KiB
Python

import os
import random
from datetime import timedelta
from pprint import pprint
import torch
import torch.distributed as dist
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
from tqdm import tqdm
from einops import rearrange
import wandb
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group
from opensora.datasets import prepare_dataloader
from opensora.models.vae.losses import VAELoss, AdversarialLoss, DiscriminatorLoss
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.ckpt_utils import create_logger, load_json, save_json
from opensora.utils.config_utils import (
create_experiment_workspace,
create_tensorboard_writer,
parse_configs,
save_training_config,
)
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, to_torch_dtype
def main():
# ======================================================
# 1. args & cfg
# ======================================================
cfg = parse_configs(training=True)
exp_name, exp_dir = create_experiment_workspace(cfg)
save_training_config(cfg._cfg_dict, exp_dir)
# ======================================================
# 2. runtime variables & colossalai launch
# ======================================================
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}"
# 2.1. colossalai init distributed training
# we set a very large timeout to avoid some processes exit early
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(1024)
coordinator = DistCoordinator()
device = get_current_device()
dtype = to_torch_dtype(cfg.dtype)
# 2.2. init logger, tensorboard & wandb
if not coordinator.is_master():
logger = create_logger(None)
else:
print("Training configuration:")
pprint(cfg._cfg_dict)
logger = create_logger(exp_dir)
logger.info(f"Experiment directory created at {exp_dir}")
writer = create_tensorboard_writer(exp_dir)
if cfg.wandb:
wandb.init(project="minisora", name=exp_name, config=cfg._cfg_dict)
# 2.3. initialize ColossalAI booster
if cfg.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=cfg.dtype,
initial_scale=2**16,
max_norm=cfg.grad_clip,
)
set_data_parallel_group(dist.group.WORLD)
else:
raise ValueError(f"Unknown plugin {cfg.plugin}")
booster = Booster(plugin=plugin)
# ======================================================
# 3. build dataset and dataloader
# ======================================================
assert cfg.dataset.type == "VideoTextDataset", "Only support VideoTextDataset for now"
dataset = build_module(cfg.dataset, DATASETS)
logger.info(f"Dataset contains {len(dataset)} samples.")
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
seed=cfg.seed,
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
)
# TODO: use plugin's prepare dataloader
dataloader = prepare_dataloader(**dataloader_args)
total_batch_size = cfg.batch_size * dist.get_world_size()
logger.info(f"Total batch size: {total_batch_size}")
# ======================================================
# 4. build model
# ======================================================
# 4.1. build model
model = build_module(cfg.model, MODELS)
model.to(device, dtype)
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}"
)
# 4.4 loss functions
vae_loss_fn = VAELoss(
logvar_init=cfg.get("logvar_init", 0.0),
perceptual_loss_weight=cfg.perceptual_loss_weight,
kl_loss_weight=cfg.kl_loss_weight,
device=device,
dtype=dtype,
)
if cfg.get("discriminator", False) != False:
adversarial_loss_fn = AdversarialLoss(
discriminator_factor=cfg.discriminator_factor,
discriminator_start=cfg.discriminator_start,
generator_factor=cfg.generator_factor,
generator_loss_type=cfg.generator_loss_type,
)
disc_loss_fn = DiscriminatorLoss(
discriminator_factor=cfg.discriminator_factor,
discriminator_start=cfg.discriminator_start,
discriminator_loss_type=cfg.discriminator_loss_type,
lecam_loss_weight=cfg.lecam_loss_weight,
gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight,
)
# 4.5. setup optimizer
# vae optimizer
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True
)
lr_scheduler = None
# 4.6. prepare for training
if cfg.grad_checkpoint:
set_grad_checkpoint(model)
model.train()
# 4.7 add discriminator if specified in config
if cfg.get("discriminator", False) != False:
discriminator = build_module(cfg.discriminator, MODELS)
discriminator.to(device, dtype)
discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator)
logger.info(
f"Trainable model params: {format_numel_str(discriminator_numel_trainable)}, Total model params: {format_numel_str(discriminator_numel)}"
)
disc_optimizer = HybridAdam(
filter(lambda p: p.requires_grad, discriminator.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True
)
disc_lr_scheduler = None
discriminator.train()
# =======================================================
# 5. boost model for distributed training with colossalai
# =======================================================
torch.set_default_dtype(dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
if cfg.get("discriminator", False) != False:
discriminator, disc_optimizer, _, _, disc_lr_scheduler = booster.boost(
model=discriminator,
optimizer=disc_optimizer,
lr_scheduler=disc_lr_scheduler,
)
torch.set_default_dtype(torch.float)
num_steps_per_epoch = len(dataloader)
logger.info("Boost model for distributed training")
num_steps_per_epoch = len(dataloader)
# =======================================================
# 6. training loop
# =======================================================
start_epoch = start_step = log_step = sampler_start_idx = 0
acc_step = 0
running_loss = 0.0
running_disc_loss = 0.0
# 6.1. resume training
if cfg.load is not None:
logger.info("Loading checkpoint")
booster.load_model(model, os.path.join(cfg.load, "model"))
booster.load_optimizer(optimizer, os.path.join(cfg.load, "optimizer"))
running_states = load_json(os.path.join(cfg.load, "running_states.json"))
if cfg.get("discriminator", False) != False and os.path.exists(os.path.join(cfg.load, "discriminator")):
booster.load_model(discriminator, os.path.join(cfg.load, "discriminator"))
booster.load_optimizer(disc_optimizer, os.path.join(cfg.load, "disc_optimizer"))
dist.barrier()
start_epoch, start_step, sampler_start_idx = (
running_states["epoch"],
running_states["step"],
running_states["sample_start_index"],
)
logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}")
logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch")
dataloader.sampler.set_start_index(sampler_start_idx)
# 6.3. training loop
for epoch in range(start_epoch, cfg.epochs):
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info(f"Beginning epoch {epoch}...")
with tqdm(
enumerate(dataloader_iter, start=start_step),
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
total=num_steps_per_epoch,
initial=start_step,
) as pbar:
for step, batch in pbar:
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
# if random.random() < cfg.get("mixed_image_ratio", 0.0):
# x = x[:, :, :1, :, :]
length = random.randint(1, x.size(2))
x = x[:, :, :length, :, :]
# ===== VAE =====
x_rec, x_z_rec, z, posterior, x_z = model(x)
# ====== Generator Loss ======
vae_loss = torch.tensor(0.0, device=device, dtype=dtype)
disc_loss = torch.tensor(0.0, device=device, dtype=dtype)
log_dict = {}
# real image reconstruction loss
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
log_dict["kl_loss"] = weighted_kl_loss.item()
log_dict["nll_loss"] = weighted_nll_loss.item()
if cfg.get("use_real_rec_loss", False):
vae_loss += weighted_nll_loss + weighted_kl_loss
_, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior, no_perceptual=True)
log_dict["z_nll_loss"] = weighted_z_nll_loss.item()
# z reconstruction loss
if cfg.get("use_z_rec_loss", False):
vae_loss += weighted_z_nll_loss
# only for image
if cfg.get("use_image_identity_loss", False) and x.size(2) == 1:
_, image_identity_loss, _ = vae_loss_fn(x_z, z, posterior, no_perceptual=True)
vae_loss += image_identity_loss
log_dict["image_identity_loss"] = image_identity_loss.item()
# Adversarial Generator Loss
if cfg.get("discriminator", False) != False:
recon_video = rearrange(x_rec, "b c t h w -> (b t) c h w").contiguous()
global_step = epoch * num_steps_per_epoch + step
fake_logits = discriminator(recon_video.contiguous())
adversarial_loss = adversarial_loss_fn(
fake_logits,
nll_loss,
model.module.get_temporal_last_layer(),
global_step,
is_training=model.training,
)
vae_loss += adversarial_loss
# Backward & update
optimizer.zero_grad()
booster.backward(loss=vae_loss, optimizer=optimizer)
optimizer.step()
# Adversarial Discriminator loss
if cfg.get("discriminator", False) != False:
real_video = rearrange(x, "b c t h w -> (b t) c h w").contiguous()
fake_video = rearrange(x_rec, "b c t h w -> (b t) c h w").contiguous()
real_logits = discriminator(real_video.contiguous().detach())
fake_logits = discriminator(fake_video.contiguous().detach())
weighted_d_adversarial_loss, _, _ = disc_loss_fn(
real_logits,
fake_logits,
global_step,
)
disc_loss = weighted_d_adversarial_loss
# Backward & update
disc_optimizer.zero_grad()
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
disc_optimizer.step()
all_reduce_mean(disc_loss)
running_disc_loss += disc_loss.item()
# Log loss values:
all_reduce_mean(vae_loss)
running_loss += vae_loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
# Log to tensorboard
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
avg_loss = running_loss / log_step
avg_disc_loss = running_disc_loss / log_step
pbar.set_postfix(
{"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step}
)
running_loss = 0
log_step = 0
running_disc_loss = 0
writer.add_scalar("loss", vae_loss.item(), global_step)
if cfg.wandb:
wandb.log(
{
"iter": global_step,
"num_samples": global_step * total_batch_size,
"epoch": epoch,
"loss": vae_loss.item(),
"avg_loss": avg_loss,
**log_dict,
},
step=global_step,
)
# Save checkpoint
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model
booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
booster.save_optimizer(
optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096
)
if cfg.get("discriminator", False) != False:
booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
booster.save_optimizer(
disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096
)
running_states = {
"epoch": epoch,
"step": step + 1,
"global_step": global_step + 1,
"sample_start_index": (step + 1) * cfg.batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
dist.barrier()
logger.info(
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
)
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(0)
start_step = 0
if __name__ == "__main__":
main()