mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-14 18:25:35 +02:00
143 lines
4.6 KiB
Python
143 lines
4.6 KiB
Python
from collections import OrderedDict
|
|
|
|
import numpy as np
|
|
|
|
from opensora.utils.misc import get_logger
|
|
|
|
from .aspect import ASPECT_RATIOS, get_closest_ratio
|
|
|
|
|
|
def find_approximate_hw(hw, hw_dict, approx=0.8):
|
|
for k, v in hw_dict.items():
|
|
if hw >= v * approx:
|
|
return k
|
|
return None
|
|
|
|
|
|
def find_closet_smaller_bucket(t, t_dict, frame_interval):
|
|
# process image
|
|
if t == 1:
|
|
if 1 in t_dict:
|
|
return 1
|
|
else:
|
|
return None
|
|
# process video
|
|
for k, v in t_dict.items():
|
|
if t >= v * frame_interval and v != 1:
|
|
return k
|
|
return None
|
|
|
|
|
|
class Bucket:
|
|
def __init__(self, bucket_config):
|
|
for key in bucket_config:
|
|
assert key in ASPECT_RATIOS, f"Aspect ratio {key} not found."
|
|
# wrap config with OrderedDict
|
|
bucket_probs = OrderedDict()
|
|
bucket_bs = OrderedDict()
|
|
bucket_names = sorted(bucket_config.keys(), key=lambda x: ASPECT_RATIOS[x][0], reverse=True)
|
|
for key in bucket_names:
|
|
bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True)
|
|
bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names})
|
|
bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names})
|
|
|
|
# first level: HW
|
|
num_bucket = 0
|
|
hw_criteria = dict()
|
|
t_criteria = dict()
|
|
ar_criteria = dict()
|
|
bucket_id = OrderedDict()
|
|
bucket_id_cnt = 0
|
|
for k1, v1 in bucket_probs.items():
|
|
hw_criteria[k1] = ASPECT_RATIOS[k1][0]
|
|
t_criteria[k1] = dict()
|
|
ar_criteria[k1] = dict()
|
|
bucket_id[k1] = dict()
|
|
for k2, _ in v1.items():
|
|
t_criteria[k1][k2] = k2
|
|
bucket_id[k1][k2] = bucket_id_cnt
|
|
bucket_id_cnt += 1
|
|
ar_criteria[k1][k2] = dict()
|
|
for k3, v3 in ASPECT_RATIOS[k1][1].items():
|
|
ar_criteria[k1][k2][k3] = v3
|
|
num_bucket += 1
|
|
|
|
self.bucket_probs = bucket_probs
|
|
self.bucket_bs = bucket_bs
|
|
self.bucket_id = bucket_id
|
|
self.hw_criteria = hw_criteria
|
|
self.t_criteria = t_criteria
|
|
self.ar_criteria = ar_criteria
|
|
self.num_bucket = num_bucket
|
|
get_logger().info("Number of buckets: %s", num_bucket)
|
|
|
|
def get_bucket_id(self, T, H, W, frame_interval=1, seed=None):
|
|
resolution = H * W
|
|
approx = 0.8
|
|
|
|
fail = True
|
|
for hw_id, t_criteria in self.bucket_probs.items():
|
|
if resolution < self.hw_criteria[hw_id] * approx:
|
|
continue
|
|
|
|
# if sample is an image
|
|
if T == 1:
|
|
if 1 in t_criteria:
|
|
rng = np.random.default_rng(seed + self.bucket_id[hw_id][1])
|
|
if rng.random() < t_criteria[1]:
|
|
fail = False
|
|
t_id = 1
|
|
break
|
|
else:
|
|
continue
|
|
|
|
# otherwise, find suitable t_id for video
|
|
t_fail = True
|
|
for t_id, prob in t_criteria.items():
|
|
rng = np.random.default_rng(seed + self.bucket_id[hw_id][t_id])
|
|
if isinstance(prob, tuple):
|
|
prob_t = prob[1]
|
|
if rng.random() > prob_t:
|
|
continue
|
|
if T > t_id * frame_interval and t_id != 1:
|
|
t_fail = False
|
|
break
|
|
if t_fail:
|
|
continue
|
|
|
|
# leave the loop if prob is high enough
|
|
if isinstance(prob, tuple):
|
|
prob = prob[0]
|
|
if prob >= 1 or rng.random() < prob:
|
|
fail = False
|
|
break
|
|
if fail:
|
|
return None
|
|
|
|
# get aspect ratio id
|
|
ar_criteria = self.ar_criteria[hw_id][t_id]
|
|
ar_id = get_closest_ratio(H, W, ar_criteria)
|
|
return hw_id, t_id, ar_id
|
|
|
|
def get_thw(self, bucket_id):
|
|
assert len(bucket_id) == 3
|
|
T = self.t_criteria[bucket_id[0]][bucket_id[1]]
|
|
H, W = self.ar_criteria[bucket_id[0]][bucket_id[1]][bucket_id[2]]
|
|
return T, H, W
|
|
|
|
def get_prob(self, bucket_id):
|
|
return self.bucket_probs[bucket_id[0]][bucket_id[1]]
|
|
|
|
def get_batch_size(self, bucket_id):
|
|
return self.bucket_bs[bucket_id[0]][bucket_id[1]]
|
|
|
|
def __len__(self):
|
|
return self.num_bucket
|
|
|
|
|
|
def closet_smaller_bucket(value, bucket):
|
|
for i in range(1, len(bucket)):
|
|
if value < bucket[i]:
|
|
return bucket[i - 1]
|
|
return bucket[-1]
|