update utils

This commit is contained in:
pxy 2024-03-16 16:13:34 +08:00
parent e03676e7e8
commit 8509094de2

View file

@ -284,56 +284,3 @@ def build_logger(work_dir, cfgname):
logger.propagate = False
return logger
timings = {}
from contextlib import contextmanager
@contextmanager
def torch_timer(prefix):
import time
import torch
torch.cuda.synchronize()
start = time.time()
try:
yield
finally:
torch.cuda.synchronize()
t_diff = (time.time() - start) * 1000
if prefix not in timings:
timings[prefix] = []
timings[prefix].append(t_diff)
num_ignored = 10
if len(timings[prefix]) > num_ignored:
# avg = sum(timings[prefix][num_ignored:]) / (len(timings[prefix]) - num_ignored)
avg = sum(timings[prefix][-num_ignored:]) / num_ignored
print("{}: {} ({})".format(prefix, t_diff, avg))
else:
print("{}: {}".format(prefix, t_diff))
def strip_dc(x):
"""
strip DataContainer
"""
try:
import mmcv
except ImportError:
pass
if isinstance(x, dict):
res = {}
for k, v in x.items():
res[k] = strip_dc(v)
return res
if isinstance(x, (list, tuple)) and isinstance(x[0], mmcv.parallel.DataContainer):
return strip_dc(x[0])
elif isinstance(x, mmcv.parallel.DataContainer):
return strip_dc(x.data)
return x