mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-02 02:46:30 +02:00
181 lines
6.6 KiB
Python
181 lines
6.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""
|
|
A minimal training script for DiT using PyTorch DDP.
|
|
"""
|
|
import argparse
|
|
import time
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from colossalai import launch_from_torch
|
|
from colossalai.accelerator import get_accelerator
|
|
from colossalai.booster import Booster
|
|
from colossalai.booster.plugin import TorchDDPPlugin
|
|
from colossalai.cluster import DistCoordinator
|
|
from colossalai.logging import get_dist_logger
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
from colossalai.utils import get_current_device
|
|
from tqdm import tqdm
|
|
|
|
from open_sora.diffusion import create_diffusion
|
|
from open_sora.modeling import DiT_models
|
|
from open_sora.modeling.dit import SUPPORTED_SEQ_PARALLEL_MODES
|
|
from open_sora.utils.data import create_video_compressor, preprocess_batch
|
|
from open_sora.utils.plugin import ZeroSeqParallelPlugin
|
|
|
|
#################################################################################
|
|
# Training Loop #
|
|
#################################################################################
|
|
|
|
|
|
def main(args):
|
|
"""
|
|
Trains a new DiT model.
|
|
"""
|
|
# init distributed environment
|
|
launch_from_torch({})
|
|
coordinator = DistCoordinator()
|
|
logger = get_dist_logger()
|
|
|
|
# set up acceleration plugins
|
|
if args.plugin == "ddp":
|
|
plugin = TorchDDPPlugin()
|
|
elif args.plugin == "zero2":
|
|
# use bf16 to avoid skipping the first few iterations due to NaNs
|
|
plugin = ZeroSeqParallelPlugin(sp_size=args.sp_size, stage=2, precision="bf16")
|
|
else:
|
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
|
booster = Booster(plugin=plugin)
|
|
|
|
# Create video compressor
|
|
video_compressor = create_video_compressor(args.compressor)
|
|
model_kwargs = {
|
|
"in_channels": video_compressor.out_channels,
|
|
"seq_parallel_group": getattr(plugin, "sp_group", None),
|
|
"seq_parallel_mode": args.sp_mode,
|
|
"seq_parallel_overlap": args.sp_overlap,
|
|
}
|
|
|
|
# Create DiT and EMA
|
|
model = DiT_models[args.model](**model_kwargs).to(get_current_device())
|
|
patch_size = model.patch_size
|
|
model.train() # important! This enables embedding dropout for classifier-free guidance
|
|
|
|
# configure gradient checkpointing
|
|
if args.grad_checkpoint:
|
|
model.enable_gradient_checkpointing()
|
|
|
|
# create diffusion pipeline
|
|
diffusion = create_diffusion(
|
|
timestep_respacing=""
|
|
) # default: 1000 steps, linear noise schedule
|
|
|
|
# setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
|
|
opt = HybridAdam(model.parameters(), lr=1e-4, weight_decay=0)
|
|
|
|
# Setup dataloader
|
|
videos = [
|
|
torch.randn(args.num_frames, args.height, args.width, 3)
|
|
for _ in range(args.batch_size)
|
|
]
|
|
assert args.num_tokens % args.sp_size == 0
|
|
input_ids = torch.randn(args.batch_size, args.num_tokens, args.text_embed_dim)
|
|
text_mask = torch.ones(input_ids.shape[:2], dtype=torch.int)
|
|
batch = {
|
|
"videos": videos,
|
|
"text_latent_states": input_ids,
|
|
"text_padding_mask": text_mask,
|
|
}
|
|
batch = preprocess_batch(
|
|
batch, patch_size, video_compressor, pad_to_multiple=args.sp_size
|
|
)
|
|
video_inputs = batch.pop("video_latent_states")
|
|
mask = batch.pop("video_padding_mask")
|
|
logger.info(
|
|
f"Num patches: {video_inputs.shape[1]}, num tokens: {batch['text_latent_states'].shape[1]}",
|
|
ranks=[0],
|
|
)
|
|
|
|
# setup booster
|
|
model, opt, *_ = booster.boost(model, opt)
|
|
logger.info(
|
|
f"Booster init max device memory: {get_accelerator().max_memory_allocated() / 1024 ** 2:.2f} MB",
|
|
ranks=[0],
|
|
)
|
|
|
|
# Train
|
|
total_samples = 0
|
|
total_duration = 0.0
|
|
for i in tqdm(
|
|
range(args.warmup_steps + args.steps),
|
|
desc="Steps",
|
|
disable=not coordinator.is_master(),
|
|
):
|
|
start = time.time()
|
|
t = torch.randint(
|
|
0,
|
|
diffusion.num_timesteps,
|
|
(video_inputs.shape[0],),
|
|
device=video_inputs.device,
|
|
)
|
|
loss_dict = diffusion.training_losses(model, video_inputs, t, batch, mask=mask)
|
|
loss = loss_dict["loss"].mean()
|
|
booster.backward(loss, opt)
|
|
opt.step()
|
|
opt.zero_grad()
|
|
get_accelerator().empty_cache()
|
|
time_per_iter = time.time() - start
|
|
if i >= args.warmup_steps:
|
|
total_samples += args.batch_size * coordinator.world_size
|
|
total_duration += time_per_iter
|
|
total_duration = torch.tensor([total_duration], device=get_current_device())
|
|
dist.all_reduce(total_duration)
|
|
total_duration = total_duration / coordinator.world_size
|
|
total_duration = total_duration.item()
|
|
total_samples *= coordinator.world_size // args.sp_size
|
|
|
|
throughput = total_samples / total_duration
|
|
logger.info(
|
|
f"Training complete, max device memory: {get_accelerator().max_memory_allocated() / 1024 ** 2:.2f} MB",
|
|
ranks=[0],
|
|
)
|
|
logger.info(
|
|
f"Throughput: {throughput:.2f} samples/s",
|
|
ranks=[0],
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters).
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8"
|
|
)
|
|
parser.add_argument(
|
|
"-p", "--plugin", type=str, default="zero2", choices=["ddp", "zero2"]
|
|
)
|
|
parser.add_argument("--sp_size", type=int, default=1)
|
|
parser.add_argument(
|
|
"--sp_mode", type=str, default="ulysses", choices=SUPPORTED_SEQ_PARALLEL_MODES
|
|
)
|
|
parser.add_argument("--sp_overlap", action="store_true", default=False)
|
|
parser.add_argument("-w", "--warmup_steps", type=int, default=2)
|
|
parser.add_argument("-s", "--steps", type=int, default=3)
|
|
parser.add_argument("-b", "--batch_size", type=int, default=4)
|
|
parser.add_argument("-f", "--num_frames", type=int, default=300)
|
|
parser.add_argument("--height", type=int, default=256)
|
|
parser.add_argument("--width", type=int, default=256)
|
|
parser.add_argument("--num_tokens", type=int, default=20)
|
|
parser.add_argument("--text_embed_dim", type=int, default=512)
|
|
parser.add_argument("-g", "--grad_checkpoint", action="store_true", default=False)
|
|
parser.add_argument(
|
|
"-c", "--compressor", choices=["raw", "vqvae", "vae"], default="raw"
|
|
)
|
|
args = parser.parse_args()
|
|
main(args)
|