Open-Sora/scripts/train_vae.py

392 lines
16 KiB
Python
Raw Normal View History

2024-04-26 09:27:26 +02:00
import os
2024-04-30 08:19:40 +02:00
import random
2024-04-29 19:02:22 +02:00
from datetime import timedelta
from pprint import pformat
2024-04-11 04:50:23 +02:00
import torch
import torch.distributed as dist
2024-06-17 17:37:23 +02:00
import wandb
2024-04-11 04:50:23 +02:00
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
2024-04-29 19:02:22 +02:00
from colossalai.utils import get_current_device, set_seed
2024-05-02 11:28:45 +02:00
from einops import rearrange
2024-05-07 10:18:44 +02:00
from tqdm import tqdm
2024-04-11 04:50:23 +02:00
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets.dataloader import prepare_dataloader
2024-05-07 10:18:44 +02:00
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELoss
2024-04-26 09:27:26 +02:00
from opensora.registry import DATASETS, MODELS, build_module
2024-05-09 07:53:19 +02:00
from opensora.utils.ckpt_utils import load, save
2024-05-13 13:39:16 +02:00
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
from opensora.utils.misc import (
all_reduce_mean,
create_logger,
2024-04-11 04:50:23 +02:00
create_tensorboard_writer,
2024-05-13 13:39:16 +02:00
format_numel_str,
get_model_numel,
to_torch_dtype,
2024-04-11 04:50:23 +02:00
)
2024-05-09 10:07:56 +02:00
from opensora.utils.train_utils import create_colossalai_plugin
2024-04-11 04:50:23 +02:00
def main():
# ======================================================
2024-05-09 07:53:19 +02:00
# 1. configs & runtime variables
2024-04-11 04:50:23 +02:00
# ======================================================
2024-05-09 07:53:19 +02:00
# == parse configs ==
2024-04-11 04:50:23 +02:00
cfg = parse_configs(training=True)
2024-05-10 15:26:14 +02:00
# == device and dtype ==
2024-04-11 04:50:23 +02:00
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
2024-04-26 09:27:26 +02:00
# == colossalai init distributed training ==
# NOTE: A very large timeout is set to avoid some processes exit early
2024-04-29 19:02:22 +02:00
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
2024-05-09 07:53:19 +02:00
set_seed(cfg.get("seed", 1024))
2024-04-11 04:50:23 +02:00
coordinator = DistCoordinator()
device = get_current_device()
# == init exp_dir ==
exp_name, exp_dir = define_experiment_workspace(cfg)
coordinator.block_all()
if coordinator.is_master():
os.makedirs(exp_dir, exist_ok=True)
save_training_config(cfg.to_dict(), exp_dir)
coordinator.block_all()
# == init logger, tensorboard & wandb ==
2024-05-09 07:53:19 +02:00
logger = create_logger(exp_dir)
2024-05-13 08:33:12 +02:00
logger.info("Experiment directory created at %s", exp_dir)
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
if coordinator.is_master():
tb_writer = create_tensorboard_writer(exp_dir)
if cfg.get("wandb", False):
2024-05-09 07:53:19 +02:00
wandb.init(project="minisora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb")
# == init ColossalAI booster ==
plugin = create_colossalai_plugin(
plugin=cfg.get("plugin", "zero2"),
dtype=cfg_dtype,
grad_clip=cfg.get("grad_clip", 0),
sp_size=cfg.get("sp_size", 1),
)
2024-04-11 04:50:23 +02:00
booster = Booster(plugin=plugin)
# ======================================================
# 2. build dataset and dataloader
2024-04-11 04:50:23 +02:00
# ======================================================
2024-05-09 10:07:56 +02:00
logger.info("Building dataset...")
2024-05-09 07:53:19 +02:00
# == build dataset ==
2024-05-20 10:40:45 +02:00
assert cfg.dataset.type == "VideoTextDataset", "Only support VideoTextDataset for vae training"
2024-04-26 09:27:26 +02:00
dataset = build_module(cfg.dataset, DATASETS)
logger.info("Dataset contains %s samples.", len(dataset))
2024-05-09 07:53:19 +02:00
# == build dataloader ==
2024-04-26 09:27:26 +02:00
dataloader_args = dict(
dataset=dataset,
2024-04-11 04:50:23 +02:00
batch_size=cfg.batch_size,
2024-05-09 07:53:19 +02:00
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
2024-04-11 04:50:23 +02:00
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
)
2024-05-20 10:40:45 +02:00
dataloader, sampler = prepare_dataloader(**dataloader_args)
2024-05-09 07:53:19 +02:00
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1)
logger.info("Total batch size: %s", total_batch_size)
2024-05-09 10:07:56 +02:00
num_steps_per_epoch = len(dataloader)
2024-04-11 04:50:23 +02:00
# ======================================================
# 3. build model
2024-04-11 04:50:23 +02:00
# ======================================================
2024-05-09 10:07:56 +02:00
logger.info("Building models...")
# == build vae model ==
2024-05-13 08:33:12 +02:00
model = build_module(cfg.model, MODELS).to(device, dtype).train()
2024-04-29 19:02:22 +02:00
model_numel, model_numel_trainable = get_model_numel(model)
2024-04-12 11:57:35 +02:00
logger.info(
2024-05-09 07:53:19 +02:00
"[VAE] Trainable model params: %s, Total model params: %s",
format_numel_str(model_numel_trainable),
format_numel_str(model_numel),
2024-04-12 11:57:35 +02:00
)
2024-05-09 07:53:19 +02:00
# == build discriminator model ==
2024-05-09 07:53:19 +02:00
use_discriminator = cfg.get("discriminator", None) is not None
if use_discriminator:
2024-05-13 08:33:12 +02:00
discriminator = build_module(cfg.discriminator, MODELS).to(device, dtype).train()
discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator)
logger.info(
2024-05-09 07:53:19 +02:00
"[Discriminator] Trainable model params: %s, Total model params: %s",
format_numel_str(discriminator_numel_trainable),
format_numel_str(discriminator_numel),
)
2024-04-11 04:50:23 +02:00
# == setup loss functions ==
2024-04-29 19:02:22 +02:00
vae_loss_fn = VAELoss(
2024-04-30 10:13:20 +02:00
logvar_init=cfg.get("logvar_init", 0.0),
perceptual_loss_weight=cfg.get("perceptual_loss_weight", 0.1),
kl_loss_weight=cfg.get("kl_loss_weight", 1e-6),
2024-04-29 19:02:22 +02:00
device=device,
dtype=dtype,
)
2024-05-09 07:53:19 +02:00
if use_discriminator:
2024-05-02 11:28:45 +02:00
adversarial_loss_fn = AdversarialLoss(
discriminator_factor=cfg.get("discriminator_factor", 1),
discriminator_start=cfg.get("discriminator_start", -1),
generator_factor=cfg.get("generator_factor", 0.5),
generator_loss_type=cfg.get("generator_loss_type", "hinge"),
2024-05-02 11:28:45 +02:00
)
disc_loss_fn = DiscriminatorLoss(
discriminator_factor=cfg.get("discriminator_factor", 1),
discriminator_start=cfg.get("discriminator_start", -1),
discriminator_loss_type=cfg.get("discriminator_loss_type", "hinge"),
lecam_loss_weight=cfg.get("lecam_loss_weight", None),
gradient_penalty_loss_weight=cfg.get("gradient_penalty_loss_weight", None),
2024-05-02 11:28:45 +02:00
)
# == setup vae optimizer ==
2024-04-11 04:50:23 +02:00
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()),
adamw_mode=True,
lr=cfg.get("lr", 1e-5),
weight_decay=cfg.get("weight_decay", 0),
2024-04-11 04:50:23 +02:00
)
lr_scheduler = None
# == setup discriminator optimizer ==
2024-05-09 07:53:19 +02:00
if use_discriminator:
2024-05-02 11:28:45 +02:00
disc_optimizer = HybridAdam(
filter(lambda p: p.requires_grad, discriminator.parameters()),
adamw_mode=True,
lr=cfg.get("lr", 1e-5),
weight_decay=cfg.get("weight_decay", 0),
2024-05-02 11:28:45 +02:00
)
disc_lr_scheduler = None
# == additional preparation ==
if cfg.get("grad_checkpoint", False):
set_grad_checkpoint(model)
2024-05-02 11:28:45 +02:00
2024-04-11 04:50:23 +02:00
# =======================================================
# 4. distributed training preparation with colossalai
2024-04-11 04:50:23 +02:00
# =======================================================
2024-05-09 10:07:56 +02:00
logger.info("Preparing for distributed training...")
# == boosting ==
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
2024-04-11 04:50:23 +02:00
torch.set_default_dtype(dtype)
2024-04-29 19:02:22 +02:00
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
2024-04-11 04:50:23 +02:00
)
2024-05-09 07:53:19 +02:00
if use_discriminator:
2024-05-02 11:28:45 +02:00
discriminator, disc_optimizer, _, _, disc_lr_scheduler = booster.boost(
model=discriminator,
optimizer=disc_optimizer,
lr_scheduler=disc_lr_scheduler,
)
2024-04-11 04:50:23 +02:00
torch.set_default_dtype(torch.float)
logger.info("Boosting model for distributed training")
# == global variables ==
2024-05-09 10:07:56 +02:00
cfg_epochs = cfg.get("epochs", 1000)
start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0
2024-05-09 07:53:19 +02:00
running_loss = running_disc_loss = 0.0
2024-05-09 10:07:56 +02:00
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
# == resume ==
if cfg.get("load", None) is not None:
2024-04-11 04:50:23 +02:00
logger.info("Loading checkpoint")
2024-05-20 10:40:45 +02:00
start_epoch, start_step = load(
2024-05-09 07:53:19 +02:00
booster,
cfg.load,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
2024-05-20 10:40:45 +02:00
sampler=sampler,
2024-05-09 07:53:19 +02:00
)
if use_discriminator and os.path.exists(os.path.join(cfg.load, "discriminator")):
2024-05-02 11:28:45 +02:00
booster.load_model(discriminator, os.path.join(cfg.load, "discriminator"))
booster.load_optimizer(disc_optimizer, os.path.join(cfg.load, "disc_optimizer"))
dist.barrier()
2024-05-09 07:53:19 +02:00
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
2024-04-11 04:50:23 +02:00
# =======================================================
# 5. training loop
# =======================================================
2024-05-09 07:53:19 +02:00
dist.barrier()
for epoch in range(start_epoch, cfg_epochs):
# == set dataloader to new epoch ==
2024-05-20 10:40:45 +02:00
sampler.set_epoch(epoch)
2024-05-14 07:40:17 +02:00
dataiter = iter(dataloader)
logger.info("Beginning epoch %s...", epoch)
2024-04-11 04:50:23 +02:00
# == training loop in an epoch ==
2024-04-11 04:50:23 +02:00
with tqdm(
2024-05-14 07:40:17 +02:00
enumerate(dataiter, start=start_step),
2024-04-11 04:50:23 +02:00
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
total=num_steps_per_epoch,
initial=start_step,
) as pbar:
2024-04-29 19:02:22 +02:00
for step, batch in pbar:
2024-04-26 09:27:26 +02:00
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
2024-05-09 07:53:19 +02:00
# == mixed training setting ==
mixed_strategy = cfg.get("mixed_strategy", None)
if mixed_strategy == "mixed_video_image":
if random.random() < cfg.get("mixed_image_ratio", 0.0):
x = x[:, :, :1, :, :]
elif mixed_strategy == "mixed_video_random":
length = random.randint(1, x.size(2))
x = x[:, :, :length, :, :]
2024-04-26 09:27:26 +02:00
2024-05-14 07:40:17 +02:00
# == vae encoding & decoding ===
2024-04-30 10:13:20 +02:00
x_rec, x_z_rec, z, posterior, x_z = model(x)
2024-04-27 15:46:09 +02:00
# == loss initialization ==
2024-04-30 10:13:20 +02:00
vae_loss = torch.tensor(0.0, device=device, dtype=dtype)
2024-05-02 11:28:45 +02:00
disc_loss = torch.tensor(0.0, device=device, dtype=dtype)
2024-04-30 10:13:20 +02:00
log_dict = {}
2024-05-02 05:46:01 +02:00
2024-05-09 07:53:19 +02:00
# == loss: real image reconstruction ==
2024-05-02 11:28:45 +02:00
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
2024-05-02 05:46:01 +02:00
log_dict["kl_loss"] = weighted_kl_loss.item()
log_dict["nll_loss"] = weighted_nll_loss.item()
2024-04-30 10:13:20 +02:00
if cfg.get("use_real_rec_loss", False):
vae_loss += weighted_nll_loss + weighted_kl_loss
2024-05-02 05:46:01 +02:00
2024-05-09 07:53:19 +02:00
# == loss: temporal vae reconstruction ==
2024-05-02 05:46:01 +02:00
_, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior, no_perceptual=True)
log_dict["z_nll_loss"] = weighted_z_nll_loss.item()
2024-04-30 10:13:20 +02:00
if cfg.get("use_z_rec_loss", False):
vae_loss += weighted_z_nll_loss
2024-05-02 05:46:01 +02:00
2024-05-09 07:53:19 +02:00
# == loss: image only distillation ==
2024-05-02 05:46:01 +02:00
if cfg.get("use_image_identity_loss", False) and x.size(2) == 1:
_, image_identity_loss, _ = vae_loss_fn(x_z, z, posterior, no_perceptual=True)
2024-04-30 10:13:20 +02:00
vae_loss += image_identity_loss
log_dict["image_identity_loss"] = image_identity_loss.item()
2024-04-29 19:02:22 +02:00
2024-05-09 07:53:19 +02:00
# == loss: generator adversarial ==
if use_discriminator:
2024-05-02 11:28:45 +02:00
recon_video = rearrange(x_rec, "b c t h w -> (b t) c h w").contiguous()
global_step = epoch * num_steps_per_epoch + step
fake_logits = discriminator(recon_video.contiguous())
adversarial_loss = adversarial_loss_fn(
fake_logits,
nll_loss,
model.module.get_temporal_last_layer(),
global_step,
is_training=model.training,
)
2024-05-09 07:53:19 +02:00
log_dict["adversarial_loss"] = adversarial_loss.item()
2024-05-02 11:28:45 +02:00
vae_loss += adversarial_loss
# == generator backward & update ==
2024-05-02 11:48:30 +02:00
optimizer.zero_grad()
2024-04-26 09:27:26 +02:00
booster.backward(loss=vae_loss, optimizer=optimizer)
optimizer.step()
2024-05-09 07:53:19 +02:00
all_reduce_mean(vae_loss)
running_loss += vae_loss.item()
2024-04-26 09:27:26 +02:00
2024-05-09 07:53:19 +02:00
# == loss: discriminator adversarial ==
if use_discriminator:
2024-05-02 11:28:45 +02:00
real_video = rearrange(x, "b c t h w -> (b t) c h w").contiguous()
fake_video = rearrange(x_rec, "b c t h w -> (b t) c h w").contiguous()
real_logits = discriminator(real_video.contiguous().detach())
fake_logits = discriminator(fake_video.contiguous().detach())
weighted_d_adversarial_loss, _, _ = disc_loss_fn(
2024-05-07 10:18:44 +02:00
real_logits,
fake_logits,
global_step,
2024-05-02 11:28:45 +02:00
)
disc_loss = weighted_d_adversarial_loss
2024-05-09 07:53:19 +02:00
log_dict["disc_loss"] = disc_loss.item()
# == discriminator backward & update ==
2024-05-02 11:48:30 +02:00
disc_optimizer.zero_grad()
2024-05-02 11:28:45 +02:00
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
disc_optimizer.step()
all_reduce_mean(disc_loss)
running_disc_loss += disc_loss.item()
# == update log info ==
2024-04-29 19:02:22 +02:00
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
2024-04-26 09:27:26 +02:00
# == logging ==
2024-05-09 07:53:19 +02:00
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
2024-04-26 09:27:26 +02:00
avg_loss = running_loss / log_step
avg_disc_loss = running_disc_loss / log_step
2024-05-09 07:53:19 +02:00
# progress bar
2024-04-26 09:27:26 +02:00
pbar.set_postfix(
{"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step}
)
# tensorboard
tb_writer.add_scalar("loss", vae_loss.item(), global_step)
2024-05-09 07:53:19 +02:00
# wandb
2024-04-26 09:27:26 +02:00
if cfg.wandb:
wandb.log(
{
"iter": global_step,
"num_samples": global_step * total_batch_size,
"epoch": epoch,
"loss": vae_loss.item(),
"avg_loss": avg_loss,
2024-04-30 10:13:20 +02:00
**log_dict,
2024-04-26 09:27:26 +02:00
},
step=global_step,
2024-04-16 09:00:31 +02:00
)
2024-05-09 07:53:19 +02:00
running_loss = running_disc_loss = 0.0
log_step = 0
2024-04-16 11:31:57 +02:00
# == checkpoint saving ==
2024-05-09 07:53:19 +02:00
ckpt_every = cfg.get("ckpt_every", 0)
if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0:
save(
booster,
exp_dir,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
global_step=global_step + 1,
batch_size=cfg.get("batch_size", None),
sampler=sampler,
2024-04-26 09:27:26 +02:00
)
2024-05-09 07:53:19 +02:00
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
if use_discriminator:
2024-05-02 11:28:45 +02:00
booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
booster.save_optimizer(
disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096
)
2024-04-26 09:27:26 +02:00
dist.barrier()
2024-04-26 09:27:26 +02:00
logger.info(
"Saved checkpoint at epoch %s step %s global_step %s to %s",
epoch,
step + 1,
global_step + 1,
exp_dir,
2024-04-26 09:27:26 +02:00
)
2024-05-20 10:40:45 +02:00
sampler.reset()
2024-04-11 04:50:23 +02:00
start_step = 0
2024-04-26 09:27:26 +02:00
2024-04-11 04:50:23 +02:00
if __name__ == "__main__":
main()