mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
fix bucket index
This commit is contained in:
parent
882225897b
commit
02b6335628
|
|
@ -11,7 +11,7 @@ bucket_config = { # 6s/it
|
|||
"240p": {16: (1.0, 16), 32: (1.0, 8), 64: (1.0, 4), 128: (1.0, 2)},
|
||||
"256": {1: (1.0, 256)},
|
||||
"512": {1: (1.0, 80)},
|
||||
"480p": {1: (1.0, 52), 16: (0.5, 4), 32: (0.0, None)},
|
||||
"480p": {1: (0.5, 52), 16: (0.5, 4), 32: (0.0, None)},
|
||||
"720p": {16: (1.0, 2), 32: (0.0, None)}, # No examples now
|
||||
"1024": {1: (1.0, 20)},
|
||||
"1080p": {1: (1.0, 8)},
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ class Bucket:
|
|||
hw_id = find_approximate_hw(hw, self.hw_criteria)
|
||||
if hw_id is None:
|
||||
return None
|
||||
hw_id_index = list(self.hw_criteria.keys()).index(hw_id)
|
||||
|
||||
# hw drops by probablity
|
||||
while True:
|
||||
|
|
@ -96,11 +97,12 @@ class Bucket:
|
|||
prob = self.get_prob((hw_id, T_id))
|
||||
if torch.rand(1, generator=generator).item() < prob:
|
||||
break
|
||||
hw_id_index = list(self.hw_criteria.keys()).index(hw_id)
|
||||
hw_id = list(self.hw_criteria.keys())[hw_id_index + 1]
|
||||
if hw_id_index == len(self.hw_criteria) - 1:
|
||||
hw_id_index += 1
|
||||
if hw_id_index >= len(self.hw_criteria) - 1:
|
||||
break
|
||||
if T_id is None:
|
||||
hw_id = list(self.hw_criteria.keys())[hw_id_index + 1]
|
||||
|
||||
if T_id is None or hw_id_index >= len(self.hw_criteria) - 1:
|
||||
return None
|
||||
|
||||
# ar
|
||||
|
|
|
|||
Loading…
Reference in a new issue