mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 13:14:44 +02:00
[feature] make timer optional and make reduce bucket size configurable (#549)
* [feature] make reduce bucket size configurable * [feature] make timer optional
This commit is contained in:
parent
92253e97ec
commit
332d9fc9c9
|
|
@ -79,6 +79,7 @@ def parse_args(training=False):
|
|||
parser.add_argument("--load", default=None, type=str, help="path to continue training")
|
||||
parser.add_argument("--start-from-scratch", action="store_true", help="start training from scratch")
|
||||
parser.add_argument("--warmup-steps", default=None, type=int, help="warmup steps")
|
||||
parser.add_argument("--record-time", default=False, action="store_true", help="record time of each part")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,12 @@ import time
|
|||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
from itertools import repeat
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.cluster.dist_coordinator import DistCoordinator
|
||||
|
||||
# ======================================================
|
||||
# Logging
|
||||
|
|
@ -358,11 +359,12 @@ def all_exists(paths):
|
|||
|
||||
|
||||
class Timer:
|
||||
def __init__(self, name, log=False):
|
||||
def __init__(self, name, log=False, coordinator: Optional[DistCoordinator] = None):
|
||||
self.name = name
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self.log = log
|
||||
self.coordinator = coordinator
|
||||
|
||||
@property
|
||||
def elapsed_time(self):
|
||||
|
|
@ -374,6 +376,8 @@ class Timer:
|
|||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.coordinator is not None:
|
||||
self.coordinator.block_all()
|
||||
torch.cuda.synchronize()
|
||||
self.end_time = time.time()
|
||||
if self.log:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from opensora.acceleration.plugin import ZeroSeqParallelPlugin
|
|||
from .misc import get_logger
|
||||
|
||||
|
||||
def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size):
|
||||
def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size, reduce_bucket_size_in_m: int = 20):
|
||||
if plugin == "zero2":
|
||||
assert sp_size == 1, "Zero2 plugin does not support sequence parallelism"
|
||||
plugin = LowLevelZeroPlugin(
|
||||
|
|
@ -20,6 +20,7 @@ def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size):
|
|||
precision=dtype,
|
||||
initial_scale=2**16,
|
||||
max_norm=grad_clip,
|
||||
reduce_bucket_size_in_m=reduce_bucket_size_in_m,
|
||||
)
|
||||
set_data_parallel_group(dist.group.WORLD)
|
||||
elif plugin == "zero2-seq":
|
||||
|
|
@ -30,6 +31,7 @@ def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size):
|
|||
precision=dtype,
|
||||
initial_scale=2**16,
|
||||
max_norm=grad_clip,
|
||||
reduce_bucket_size_in_m=reduce_bucket_size_in_m,
|
||||
)
|
||||
set_sequence_parallel_group(plugin.sp_group)
|
||||
set_data_parallel_group(plugin.dp_group)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
from pprint import pformat
|
||||
|
|
@ -38,6 +39,7 @@ def main():
|
|||
# ======================================================
|
||||
# == parse configs ==
|
||||
cfg = parse_configs(training=True)
|
||||
record_time = cfg.get("record_time", False)
|
||||
|
||||
# == device and dtype ==
|
||||
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
||||
|
|
@ -76,6 +78,7 @@ def main():
|
|||
dtype=cfg_dtype,
|
||||
grad_clip=cfg.get("grad_clip", 0),
|
||||
sp_size=cfg.get("sp_size", 1),
|
||||
reduce_bucket_size_in_m=cfg.get("reduce_bucket_size_in_m", 20),
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
torch.set_num_threads(1)
|
||||
|
|
@ -229,6 +232,21 @@ def main():
|
|||
# 5. training loop
|
||||
# =======================================================
|
||||
dist.barrier()
|
||||
timers = {}
|
||||
timer_keys = [
|
||||
"move_data",
|
||||
"encode",
|
||||
"mask",
|
||||
"diffusion",
|
||||
"backward",
|
||||
"update_ema",
|
||||
"reduce_loss",
|
||||
]
|
||||
for key in timer_keys:
|
||||
if record_time:
|
||||
timers[key] = Timer(key, coordinator=coordinator)
|
||||
else:
|
||||
timers[key] = nullcontext()
|
||||
for epoch in range(start_epoch, cfg_epochs):
|
||||
# == set dataloader to new epoch ==
|
||||
sampler.set_epoch(epoch)
|
||||
|
|
@ -245,13 +263,14 @@ def main():
|
|||
) as pbar:
|
||||
for step, batch in pbar:
|
||||
timer_list = []
|
||||
with Timer("move data") as move_data_t:
|
||||
with timers["move_data"] as move_data_t:
|
||||
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
|
||||
y = batch.pop("text")
|
||||
timer_list.append(move_data_t)
|
||||
if record_time:
|
||||
timer_list.append(move_data_t)
|
||||
|
||||
# == visual and text encoding ==
|
||||
with Timer("encode") as encode_t:
|
||||
with timers["encode"] as encode_t:
|
||||
with torch.no_grad():
|
||||
# Prepare visual inputs
|
||||
if cfg.get("load_video_features", False):
|
||||
|
|
@ -267,17 +286,17 @@ def main():
|
|||
model_args["mask"] = mask
|
||||
else:
|
||||
model_args = text_encoder.encode(y)
|
||||
coordinator.block_all()
|
||||
timer_list.append(encode_t)
|
||||
if record_time:
|
||||
timer_list.append(encode_t)
|
||||
|
||||
# == mask ==
|
||||
with Timer("mask") as mask_t:
|
||||
with timers["mask"] as mask_t:
|
||||
mask = None
|
||||
if cfg.get("mask_ratios", None) is not None:
|
||||
mask = mask_generator.get_masks(x)
|
||||
model_args["x_mask"] = mask
|
||||
coordinator.block_all()
|
||||
timer_list.append(mask_t)
|
||||
if record_time:
|
||||
timer_list.append(mask_t)
|
||||
|
||||
# == video meta info ==
|
||||
for k, v in batch.items():
|
||||
|
|
@ -285,13 +304,13 @@ def main():
|
|||
model_args[k] = v.to(device, dtype)
|
||||
|
||||
# == diffusion loss computation ==
|
||||
with Timer("diffusion") as loss_t:
|
||||
with timers["diffusion"] as loss_t:
|
||||
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
|
||||
coordinator.block_all()
|
||||
timer_list.append(loss_t)
|
||||
if record_time:
|
||||
timer_list.append(loss_t)
|
||||
|
||||
# == backward & update ==
|
||||
with Timer("backward") as backward_t:
|
||||
with timers["backward"] as backward_t:
|
||||
loss = loss_dict["loss"].mean()
|
||||
booster.backward(loss=loss, optimizer=optimizer)
|
||||
optimizer.step()
|
||||
|
|
@ -300,24 +319,24 @@ def main():
|
|||
# update learning rate
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
coordinator.block_all()
|
||||
timer_list.append(backward_t)
|
||||
if record_time:
|
||||
timer_list.append(backward_t)
|
||||
|
||||
# == update EMA ==
|
||||
with Timer("update_ema") as ema_t:
|
||||
with timers["update_ema"] as ema_t:
|
||||
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
|
||||
coordinator.block_all()
|
||||
timer_list.append(ema_t)
|
||||
if record_time:
|
||||
timer_list.append(ema_t)
|
||||
|
||||
# == update log info ==
|
||||
with Timer("reduce_loss") as reduce_loss_t:
|
||||
with timers["reduce_loss"] as reduce_loss_t:
|
||||
all_reduce_mean(loss)
|
||||
running_loss += loss.item()
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
log_step += 1
|
||||
acc_step += 1
|
||||
coordinator.block_all()
|
||||
timer_list.append(reduce_loss_t)
|
||||
if record_time:
|
||||
timer_list.append(reduce_loss_t)
|
||||
|
||||
# == logging ==
|
||||
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
|
||||
|
|
@ -376,12 +395,11 @@ def main():
|
|||
global_step + 1,
|
||||
save_dir,
|
||||
)
|
||||
|
||||
log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | "
|
||||
for timer in timer_list:
|
||||
log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | "
|
||||
print(log_str)
|
||||
coordinator.block_all()
|
||||
if record_time:
|
||||
log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | "
|
||||
for timer in timer_list:
|
||||
log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | "
|
||||
print(log_str)
|
||||
|
||||
sampler.reset()
|
||||
start_step = 0
|
||||
|
|
|
|||
Loading…
Reference in a new issue