From 332d9fc9c91f72078acb9415793ec4d9d5cfa2b2 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 27 Jun 2024 13:37:54 +0800 Subject: [PATCH] [feature] make timer optional and make reduce bucket size configurable (#549) * [feature] make reduce bucket size configurable * [feature] make timer optional --- opensora/utils/config_utils.py | 1 + opensora/utils/misc.py | 8 +++- opensora/utils/train_utils.py | 4 +- scripts/train.py | 70 +++++++++++++++++++++------------- 4 files changed, 54 insertions(+), 29 deletions(-) diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index f20138b..0333bec 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -79,6 +79,7 @@ def parse_args(training=False): parser.add_argument("--load", default=None, type=str, help="path to continue training") parser.add_argument("--start-from-scratch", action="store_true", help="start training from scratch") parser.add_argument("--warmup-steps", default=None, type=int, help="warmup steps") + parser.add_argument("--record-time", default=False, action="store_true", help="record time of each part") return parser.parse_args() diff --git a/opensora/utils/misc.py b/opensora/utils/misc.py index b2f6730..bb6b658 100644 --- a/opensora/utils/misc.py +++ b/opensora/utils/misc.py @@ -6,11 +6,12 @@ import time from collections import OrderedDict from collections.abc import Sequence from itertools import repeat -from typing import Tuple +from typing import Optional, Tuple import numpy as np import torch import torch.distributed as dist +from colossalai.cluster.dist_coordinator import DistCoordinator # ====================================================== # Logging @@ -358,11 +359,12 @@ def all_exists(paths): class Timer: - def __init__(self, name, log=False): + def __init__(self, name, log=False, coordinator: Optional[DistCoordinator] = None): self.name = name self.start_time = None self.end_time = None self.log = log + self.coordinator = coordinator @property def elapsed_time(self): @@ -374,6 +376,8 @@ class Timer: return self def __exit__(self, exc_type, exc_val, exc_tb): + if self.coordinator is not None: + self.coordinator.block_all() torch.cuda.synchronize() self.end_time = time.time() if self.log: diff --git a/opensora/utils/train_utils.py b/opensora/utils/train_utils.py index da38cac..95d0011 100644 --- a/opensora/utils/train_utils.py +++ b/opensora/utils/train_utils.py @@ -12,7 +12,7 @@ from opensora.acceleration.plugin import ZeroSeqParallelPlugin from .misc import get_logger -def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size): +def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size, reduce_bucket_size_in_m: int = 20): if plugin == "zero2": assert sp_size == 1, "Zero2 plugin does not support sequence parallelism" plugin = LowLevelZeroPlugin( @@ -20,6 +20,7 @@ def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size): precision=dtype, initial_scale=2**16, max_norm=grad_clip, + reduce_bucket_size_in_m=reduce_bucket_size_in_m, ) set_data_parallel_group(dist.group.WORLD) elif plugin == "zero2-seq": @@ -30,6 +31,7 @@ def create_colossalai_plugin(plugin, dtype, grad_clip, sp_size): precision=dtype, initial_scale=2**16, max_norm=grad_clip, + reduce_bucket_size_in_m=reduce_bucket_size_in_m, ) set_sequence_parallel_group(plugin.sp_group) set_data_parallel_group(plugin.dp_group) diff --git a/scripts/train.py b/scripts/train.py index 98bb2d7..1066977 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,4 +1,5 @@ import os +from contextlib import nullcontext from copy import deepcopy from datetime import timedelta from pprint import pformat @@ -38,6 +39,7 @@ def main(): # ====================================================== # == parse configs == cfg = parse_configs(training=True) + record_time = cfg.get("record_time", False) # == device and dtype == assert torch.cuda.is_available(), "Training currently requires at least one GPU." @@ -76,6 +78,7 @@ def main(): dtype=cfg_dtype, grad_clip=cfg.get("grad_clip", 0), sp_size=cfg.get("sp_size", 1), + reduce_bucket_size_in_m=cfg.get("reduce_bucket_size_in_m", 20), ) booster = Booster(plugin=plugin) torch.set_num_threads(1) @@ -229,6 +232,21 @@ def main(): # 5. training loop # ======================================================= dist.barrier() + timers = {} + timer_keys = [ + "move_data", + "encode", + "mask", + "diffusion", + "backward", + "update_ema", + "reduce_loss", + ] + for key in timer_keys: + if record_time: + timers[key] = Timer(key, coordinator=coordinator) + else: + timers[key] = nullcontext() for epoch in range(start_epoch, cfg_epochs): # == set dataloader to new epoch == sampler.set_epoch(epoch) @@ -245,13 +263,14 @@ def main(): ) as pbar: for step, batch in pbar: timer_list = [] - with Timer("move data") as move_data_t: + with timers["move_data"] as move_data_t: x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] y = batch.pop("text") - timer_list.append(move_data_t) + if record_time: + timer_list.append(move_data_t) # == visual and text encoding == - with Timer("encode") as encode_t: + with timers["encode"] as encode_t: with torch.no_grad(): # Prepare visual inputs if cfg.get("load_video_features", False): @@ -267,17 +286,17 @@ def main(): model_args["mask"] = mask else: model_args = text_encoder.encode(y) - coordinator.block_all() - timer_list.append(encode_t) + if record_time: + timer_list.append(encode_t) # == mask == - with Timer("mask") as mask_t: + with timers["mask"] as mask_t: mask = None if cfg.get("mask_ratios", None) is not None: mask = mask_generator.get_masks(x) model_args["x_mask"] = mask - coordinator.block_all() - timer_list.append(mask_t) + if record_time: + timer_list.append(mask_t) # == video meta info == for k, v in batch.items(): @@ -285,13 +304,13 @@ def main(): model_args[k] = v.to(device, dtype) # == diffusion loss computation == - with Timer("diffusion") as loss_t: + with timers["diffusion"] as loss_t: loss_dict = scheduler.training_losses(model, x, model_args, mask=mask) - coordinator.block_all() - timer_list.append(loss_t) + if record_time: + timer_list.append(loss_t) # == backward & update == - with Timer("backward") as backward_t: + with timers["backward"] as backward_t: loss = loss_dict["loss"].mean() booster.backward(loss=loss, optimizer=optimizer) optimizer.step() @@ -300,24 +319,24 @@ def main(): # update learning rate if lr_scheduler is not None: lr_scheduler.step() - coordinator.block_all() - timer_list.append(backward_t) + if record_time: + timer_list.append(backward_t) # == update EMA == - with Timer("update_ema") as ema_t: + with timers["update_ema"] as ema_t: update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999)) - coordinator.block_all() - timer_list.append(ema_t) + if record_time: + timer_list.append(ema_t) # == update log info == - with Timer("reduce_loss") as reduce_loss_t: + with timers["reduce_loss"] as reduce_loss_t: all_reduce_mean(loss) running_loss += loss.item() global_step = epoch * num_steps_per_epoch + step log_step += 1 acc_step += 1 - coordinator.block_all() - timer_list.append(reduce_loss_t) + if record_time: + timer_list.append(reduce_loss_t) # == logging == if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0: @@ -376,12 +395,11 @@ def main(): global_step + 1, save_dir, ) - - log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | " - for timer in timer_list: - log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | " - print(log_str) - coordinator.block_all() + if record_time: + log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | " + for timer in timer_list: + log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | " + print(log_str) sampler.reset() start_step = 0