From 8509094de2bfc469a7f4078d0fc4b88c7e9ea4fe Mon Sep 17 00:00:00 2001 From: pxy Date: Sat, 16 Mar 2024 16:13:34 +0800 Subject: [PATCH] update utils --- opensora/utils/misc.py | 53 ------------------------------------------ 1 file changed, 53 deletions(-) diff --git a/opensora/utils/misc.py b/opensora/utils/misc.py index 9d89c10..d162526 100644 --- a/opensora/utils/misc.py +++ b/opensora/utils/misc.py @@ -284,56 +284,3 @@ def build_logger(work_dir, cfgname): logger.propagate = False return logger - - -timings = {} -from contextlib import contextmanager - - -@contextmanager -def torch_timer(prefix): - import time - - import torch - - torch.cuda.synchronize() - start = time.time() - try: - yield - finally: - torch.cuda.synchronize() - t_diff = (time.time() - start) * 1000 - if prefix not in timings: - timings[prefix] = [] - - timings[prefix].append(t_diff) - - num_ignored = 10 - - if len(timings[prefix]) > num_ignored: - # avg = sum(timings[prefix][num_ignored:]) / (len(timings[prefix]) - num_ignored) - avg = sum(timings[prefix][-num_ignored:]) / num_ignored - print("{}: {} ({})".format(prefix, t_diff, avg)) - else: - print("{}: {}".format(prefix, t_diff)) - - -def strip_dc(x): - """ - strip DataContainer - """ - try: - import mmcv - except ImportError: - pass - - if isinstance(x, dict): - res = {} - for k, v in x.items(): - res[k] = strip_dc(v) - return res - if isinstance(x, (list, tuple)) and isinstance(x[0], mmcv.parallel.DataContainer): - return strip_dc(x[0]) - elif isinstance(x, mmcv.parallel.DataContainer): - return strip_dc(x.data) - return x