[feature] make timer optional and make reduce bucket size configurable (#549)

* [feature] make reduce bucket size configurable

* [feature] make timer optional
This commit is contained in:
Hongxin Liu 2024-06-27 13:37:54 +08:00 committed by GitHub
parent 92253e97ec
commit 332d9fc9c9
4 changed files with 54 additions and 29 deletions

View file

@ -79,6 +79,7 @@ def parse_args(training=False):
parser.add_argument("--load", default=None, type=str, help="path to continue training")
parser.add_argument("--start-from-scratch", action="store_true", help="start training from scratch")
parser.add_argument("--warmup-steps", default=None, type=int, help="warmup steps")
parser.add_argument("--record-time", default=False, action="store_true", help="record time of each part")
return parser.parse_args()

View file

@ -6,11 +6,12 @@ import time
from collections import OrderedDict
from collections.abc import Sequence
from itertools import repeat
from typing import Tuple
from typing import Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
from colossalai.cluster.dist_coordinator import DistCoordinator
# ======================================================
# Logging
@ -358,11 +359,12 @@ def all_exists(paths):
class Timer:
def __init__(self, name, log=False):
def __init__(self, name, log=False, coordinator: Optional[DistCoordinator] = None):
self.name = name
self.start_time = None
self.end_time = None
self.log = log
self.coordinator = coordinator
@property
def elapsed_time(self):
@ -374,6 +376,8 @@ class Timer:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.coordinator is not None:
self.coordinator.block_all()
torch.cuda.synchronize()
self.end_time = time.time()
if self.log:

View file

@ -12,7 +12,7 @@ from opensora.acceleration.plugin import ZeroSeqParallelPlugin
from .misc import get_logger
def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size):
def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size, reduce_bucket_size_in_m: int = 20):
if plugin == "zero2":
assert sp_size == 1, "Zero2 plugin does not support sequence parallelism"
plugin = LowLevelZeroPlugin(
@ -20,6 +20,7 @@ def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size):
precision=dtype,
initial_scale=2**16,
max_norm=grad_clip,
reduce_bucket_size_in_m=reduce_bucket_size_in_m,
)
set_data_parallel_group(dist.group.WORLD)
elif plugin == "zero2-seq":
@ -30,6 +31,7 @@ def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size):
precision=dtype,
initial_scale=2**16,
max_norm=grad_clip,
reduce_bucket_size_in_m=reduce_bucket_size_in_m,
)
set_sequence_parallel_group(plugin.sp_group)
set_data_parallel_group(plugin.dp_group)

View file

@ -1,4 +1,5 @@
import os
from contextlib import nullcontext
from copy import deepcopy
from datetime import timedelta
from pprint import pformat
@ -38,6 +39,7 @@ def main():
# ======================================================
# == parse configs ==
cfg = parse_configs(training=True)
record_time = cfg.get("record_time", False)
# == device and dtype ==
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
@ -76,6 +78,7 @@ def main():
dtype=cfg_dtype,
grad_clip=cfg.get("grad_clip", 0),
sp_size=cfg.get("sp_size", 1),
reduce_bucket_size_in_m=cfg.get("reduce_bucket_size_in_m", 20),
)
booster = Booster(plugin=plugin)
torch.set_num_threads(1)
@ -229,6 +232,21 @@ def main():
# 5. training loop
# =======================================================
dist.barrier()
timers = {}
timer_keys = [
"move_data",
"encode",
"mask",
"diffusion",
"backward",
"update_ema",
"reduce_loss",
]
for key in timer_keys:
if record_time:
timers[key] = Timer(key, coordinator=coordinator)
else:
timers[key] = nullcontext()
for epoch in range(start_epoch, cfg_epochs):
# == set dataloader to new epoch ==
sampler.set_epoch(epoch)
@ -245,13 +263,14 @@ def main():
) as pbar:
for step, batch in pbar:
timer_list = []
with Timer("move data") as move_data_t:
with timers["move_data"] as move_data_t:
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
timer_list.append(move_data_t)
if record_time:
timer_list.append(move_data_t)
# == visual and text encoding ==
with Timer("encode") as encode_t:
with timers["encode"] as encode_t:
with torch.no_grad():
# Prepare visual inputs
if cfg.get("load_video_features", False):
@ -267,17 +286,17 @@ def main():
model_args["mask"] = mask
else:
model_args = text_encoder.encode(y)
coordinator.block_all()
timer_list.append(encode_t)
if record_time:
timer_list.append(encode_t)
# == mask ==
with Timer("mask") as mask_t:
with timers["mask"] as mask_t:
mask = None
if cfg.get("mask_ratios", None) is not None:
mask = mask_generator.get_masks(x)
model_args["x_mask"] = mask
coordinator.block_all()
timer_list.append(mask_t)
if record_time:
timer_list.append(mask_t)
# == video meta info ==
for k, v in batch.items():
@ -285,13 +304,13 @@ def main():
model_args[k] = v.to(device, dtype)
# == diffusion loss computation ==
with Timer("diffusion") as loss_t:
with timers["diffusion"] as loss_t:
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
coordinator.block_all()
timer_list.append(loss_t)
if record_time:
timer_list.append(loss_t)
# == backward & update ==
with Timer("backward") as backward_t:
with timers["backward"] as backward_t:
loss = loss_dict["loss"].mean()
booster.backward(loss=loss, optimizer=optimizer)
optimizer.step()
@ -300,24 +319,24 @@ def main():
# update learning rate
if lr_scheduler is not None:
lr_scheduler.step()
coordinator.block_all()
timer_list.append(backward_t)
if record_time:
timer_list.append(backward_t)
# == update EMA ==
with Timer("update_ema") as ema_t:
with timers["update_ema"] as ema_t:
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
coordinator.block_all()
timer_list.append(ema_t)
if record_time:
timer_list.append(ema_t)
# == update log info ==
with Timer("reduce_loss") as reduce_loss_t:
with timers["reduce_loss"] as reduce_loss_t:
all_reduce_mean(loss)
running_loss += loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
coordinator.block_all()
timer_list.append(reduce_loss_t)
if record_time:
timer_list.append(reduce_loss_t)
# == logging ==
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
@ -376,12 +395,11 @@ def main():
global_step + 1,
save_dir,
)
log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | "
for timer in timer_list:
log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | "
print(log_str)
coordinator.block_all()
if record_time:
log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | "
for timer in timer_list:
log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | "
print(log_str)
sampler.reset()
start_step = 0