mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
323 lines
12 KiB
Python
323 lines
12 KiB
Python
import time
|
|
import traceback
|
|
from copy import deepcopy
|
|
from datetime import timedelta
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from colossalai.booster import Booster
|
|
from colossalai.cluster import DistCoordinator
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
from colossalai.utils import get_current_device, set_seed
|
|
from tqdm import tqdm
|
|
|
|
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
|
from opensora.acceleration.parallel_states import get_data_parallel_group
|
|
from opensora.datasets import prepare_variable_dataloader
|
|
from opensora.datasets.aspect import get_num_frames
|
|
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
|
from opensora.utils.ckpt_utils import model_sharding
|
|
from opensora.utils.config_utils import parse_configs
|
|
from opensora.utils.misc import BColors, create_logger, format_numel_str, get_model_numel, requires_grad, to_torch_dtype
|
|
from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema
|
|
|
|
SEARCH_BS_PREFIX = f"{BColors.OKGREEN}[Search BS]{BColors.ENDC}"
|
|
|
|
|
|
def main():
|
|
# ======================================================
|
|
# configs & runtime variables
|
|
# ======================================================
|
|
# == parse configs ==
|
|
cfg = parse_configs()
|
|
assert cfg.dataset.type == "VariableVideoTextDataset", "Only VariableVideoTextDataset is supported"
|
|
|
|
# == device and dtype ==
|
|
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
|
cfg_dtype = cfg.get("dtype", "bf16")
|
|
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
|
|
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
|
|
|
# == colossalai init distributed training ==
|
|
# NOTE: A very large timeout is set to avoid some processes exit early
|
|
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
|
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
|
set_seed(cfg.get("seed", 1024))
|
|
DistCoordinator()
|
|
device = get_current_device()
|
|
|
|
# == init logger, tensorboard & wandb ==
|
|
logger = create_logger()
|
|
|
|
# == init ColossalAI booster ==
|
|
plugin = create_colossalai_plugin(
|
|
plugin=cfg.get("plugin", "zero2"),
|
|
dtype=cfg_dtype,
|
|
grad_clip=cfg.get("grad_clip", 0),
|
|
sp_size=cfg.get("sp_size", 1),
|
|
)
|
|
booster = Booster(plugin=plugin)
|
|
|
|
# ======================================================
|
|
# build model
|
|
# ======================================================
|
|
logger.info("Building models...")
|
|
# == build text-encoder and vae ==
|
|
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
|
|
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
|
|
|
|
# == build diffusion model ==
|
|
input_size = (None, None, None)
|
|
latent_size = vae.get_latent_size(input_size)
|
|
model = (
|
|
build_module(
|
|
cfg.model,
|
|
MODELS,
|
|
input_size=latent_size,
|
|
in_channels=vae.out_channels,
|
|
caption_channels=text_encoder.output_dim,
|
|
model_max_length=text_encoder.model_max_length,
|
|
)
|
|
.to(device, dtype)
|
|
.train()
|
|
)
|
|
model_numel, model_numel_trainable = get_model_numel(model)
|
|
logger.info(
|
|
"[Diffusion] Trainable model params: %s, Total model params: %s",
|
|
format_numel_str(model_numel_trainable),
|
|
format_numel_str(model_numel),
|
|
)
|
|
|
|
# == build ema for diffusion model ==
|
|
ema = deepcopy(model).to(torch.float32).to(device)
|
|
requires_grad(ema, False)
|
|
ema.eval()
|
|
update_ema(ema, model, decay=0, sharded=False)
|
|
|
|
# == setup loss function, build scheduler ==
|
|
scheduler = build_module(cfg.scheduler, SCHEDULERS)
|
|
|
|
# == setup optimizer ==
|
|
optimizer = HybridAdam(
|
|
filter(lambda p: p.requires_grad, model.parameters()),
|
|
adamw_mode=True,
|
|
lr=cfg.get("lr", 1e-4),
|
|
weight_decay=cfg.get("weight_decay", 0),
|
|
eps=cfg.get("adam_eps", 1e-8),
|
|
)
|
|
lr_scheduler = None
|
|
|
|
# == additional preparation ==
|
|
if cfg.get("grad_checkpoint", False):
|
|
set_grad_checkpoint(model)
|
|
if cfg.get("mask_ratios", None) is not None:
|
|
mask_generator = MaskGenerator(cfg.mask_ratios)
|
|
|
|
# =======================================================
|
|
# distributed training preparation with colossalai
|
|
# =======================================================
|
|
logger.info("Preparing for distributed training...")
|
|
# == boosting ==
|
|
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
|
|
torch.set_default_dtype(dtype)
|
|
model, optimizer, _, _, lr_scheduler = booster.boost(
|
|
model=model,
|
|
optimizer=optimizer,
|
|
lr_scheduler=lr_scheduler,
|
|
)
|
|
torch.set_default_dtype(torch.float)
|
|
logger.info("Boosting model for distributed training")
|
|
|
|
model_sharding(ema)
|
|
|
|
def reset_optimizer():
|
|
# this is essential for the first iteration after OOM
|
|
optimizer._grad_store.reset_all_gradients()
|
|
optimizer._bucket_store.reset_num_elements_in_bucket()
|
|
optimizer._bucket_store.grad_to_param_mapping = dict()
|
|
optimizer._bucket_store._grad_in_bucket = dict()
|
|
optimizer._bucket_store._param_list = []
|
|
optimizer._bucket_store._padding_size = []
|
|
for rank in range(optimizer._bucket_store._world_size):
|
|
optimizer._bucket_store._grad_in_bucket[rank] = []
|
|
optimizer._bucket_store.offset_list = [0]
|
|
optimizer.zero_grad()
|
|
|
|
def build_dataset(resolution, num_frames, batch_size):
|
|
bucket_config = {resolution: {num_frames: (1.0, batch_size)}}
|
|
dataset = build_module(cfg.dataset, DATASETS)
|
|
dataloader_args = dict(
|
|
dataset=dataset,
|
|
batch_size=None,
|
|
num_workers=cfg.num_workers,
|
|
shuffle=False,
|
|
drop_last=False,
|
|
pin_memory=True,
|
|
process_group=get_data_parallel_group(),
|
|
)
|
|
dataloader = prepare_variable_dataloader(
|
|
bucket_config=bucket_config,
|
|
**dataloader_args,
|
|
)
|
|
num_batch = dataloader.batch_sampler.get_num_batch()
|
|
num_steps_per_epoch = num_batch // dist.get_world_size()
|
|
|
|
dataloader_iter = iter(dataloader)
|
|
|
|
return dataloader_iter, num_steps_per_epoch, num_batch
|
|
|
|
def train(resolution, num_frames, batch_size, warmup_steps=5, active_steps=5):
|
|
logger.info(
|
|
"%s Training resolution=%s, num_frames=%s, batch_size=%s",
|
|
SEARCH_BS_PREFIX,
|
|
resolution,
|
|
num_frames,
|
|
batch_size,
|
|
)
|
|
total_steps = warmup_steps + active_steps
|
|
dataloader_iter, num_steps_per_epoch, num_batch = build_dataset(resolution, num_frames, batch_size)
|
|
if num_batch == 0: # no data
|
|
logger.info("%s No data found for resolution=%s, num_frames=%s", SEARCH_BS_PREFIX, resolution, num_frames)
|
|
return -1
|
|
assert (
|
|
num_steps_per_epoch >= total_steps
|
|
), f"num_steps_per_epoch={num_steps_per_epoch} < total_steps={total_steps}"
|
|
duration = 0
|
|
|
|
reset_optimizer()
|
|
for step, batch in tqdm(
|
|
enumerate(dataloader_iter),
|
|
desc=f"({resolution},{num_frames}) bs={batch_size}",
|
|
total=total_steps,
|
|
):
|
|
if step >= total_steps:
|
|
break
|
|
if step >= warmup_steps:
|
|
start = time.time()
|
|
|
|
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
|
|
y = batch.pop("text")
|
|
|
|
# == visual and text encoding ==
|
|
with torch.no_grad():
|
|
# Prepare visual inputs
|
|
x = vae.encode(x) # [B, C, T, H/P, W/P]
|
|
# Prepare text inputs
|
|
model_args = text_encoder.encode(y)
|
|
|
|
# == mask ==
|
|
mask = None
|
|
if cfg.get("mask_ratios", None) is not None:
|
|
mask = mask_generator.get_masks(x)
|
|
model_args["x_mask"] = mask
|
|
|
|
# == video meta info ==
|
|
for k, v in batch.items():
|
|
model_args[k] = v.to(device, dtype)
|
|
|
|
# == diffusion loss computation ==
|
|
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
|
|
|
|
# == backward & update ==
|
|
loss = loss_dict["loss"].mean()
|
|
booster.backward(loss=loss, optimizer=optimizer)
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
# == update EMA ==
|
|
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
|
|
|
|
# == time accumulation ==
|
|
if step >= warmup_steps:
|
|
end = time.time()
|
|
duration += end - start
|
|
|
|
avg_step_time = duration / active_steps
|
|
logger.info("%s Average step time: %.2f", SEARCH_BS_PREFIX, avg_step_time)
|
|
return avg_step_time
|
|
|
|
# =======================================================
|
|
# search for bucket
|
|
# =======================================================
|
|
# == benchmark ==
|
|
def benchmark(resolution, num_frames, lower_bound, upper_bound, ref_step_time=None):
|
|
logger.info(
|
|
"%s Benchmarking resolution=%s, num_frames=%s, lower_bound=%s, upper_bound=%s",
|
|
SEARCH_BS_PREFIX,
|
|
resolution,
|
|
num_frames,
|
|
lower_bound,
|
|
upper_bound,
|
|
)
|
|
|
|
# binary search the largest valid batch size
|
|
mid = target_batch_size = target_step_time = 0
|
|
if ref_step_time is not None:
|
|
min_dis = float("inf")
|
|
while lower_bound <= upper_bound:
|
|
mid = (lower_bound + upper_bound) // 2
|
|
try:
|
|
step_time = train(resolution, num_frames, mid)
|
|
if step_time < 0: # no data
|
|
return 0, 0
|
|
if ref_step_time is not None:
|
|
if step_time < ref_step_time:
|
|
lower_bound = mid + 1
|
|
dis = abs(step_time - ref_step_time)
|
|
if dis < min_dis:
|
|
target_batch_size, target_step_time = mid, step_time
|
|
min_dis = dis
|
|
else:
|
|
upper_bound = mid - 1
|
|
else:
|
|
target_batch_size, target_step_time = mid, step_time
|
|
lower_bound = mid + 1
|
|
except Exception:
|
|
traceback.print_exc()
|
|
upper_bound = mid - 1
|
|
|
|
logger.info(
|
|
"%s Benchmarking result: batch_size=%s, step_time=%s", SEARCH_BS_PREFIX, target_batch_size, target_step_time
|
|
)
|
|
return target_batch_size, target_step_time
|
|
|
|
# == build bucket ==
|
|
output_bucket_cfg = deepcopy(cfg.bucket_config)
|
|
buckets = {
|
|
(resolution, num_frames): (max(guess_bs - variance, 1), guess_bs + variance)
|
|
for resolution, t_bucket in cfg.bucket_config.items()
|
|
for num_frames, (guess_bs, variance) in t_bucket.items()
|
|
}
|
|
|
|
# == get base_step_time ==
|
|
base_step_time = cfg.get("base_step_time", None)
|
|
result_table = []
|
|
if base_step_time is None:
|
|
base_resolution, base_num_frames = cfg.base
|
|
base_num_frames = get_num_frames(base_num_frames)
|
|
assert (
|
|
base_resolution,
|
|
base_num_frames,
|
|
) in buckets, f"Base bucket {base_resolution} {base_num_frames} not found"
|
|
base_bound = buckets.pop((base_resolution, base_num_frames))
|
|
|
|
base_batch_size, base_step_time = benchmark(base_resolution, base_num_frames, *base_bound)
|
|
output_bucket_cfg[base_resolution][base_num_frames] = base_batch_size
|
|
result_table.append(f"{base_resolution}, {base_num_frames}, {base_batch_size}, {base_step_time:.2f}")
|
|
|
|
# == search for other buckets ==
|
|
for (resolution, frames), bounds in buckets.items():
|
|
try:
|
|
batch_size, step_time = benchmark(resolution, frames, *bounds, ref_step_time=base_step_time)
|
|
output_bucket_cfg[resolution][frames] = batch_size
|
|
result_table.append(f"{resolution}, {frames}, {batch_size}, {step_time:.2f}")
|
|
except RuntimeError:
|
|
pass
|
|
result_table = "\n".join(result_table)
|
|
logger.info("%s Search result:\nResolution, Frames, Batch size, Step time\n%s", SEARCH_BS_PREFIX, result_table)
|
|
logger.info("%s Bucket searched: %s", SEARCH_BS_PREFIX, output_bucket_cfg)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|