diff --git a/configs/opensora-v1-1/train/video.py b/configs/opensora-v1-1/train/video.py index 285ae84..7be2630 100644 --- a/configs/opensora-v1-1/train/video.py +++ b/configs/opensora-v1-1/train/video.py @@ -11,7 +11,7 @@ bucket_config = { # 6s/it "240p": {16: (1.0, 16), 32: (1.0, 8), 64: (1.0, 4), 128: (1.0, 2)}, "256": {1: (1.0, 256)}, "512": {1: (1.0, 80)}, - "480p": {1: (1.0, 52), 16: (0.5, 4), 32: (0.0, None)}, + "480p": {1: (0.5, 52), 16: (0.5, 4), 32: (0.0, None)}, "720p": {16: (1.0, 2), 32: (0.0, None)}, # No examples now "1024": {1: (1.0, 20)}, "1080p": {1: (1.0, 8)}, diff --git a/opensora/datasets/bucket.py b/opensora/datasets/bucket.py index f5c5b46..cf4569b 100644 --- a/opensora/datasets/bucket.py +++ b/opensora/datasets/bucket.py @@ -87,6 +87,7 @@ class Bucket: hw_id = find_approximate_hw(hw, self.hw_criteria) if hw_id is None: return None + hw_id_index = list(self.hw_criteria.keys()).index(hw_id) # hw drops by probablity while True: @@ -96,11 +97,12 @@ class Bucket: prob = self.get_prob((hw_id, T_id)) if torch.rand(1, generator=generator).item() < prob: break - hw_id_index = list(self.hw_criteria.keys()).index(hw_id) - hw_id = list(self.hw_criteria.keys())[hw_id_index + 1] - if hw_id_index == len(self.hw_criteria) - 1: + hw_id_index += 1 + if hw_id_index >= len(self.hw_criteria) - 1: break - if T_id is None: + hw_id = list(self.hw_criteria.keys())[hw_id_index + 1] + + if T_id is None or hw_id_index >= len(self.hw_criteria) - 1: return None # ar