Open-Sora/scripts/train.py

323 lines
12 KiB
Python
Raw Normal View History

2024-03-15 14:49:38 +01:00
from copy import deepcopy
from datetime import timedelta
2024-03-30 06:34:19 +01:00
from pprint import pprint
2024-03-15 14:49:38 +01:00
import torch
import torch.distributed as dist
import wandb
2024-03-15 14:49:38 +01:00
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
2024-03-15 14:49:38 +01:00
from tqdm import tqdm
2024-03-15 15:16:20 +01:00
from opensora.acceleration.checkpoint import set_grad_checkpoint
2024-03-15 14:49:38 +01:00
from opensora.acceleration.parallel_states import (
get_data_parallel_group,
set_data_parallel_group,
set_sequence_parallel_group,
)
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
2024-03-26 17:24:46 +01:00
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
2024-03-30 06:34:19 +01:00
from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save
2024-03-15 14:49:38 +01:00
from opensora.utils.config_utils import (
create_experiment_workspace,
create_tensorboard_writer,
parse_configs,
save_training_config,
)
2024-03-30 06:34:19 +01:00
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype
2024-03-26 17:24:46 +01:00
from opensora.utils.train_utils import MaskGenerator, update_ema
2024-03-15 14:49:38 +01:00
def main():
# ======================================================
# 1. args & cfg
# ======================================================
cfg = parse_configs(training=True)
exp_name, exp_dir = create_experiment_workspace(cfg)
save_training_config(cfg._cfg_dict, exp_dir)
# ======================================================
# 2. runtime variables & colossalai launch
# ======================================================
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}"
# 2.1. colossalai init distributed training
# we set a very large timeout to avoid some processes exit early
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(1024)
2024-03-15 14:49:38 +01:00
coordinator = DistCoordinator()
device = get_current_device()
dtype = to_torch_dtype(cfg.dtype)
# 2.2. init logger, tensorboard & wandb
if not coordinator.is_master():
logger = create_logger(None)
else:
print("Training configuration:")
pprint(cfg._cfg_dict)
2024-03-15 14:49:38 +01:00
logger = create_logger(exp_dir)
logger.info(f"Experiment directory created at {exp_dir}")
writer = create_tensorboard_writer(exp_dir)
if cfg.wandb:
wandb.init(project="minisora", name=exp_name, config=cfg._cfg_dict)
# 2.3. initialize ColossalAI booster
if cfg.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=cfg.dtype,
initial_scale=2**16,
max_norm=cfg.grad_clip,
)
set_data_parallel_group(dist.group.WORLD)
elif cfg.plugin == "zero2-seq":
plugin = ZeroSeqParallelPlugin(
sp_size=cfg.sp_size,
stage=2,
precision=cfg.dtype,
initial_scale=2**16,
max_norm=cfg.grad_clip,
)
set_sequence_parallel_group(plugin.sp_group)
set_data_parallel_group(plugin.dp_group)
else:
raise ValueError(f"Unknown plugin {cfg.plugin}")
booster = Booster(plugin=plugin)
# ======================================================
# 3. build dataset and dataloader
# ======================================================
2024-03-26 10:02:41 +01:00
dataset = build_module(cfg.dataset, DATASETS)
logger.info(f"Dataset contains {len(dataset)} samples.")
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
2024-04-16 11:45:53 +02:00
seed=cfg.seed,
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
)
2024-03-15 14:49:38 +01:00
# TODO: use plugin's prepare dataloader
if cfg.bucket_config is None:
dataloader = prepare_dataloader(**dataloader_args)
2024-03-26 17:24:46 +01:00
else:
dataloader = prepare_variable_dataloader(
bucket_config=cfg.bucket_config,
num_bucket_build_workers=cfg.num_bucket_build_workers,
**dataloader_args,
)
2024-03-30 06:34:19 +01:00
if cfg.dataset.type == "VideoTextDataset":
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
logger.info(f"Total batch size: {total_batch_size}")
2024-03-15 14:49:38 +01:00
# ======================================================
# 4. build model
# ======================================================
# 4.1. build model
2024-03-26 10:02:41 +01:00
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
2024-03-15 14:49:38 +01:00
vae = build_module(cfg.vae, MODELS)
2024-03-26 10:02:41 +01:00
input_size = (dataset.num_frames, *dataset.image_size)
2024-03-15 14:49:38 +01:00
latent_size = vae.get_latent_size(input_size)
model = build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length
2024-03-15 14:49:38 +01:00
)
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}"
)
# 4.2. create ema
ema = deepcopy(model).to(torch.float32).to(device)
requires_grad(ema, False)
ema_shape_dict = record_model_param_shape(ema)
# 4.3. move to device
vae = vae.to(device, dtype)
model = model.to(device, dtype)
# 4.4. build scheduler
scheduler = build_module(cfg.scheduler, SCHEDULERS)
# 4.5. setup optimizer
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=cfg.lr,
weight_decay=0,
adamw_mode=True,
2024-03-15 14:49:38 +01:00
)
lr_scheduler = None
# 4.6. prepare for training
if cfg.grad_checkpoint:
set_grad_checkpoint(model)
model.train()
update_ema(ema, model, decay=0, sharded=False)
ema.eval()
2024-03-23 15:06:19 +01:00
if cfg.mask_ratios is not None:
mask_generator = MaskGenerator(cfg.mask_ratios)
2024-03-15 14:49:38 +01:00
# =======================================================
# 5. boost model for distributed training with colossalai
# =======================================================
torch.set_default_dtype(dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
2024-03-15 14:49:38 +01:00
)
torch.set_default_dtype(torch.float)
logger.info("Boost model for distributed training")
if cfg.dataset.type == "VariableVideoTextDataset":
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
else:
num_steps_per_epoch = len(dataloader)
2024-03-15 14:49:38 +01:00
# =======================================================
# 6. training loop
# =======================================================
start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0
2024-03-15 14:49:38 +01:00
running_loss = 0.0
2024-03-30 06:34:19 +01:00
sampler_to_io = dataloader.batch_sampler if cfg.dataset.type == "VariableVideoTextDataset" else None
2024-03-15 14:49:38 +01:00
# 6.1. resume training
if cfg.load is not None:
logger.info("Loading checkpoint")
2024-04-23 13:34:35 +02:00
ret = load(
booster,
model,
ema,
optimizer,
lr_scheduler,
cfg.load,
2024-04-23 13:24:22 +02:00
sampler=sampler_to_io if not cfg.start_from_scratch else None,
)
2024-04-23 13:34:35 +02:00
if not cfg.start_from_scratch:
start_epoch, start_step, sampler_start_idx = ret
2024-03-30 06:34:19 +01:00
logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}")
logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch")
2024-03-15 14:49:38 +01:00
2024-03-26 17:24:46 +01:00
if cfg.dataset.type == "VideoTextDataset":
dataloader.sampler.set_start_index(sampler_start_idx)
2024-03-15 14:49:38 +01:00
model_sharding(ema)
# 6.2. training loop
for epoch in range(start_epoch, cfg.epochs):
2024-03-26 17:24:46 +01:00
if cfg.dataset.type == "VideoTextDataset":
dataloader.sampler.set_epoch(epoch)
2024-03-15 14:49:38 +01:00
dataloader_iter = iter(dataloader)
logger.info(f"Beginning epoch {epoch}...")
with tqdm(
enumerate(dataloader_iter, start=start_step),
2024-03-15 14:49:38 +01:00
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
initial=start_step,
total=num_steps_per_epoch,
2024-03-15 14:49:38 +01:00
) as pbar:
for step, batch in pbar:
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
2024-03-23 15:06:19 +01:00
# Visual and text encoding
2024-03-15 14:49:38 +01:00
with torch.no_grad():
# Prepare visual inputs
x = vae.encode(x) # [B, C, T, H/P, W/P]
# Prepare text inputs
model_args = text_encoder.encode(y)
2024-03-26 09:50:36 +01:00
2024-03-23 15:06:19 +01:00
# Mask
if cfg.mask_ratios is not None:
mask = mask_generator.get_masks(x)
model_args["x_mask"] = mask
else:
mask = None
2024-03-15 14:49:38 +01:00
2024-03-26 17:24:46 +01:00
# Video info
for k, v in batch.items():
model_args[k] = v.to(device, dtype)
2024-03-26 17:24:46 +01:00
2024-03-15 14:49:38 +01:00
# Diffusion
2024-03-30 06:34:19 +01:00
t = torch.randint(0, scheduler.num_timesteps, (x.shape[0],), device=device)
loss_dict = scheduler.training_losses(model, x, t, model_args, mask=mask)
2024-03-15 14:49:38 +01:00
# Backward & update
loss = loss_dict["loss"].mean()
booster.backward(loss=loss, optimizer=optimizer)
optimizer.step()
optimizer.zero_grad()
# Update EMA
update_ema(ema, model.module, optimizer=optimizer)
# Log loss values:
all_reduce_mean(loss)
running_loss += loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
2024-03-15 14:49:38 +01:00
# Log to tensorboard
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
avg_loss = running_loss / log_step
2024-03-30 06:34:19 +01:00
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
2024-03-15 14:49:38 +01:00
running_loss = 0
log_step = 0
writer.add_scalar("loss", loss.item(), global_step)
if cfg.wandb:
wandb.log(
{
"iter": global_step,
"epoch": epoch,
"loss": loss.item(),
"avg_loss": avg_loss,
"acc_step": acc_step,
2024-03-15 14:49:38 +01:00
},
step=global_step,
)
# Save checkpoint
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
save(
booster,
model,
ema,
optimizer,
lr_scheduler,
epoch,
step + 1,
global_step + 1,
cfg.batch_size,
coordinator,
exp_dir,
ema_shape_dict,
sampler=sampler_to_io,
2024-03-15 14:49:38 +01:00
)
logger.info(
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
)
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
2024-03-27 16:04:12 +01:00
if cfg.dataset.type == "VideoTextDataset":
dataloader.sampler.set_start_index(0)
if cfg.dataset.type == "VariableVideoTextDataset":
dataloader.batch_sampler.set_epoch(epoch + 1)
2024-04-13 20:02:59 +02:00
print("Epoch done, recomputing batch sampler")
2024-03-15 14:49:38 +01:00
start_step = 0
if __name__ == "__main__":
main()