mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
update utils
This commit is contained in:
parent
e03676e7e8
commit
8509094de2
|
|
@ -284,56 +284,3 @@ def build_logger(work_dir, cfgname):
|
|||
logger.propagate = False
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
timings = {}
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_timer(prefix):
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.synchronize()
|
||||
t_diff = (time.time() - start) * 1000
|
||||
if prefix not in timings:
|
||||
timings[prefix] = []
|
||||
|
||||
timings[prefix].append(t_diff)
|
||||
|
||||
num_ignored = 10
|
||||
|
||||
if len(timings[prefix]) > num_ignored:
|
||||
# avg = sum(timings[prefix][num_ignored:]) / (len(timings[prefix]) - num_ignored)
|
||||
avg = sum(timings[prefix][-num_ignored:]) / num_ignored
|
||||
print("{}: {} ({})".format(prefix, t_diff, avg))
|
||||
else:
|
||||
print("{}: {}".format(prefix, t_diff))
|
||||
|
||||
|
||||
def strip_dc(x):
|
||||
"""
|
||||
strip DataContainer
|
||||
"""
|
||||
try:
|
||||
import mmcv
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if isinstance(x, dict):
|
||||
res = {}
|
||||
for k, v in x.items():
|
||||
res[k] = strip_dc(v)
|
||||
return res
|
||||
if isinstance(x, (list, tuple)) and isinstance(x[0], mmcv.parallel.DataContainer):
|
||||
return strip_dc(x[0])
|
||||
elif isinstance(x, mmcv.parallel.DataContainer):
|
||||
return strip_dc(x.data)
|
||||
return x
|
||||
|
|
|
|||
Loading…
Reference in a new issue