fix bucket index

This commit is contained in:
Zangwei Zheng 2024-04-07 19:35:30 +08:00
parent 882225897b
commit 02b6335628
2 changed files with 7 additions and 5 deletions

View file

@ -11,7 +11,7 @@ bucket_config = { # 6s/it
"240p": {16: (1.0, 16), 32: (1.0, 8), 64: (1.0, 4), 128: (1.0, 2)},
"256": {1: (1.0, 256)},
"512": {1: (1.0, 80)},
"480p": {1: (1.0, 52), 16: (0.5, 4), 32: (0.0, None)},
"480p": {1: (0.5, 52), 16: (0.5, 4), 32: (0.0, None)},
"720p": {16: (1.0, 2), 32: (0.0, None)}, # No examples now
"1024": {1: (1.0, 20)},
"1080p": {1: (1.0, 8)},

View file

@ -87,6 +87,7 @@ class Bucket:
hw_id = find_approximate_hw(hw, self.hw_criteria)
if hw_id is None:
return None
hw_id_index = list(self.hw_criteria.keys()).index(hw_id)
# hw drops by probablity
while True:
@ -96,11 +97,12 @@ class Bucket:
prob = self.get_prob((hw_id, T_id))
if torch.rand(1, generator=generator).item() < prob:
break
hw_id_index = list(self.hw_criteria.keys()).index(hw_id)
hw_id = list(self.hw_criteria.keys())[hw_id_index + 1]
if hw_id_index == len(self.hw_criteria) - 1:
hw_id_index += 1
if hw_id_index >= len(self.hw_criteria) - 1:
break
if T_id is None:
hw_id = list(self.hw_criteria.keys())[hw_id_index + 1]
if T_id is None or hw_id_index >= len(self.hw_criteria) - 1:
return None
# ar