From 2e1e26a2e41a94a30f8f74c9342ca9e0da173efb Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 29 May 2024 14:43:15 +0800 Subject: [PATCH] added timer for benchmarking (#114) --- configs/opensora-v1-2/train/stage1.py | 2 +- opensora/acceleration/parallel_states.py | 2 +- opensora/datasets/utils.py | 2 +- opensora/utils/misc.py | 20 +++++ scripts/train.py | 101 ++++++++++++++++------- 5 files changed, 92 insertions(+), 35 deletions(-) diff --git a/configs/opensora-v1-2/train/stage1.py b/configs/opensora-v1-2/train/stage1.py index 5a9b54d..eb36c80 100644 --- a/configs/opensora-v1-2/train/stage1.py +++ b/configs/opensora-v1-2/train/stage1.py @@ -63,7 +63,7 @@ model = dict( ) vae = dict( type="OpenSoraVAE_V1_2", - from_pretrained="pretrained_models/vae-pipeline", + from_pretrained="/mnt/jfs/sora_checkpoints/vae-pipeline", micro_frame_size=17, micro_batch_size=4, ) diff --git a/opensora/acceleration/parallel_states.py b/opensora/acceleration/parallel_states.py index ff2893e..3c05cf1 100644 --- a/opensora/acceleration/parallel_states.py +++ b/opensora/acceleration/parallel_states.py @@ -8,7 +8,7 @@ def set_data_parallel_group(group: dist.ProcessGroup): def get_data_parallel_group(): - return _GLOBAL_PARALLEL_GROUPS.get("data", None) + return _GLOBAL_PARALLEL_GROUPS.get("data", dist.group.WORLD) def set_sequence_parallel_group(group: dist.ProcessGroup): diff --git a/opensora/datasets/utils.py b/opensora/datasets/utils.py index a85cf94..194ea24 100644 --- a/opensora/datasets/utils.py +++ b/opensora/datasets/utils.py @@ -67,7 +67,7 @@ def temporal_random_crop(vframes, num_frames, frame_interval): temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval) total_frames = len(vframes) start_frame_ind, end_frame_ind = temporal_sample(total_frames) - assert end_frame_ind - start_frame_ind >= num_frames + assert end_frame_ind - start_frame_ind >= num_frames, f"Not enough frames to sample, {end_frame_ind} - {start_frame_ind} < {num_frames}" frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int) video = vframes[frame_indice] return video diff --git a/opensora/utils/misc.py b/opensora/utils/misc.py index 61e00f4..cfee23d 100644 --- a/opensora/utils/misc.py +++ b/opensora/utils/misc.py @@ -409,3 +409,23 @@ class FeatureSaver: get_logger().info("Saved to %s", save_path) self.data_list = [] self.bin_cnt += 1 + + +class Timer: + def __init__(self, name): + self.name = name + self.start_time = None + self.end_time = None + + @property + def elapsed_time(self): + return self.end_time - self.start_time + + def __enter__(self): + torch.cuda.synchronize() + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.cuda.synchronize() + self.end_time = time.time() diff --git a/scripts/train.py b/scripts/train.py index e08524c..5ac9c9a 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -26,6 +26,7 @@ from opensora.utils.misc import ( get_model_numel, requires_grad, to_torch_dtype, + Timer ) from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema @@ -233,31 +234,40 @@ def main(): total=num_steps_per_epoch, ) as pbar: for step, batch in pbar: - x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] - y = batch.pop("text") + timer_list = [] + with Timer("move data") as move_data_t: + x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] + y = batch.pop("text") + timer_list.append(move_data_t) # == visual and text encoding == - with torch.no_grad(): - # Prepare visual inputs - if cfg.get("load_video_features", False): - x = x.to(device, dtype) - else: - x = vae.encode(x) # [B, C, T, H/P, W/P] - # Prepare text inputs - if cfg.get("load_text_features", False): - model_args = {"y": y.to(device, dtype)} - mask = batch.pop("mask") - if isinstance(mask, torch.Tensor): - mask = mask.to(device, dtype) - model_args["mask"] = mask - else: - model_args = text_encoder.encode(y) + with Timer("encode") as encode_t: + with torch.no_grad(): + # Prepare visual inputs + if cfg.get("load_video_features", False): + x = x.to(device, dtype) + else: + x = vae.encode(x) # [B, C, T, H/P, W/P] + # Prepare text inputs + if cfg.get("load_text_features", False): + model_args = {"y": y.to(device, dtype)} + mask = batch.pop("mask") + if isinstance(mask, torch.Tensor): + mask = mask.to(device, dtype) + model_args["mask"] = mask + else: + model_args = text_encoder.encode(y) + coordinator.block_all() + timer_list.append(encode_t) # == mask == - mask = None - if cfg.get("mask_ratios", None) is not None: - mask = mask_generator.get_masks(x) - model_args["x_mask"] = mask + with Timer("mask") as mask_t: + mask = None + if cfg.get("mask_ratios", None) is not None: + mask = mask_generator.get_masks(x) + model_args["x_mask"] = mask + coordinator.block_all() + timer_list.append(mask_t) # == video meta info == for k, v in batch.items(): @@ -265,23 +275,35 @@ def main(): model_args[k] = v.to(device, dtype) # == diffusion loss computation == - loss_dict = scheduler.training_losses(model, x, model_args, mask=mask) + with Timer("diffusion") as loss_t: + loss_dict = scheduler.training_losses(model, x, model_args, mask=mask) + coordinator.block_all() + timer_list.append(loss_t) # == backward & update == - loss = loss_dict["loss"].mean() - booster.backward(loss=loss, optimizer=optimizer) - optimizer.step() - optimizer.zero_grad() + with Timer("backward") as backward_t: + loss = loss_dict["loss"].mean() + booster.backward(loss=loss, optimizer=optimizer) + optimizer.step() + optimizer.zero_grad() + coordinator.block_all() + timer_list.append(backward_t) # == update EMA == - update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999)) + with Timer("update_ema") as ema_t: + update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999)) + coordinator.block_all() + timer_list.append(ema_t) # == update log info == - all_reduce_mean(loss) - running_loss += loss.item() - global_step = epoch * num_steps_per_epoch + step - log_step += 1 - acc_step += 1 + with Timer("reduce_loss") as reduce_loss_t: + all_reduce_mean(loss) + running_loss += loss.item() + global_step = epoch * num_steps_per_epoch + step + log_step += 1 + acc_step += 1 + coordinator.block_all() + timer_list.append(reduce_loss_t) # == logging == if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0: @@ -299,6 +321,13 @@ def main(): "loss": loss.item(), "avg_loss": avg_loss, "acc_step": acc_step, + "move_data_time": move_data_t.elapsed_time, + "encode_time": encode_t.elapsed_time, + "mask_time": mask_t.elapsed_time, + "diffusion_time": loss_t.elapsed_time, + "backward_time": backward_t.elapsed_time, + "update_ema_time": ema_t.elapsed_time, + "reduce_loss_time": reduce_loss_t.elapsed_time, }, step=global_step, ) @@ -332,6 +361,14 @@ def main(): global_step + 1, save_dir, ) + + + log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | " + for timer in timer_list: + log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | " + print(log_str) + coordinator.block_all() + sampler.reset() start_step = 0