Open-Sora/scripts/misc/search_bs.py
2024-05-15 12:21:07 +00:00

323 lines
12 KiB
Python

import time
import traceback
from copy import deepcopy
from datetime import timedelta
import torch
import torch.distributed as dist
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
from tqdm import tqdm
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_variable_dataloader
from opensora.datasets.aspect import get_num_frames
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import model_sharding
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import BColors, create_logger, format_numel_str, get_model_numel, requires_grad, to_torch_dtype
from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema
SEARCH_BS_PREFIX = f"{BColors.OKGREEN}[Search BS]{BColors.ENDC}"
def main():
# ======================================================
# configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs()
assert cfg.dataset.type == "VariableVideoTextDataset", "Only VariableVideoTextDataset is supported"
# == device and dtype ==
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
# == colossalai init distributed training ==
# NOTE: A very large timeout is set to avoid some processes exit early
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(cfg.get("seed", 1024))
DistCoordinator()
device = get_current_device()
# == init logger, tensorboard & wandb ==
logger = create_logger()
# == init ColossalAI booster ==
plugin = create_colossalai_plugin(
plugin=cfg.get("plugin", "zero2"),
dtype=cfg_dtype,
grad_clip=cfg.get("grad_clip", 0),
sp_size=cfg.get("sp_size", 1),
)
booster = Booster(plugin=plugin)
# ======================================================
# build model
# ======================================================
logger.info("Building models...")
# == build text-encoder and vae ==
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
# == build diffusion model ==
input_size = (None, None, None)
latent_size = vae.get_latent_size(input_size)
model = (
build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
)
.to(device, dtype)
.train()
)
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
"[Diffusion] Trainable model params: %s, Total model params: %s",
format_numel_str(model_numel_trainable),
format_numel_str(model_numel),
)
# == build ema for diffusion model ==
ema = deepcopy(model).to(torch.float32).to(device)
requires_grad(ema, False)
ema.eval()
update_ema(ema, model, decay=0, sharded=False)
# == setup loss function, build scheduler ==
scheduler = build_module(cfg.scheduler, SCHEDULERS)
# == setup optimizer ==
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()),
adamw_mode=True,
lr=cfg.get("lr", 1e-4),
weight_decay=cfg.get("weight_decay", 0),
eps=cfg.get("adam_eps", 1e-8),
)
lr_scheduler = None
# == additional preparation ==
if cfg.get("grad_checkpoint", False):
set_grad_checkpoint(model)
if cfg.get("mask_ratios", None) is not None:
mask_generator = MaskGenerator(cfg.mask_ratios)
# =======================================================
# distributed training preparation with colossalai
# =======================================================
logger.info("Preparing for distributed training...")
# == boosting ==
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
torch.set_default_dtype(dtype)
model, optimizer, _, _, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
torch.set_default_dtype(torch.float)
logger.info("Boosting model for distributed training")
model_sharding(ema)
def reset_optimizer():
# this is essential for the first iteration after OOM
optimizer._grad_store.reset_all_gradients()
optimizer._bucket_store.reset_num_elements_in_bucket()
optimizer._bucket_store.grad_to_param_mapping = dict()
optimizer._bucket_store._grad_in_bucket = dict()
optimizer._bucket_store._param_list = []
optimizer._bucket_store._padding_size = []
for rank in range(optimizer._bucket_store._world_size):
optimizer._bucket_store._grad_in_bucket[rank] = []
optimizer._bucket_store.offset_list = [0]
optimizer.zero_grad()
def build_dataset(resolution, num_frames, batch_size):
bucket_config = {resolution: {num_frames: (1.0, batch_size)}}
dataset = build_module(cfg.dataset, DATASETS)
dataloader_args = dict(
dataset=dataset,
batch_size=None,
num_workers=cfg.num_workers,
shuffle=False,
drop_last=False,
pin_memory=True,
process_group=get_data_parallel_group(),
)
dataloader = prepare_variable_dataloader(
bucket_config=bucket_config,
**dataloader_args,
)
num_batch = dataloader.batch_sampler.get_num_batch()
num_steps_per_epoch = num_batch // dist.get_world_size()
dataloader_iter = iter(dataloader)
return dataloader_iter, num_steps_per_epoch, num_batch
def train(resolution, num_frames, batch_size, warmup_steps=5, active_steps=5):
logger.info(
"%s Training resolution=%s, num_frames=%s, batch_size=%s",
SEARCH_BS_PREFIX,
resolution,
num_frames,
batch_size,
)
total_steps = warmup_steps + active_steps
dataloader_iter, num_steps_per_epoch, num_batch = build_dataset(resolution, num_frames, batch_size)
if num_batch == 0: # no data
logger.info("%s No data found for resolution=%s, num_frames=%s", SEARCH_BS_PREFIX, resolution, num_frames)
return -1
assert (
num_steps_per_epoch >= total_steps
), f"num_steps_per_epoch={num_steps_per_epoch} < total_steps={total_steps}"
duration = 0
reset_optimizer()
for step, batch in tqdm(
enumerate(dataloader_iter),
desc=f"({resolution},{num_frames}) bs={batch_size}",
total=total_steps,
):
if step >= total_steps:
break
if step >= warmup_steps:
start = time.time()
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
# == visual and text encoding ==
with torch.no_grad():
# Prepare visual inputs
x = vae.encode(x) # [B, C, T, H/P, W/P]
# Prepare text inputs
model_args = text_encoder.encode(y)
# == mask ==
mask = None
if cfg.get("mask_ratios", None) is not None:
mask = mask_generator.get_masks(x)
model_args["x_mask"] = mask
# == video meta info ==
for k, v in batch.items():
model_args[k] = v.to(device, dtype)
# == diffusion loss computation ==
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
# == backward & update ==
loss = loss_dict["loss"].mean()
booster.backward(loss=loss, optimizer=optimizer)
optimizer.step()
optimizer.zero_grad()
# == update EMA ==
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
# == time accumulation ==
if step >= warmup_steps:
end = time.time()
duration += end - start
avg_step_time = duration / active_steps
logger.info("%s Average step time: %.2f", SEARCH_BS_PREFIX, avg_step_time)
return avg_step_time
# =======================================================
# search for bucket
# =======================================================
# == benchmark ==
def benchmark(resolution, num_frames, lower_bound, upper_bound, ref_step_time=None):
logger.info(
"%s Benchmarking resolution=%s, num_frames=%s, lower_bound=%s, upper_bound=%s",
SEARCH_BS_PREFIX,
resolution,
num_frames,
lower_bound,
upper_bound,
)
# binary search the largest valid batch size
mid = target_batch_size = target_step_time = 0
if ref_step_time is not None:
min_dis = float("inf")
while lower_bound <= upper_bound:
mid = (lower_bound + upper_bound) // 2
try:
step_time = train(resolution, num_frames, mid)
if step_time < 0: # no data
return 0, 0
if ref_step_time is not None:
if step_time < ref_step_time:
lower_bound = mid + 1
dis = abs(step_time - ref_step_time)
if dis < min_dis:
target_batch_size, target_step_time = mid, step_time
min_dis = dis
else:
upper_bound = mid - 1
else:
target_batch_size, target_step_time = mid, step_time
lower_bound = mid + 1
except Exception:
traceback.print_exc()
upper_bound = mid - 1
logger.info(
"%s Benchmarking result: batch_size=%s, step_time=%s", SEARCH_BS_PREFIX, target_batch_size, target_step_time
)
return target_batch_size, target_step_time
# == build bucket ==
output_bucket_cfg = deepcopy(cfg.bucket_config)
buckets = {
(resolution, num_frames): (max(guess_bs - variance, 1), guess_bs + variance)
for resolution, t_bucket in cfg.bucket_config.items()
for num_frames, (guess_bs, variance) in t_bucket.items()
}
# == get base_step_time ==
base_step_time = cfg.get("base_step_time", None)
result_table = []
if base_step_time is None:
base_resolution, base_num_frames = cfg.base
base_num_frames = get_num_frames(base_num_frames)
assert (
base_resolution,
base_num_frames,
) in buckets, f"Base bucket {base_resolution} {base_num_frames} not found"
base_bound = buckets.pop((base_resolution, base_num_frames))
base_batch_size, base_step_time = benchmark(base_resolution, base_num_frames, *base_bound)
output_bucket_cfg[base_resolution][base_num_frames] = base_batch_size
result_table.append(f"{base_resolution}, {base_num_frames}, {base_batch_size}, {base_step_time:.2f}")
# == search for other buckets ==
for (resolution, frames), bounds in buckets.items():
try:
batch_size, step_time = benchmark(resolution, frames, *bounds, ref_step_time=base_step_time)
output_bucket_cfg[resolution][frames] = batch_size
result_table.append(f"{resolution}, {frames}, {batch_size}, {step_time:.2f}")
except RuntimeError:
pass
result_table = "\n".join(result_table)
logger.info("%s Search result:\nResolution, Frames, Batch size, Step time\n%s", SEARCH_BS_PREFIX, result_table)
logger.info("%s Bucket searched: %s", SEARCH_BS_PREFIX, output_bucket_cfg)
if __name__ == "__main__":
main()