added timer for benchmarking (#114)

This commit is contained in:
Frank Lee 2024-05-29 14:43:15 +08:00 committed by GitHub
parent 11b3881a86
commit 2e1e26a2e4
5 changed files with 92 additions and 35 deletions

View file

@ -63,7 +63,7 @@ model = dict(
)
vae = dict(
type="OpenSoraVAE_V1_2",
from_pretrained="pretrained_models/vae-pipeline",
from_pretrained="/mnt/jfs/sora_checkpoints/vae-pipeline",
micro_frame_size=17,
micro_batch_size=4,
)

View file

@ -8,7 +8,7 @@ def set_data_parallel_group(group: dist.ProcessGroup):
def get_data_parallel_group():
return _GLOBAL_PARALLEL_GROUPS.get("data", None)
return _GLOBAL_PARALLEL_GROUPS.get("data", dist.group.WORLD)
def set_sequence_parallel_group(group: dist.ProcessGroup):

View file

@ -67,7 +67,7 @@ def temporal_random_crop(vframes, num_frames, frame_interval):
temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
total_frames = len(vframes)
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
assert end_frame_ind - start_frame_ind >= num_frames
assert end_frame_ind - start_frame_ind >= num_frames, f"Not enough frames to sample, {end_frame_ind} - {start_frame_ind} < {num_frames}"
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int)
video = vframes[frame_indice]
return video

View file

@ -409,3 +409,23 @@ class FeatureSaver:
get_logger().info("Saved to %s", save_path)
self.data_list = []
self.bin_cnt += 1
class Timer:
def __init__(self, name):
self.name = name
self.start_time = None
self.end_time = None
@property
def elapsed_time(self):
return self.end_time - self.start_time
def __enter__(self):
torch.cuda.synchronize()
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
self.end_time = time.time()

View file

@ -26,6 +26,7 @@ from opensora.utils.misc import (
get_model_numel,
requires_grad,
to_torch_dtype,
Timer
)
from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema
@ -233,31 +234,40 @@ def main():
total=num_steps_per_epoch,
) as pbar:
for step, batch in pbar:
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
timer_list = []
with Timer("move data") as move_data_t:
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
timer_list.append(move_data_t)
# == visual and text encoding ==
with torch.no_grad():
# Prepare visual inputs
if cfg.get("load_video_features", False):
x = x.to(device, dtype)
else:
x = vae.encode(x) # [B, C, T, H/P, W/P]
# Prepare text inputs
if cfg.get("load_text_features", False):
model_args = {"y": y.to(device, dtype)}
mask = batch.pop("mask")
if isinstance(mask, torch.Tensor):
mask = mask.to(device, dtype)
model_args["mask"] = mask
else:
model_args = text_encoder.encode(y)
with Timer("encode") as encode_t:
with torch.no_grad():
# Prepare visual inputs
if cfg.get("load_video_features", False):
x = x.to(device, dtype)
else:
x = vae.encode(x) # [B, C, T, H/P, W/P]
# Prepare text inputs
if cfg.get("load_text_features", False):
model_args = {"y": y.to(device, dtype)}
mask = batch.pop("mask")
if isinstance(mask, torch.Tensor):
mask = mask.to(device, dtype)
model_args["mask"] = mask
else:
model_args = text_encoder.encode(y)
coordinator.block_all()
timer_list.append(encode_t)
# == mask ==
mask = None
if cfg.get("mask_ratios", None) is not None:
mask = mask_generator.get_masks(x)
model_args["x_mask"] = mask
with Timer("mask") as mask_t:
mask = None
if cfg.get("mask_ratios", None) is not None:
mask = mask_generator.get_masks(x)
model_args["x_mask"] = mask
coordinator.block_all()
timer_list.append(mask_t)
# == video meta info ==
for k, v in batch.items():
@ -265,23 +275,35 @@ def main():
model_args[k] = v.to(device, dtype)
# == diffusion loss computation ==
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
with Timer("diffusion") as loss_t:
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
coordinator.block_all()
timer_list.append(loss_t)
# == backward & update ==
loss = loss_dict["loss"].mean()
booster.backward(loss=loss, optimizer=optimizer)
optimizer.step()
optimizer.zero_grad()
with Timer("backward") as backward_t:
loss = loss_dict["loss"].mean()
booster.backward(loss=loss, optimizer=optimizer)
optimizer.step()
optimizer.zero_grad()
coordinator.block_all()
timer_list.append(backward_t)
# == update EMA ==
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
with Timer("update_ema") as ema_t:
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
coordinator.block_all()
timer_list.append(ema_t)
# == update log info ==
all_reduce_mean(loss)
running_loss += loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
with Timer("reduce_loss") as reduce_loss_t:
all_reduce_mean(loss)
running_loss += loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
coordinator.block_all()
timer_list.append(reduce_loss_t)
# == logging ==
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
@ -299,6 +321,13 @@ def main():
"loss": loss.item(),
"avg_loss": avg_loss,
"acc_step": acc_step,
"move_data_time": move_data_t.elapsed_time,
"encode_time": encode_t.elapsed_time,
"mask_time": mask_t.elapsed_time,
"diffusion_time": loss_t.elapsed_time,
"backward_time": backward_t.elapsed_time,
"update_ema_time": ema_t.elapsed_time,
"reduce_loss_time": reduce_loss_t.elapsed_time,
},
step=global_step,
)
@ -332,6 +361,14 @@ def main():
global_step + 1,
save_dir,
)
log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | "
for timer in timer_list:
log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | "
print(log_str)
coordinator.block_all()
sampler.reset()
start_step = 0