Open-Sora/opensora/datasets/bucket.py
Zheng Zangwei (Alex Zheng) febf3ad4b2
Update Open-Sora 2.0 (#807)
* upload v2.0

* update docs

* [hotfix] fit latest fa3 (#802)

* update readme

* update readme

* update readme

* update train readme

* update readme

* update readme: motion score

* cleaning video dc ae WIP

* update config

* add dependency functions

* undo cleaning

* use latest dcae

* complete high compression training

* update hcae config

* cleaned up vae

* update ae.md

* further cleanup

* update vae & ae paths

* align naming of ae

* [hotfix] fix ring attn bwd for fa3 (#803)

* train ae default without wandb

* update config

* update evaluation results

* added hcae report

* update readme

* update readme demo

* update readme demo

* update readme gif

* display demo directly in readme

* update paper

* delete files

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Shen-Chenhui <shen_chenhui@u.nus.edu>
Co-authored-by: wuxiwen <wuxiwen.simon@gmail.com>
2025-03-12 13:14:22 +08:00

140 lines
5.4 KiB
Python

from collections import OrderedDict
import numpy as np
from opensora.utils.logger import log_message
from .aspect import get_closest_ratio, get_resolution_with_aspect_ratio
from .utils import map_target_fps
class Bucket:
def __init__(self, bucket_config: dict[str, dict[int, tuple[float, int] | tuple[tuple[float, float], int]]]):
"""
Args:
bucket_config (dict): A dictionary containing the bucket configuration.
The dictionary should be in the following format:
{
"bucket_name": {
"time": (probability, batch_size),
"time": (probability, batch_size),
...
},
...
}
Or in the following format:
{
"bucket_name": {
"time": ((probability, next_probability), batch_size),
"time": ((probability, next_probability), batch_size),
...
},
...
}
The bucket_name should be the name of the bucket, and the time should be the number of frames in the video.
The probability should be a float between 0 and 1, and the batch_size should be an integer.
If the probability is a tuple, the second value should be the probability to skip to the next time.
"""
aspect_ratios = {key: get_resolution_with_aspect_ratio(key) for key in bucket_config.keys()}
bucket_probs = OrderedDict()
bucket_bs = OrderedDict()
bucket_names = sorted(bucket_config.keys(), key=lambda x: aspect_ratios[x][0], reverse=True)
for key in bucket_names:
bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True)
bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names})
bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names})
self.hw_criteria = {k: aspect_ratios[k][0] for k in bucket_names}
self.t_criteria = {k1: {k2: k2 for k2 in bucket_config[k1].keys()} for k1 in bucket_names}
self.ar_criteria = {
k1: {k2: {k3: v3 for k3, v3 in aspect_ratios[k1][1].items()} for k2 in bucket_config[k1].keys()}
for k1 in bucket_names
}
bucket_id_cnt = num_bucket = 0
bucket_id = dict()
for k1, v1 in bucket_probs.items():
bucket_id[k1] = dict()
for k2, _ in v1.items():
bucket_id[k1][k2] = bucket_id_cnt
bucket_id_cnt += 1
num_bucket += len(aspect_ratios[k1][1])
self.bucket_probs = bucket_probs
self.bucket_bs = bucket_bs
self.bucket_id = bucket_id
self.num_bucket = num_bucket
log_message("Number of buckets: %s", num_bucket)
def get_bucket_id(
self,
T: int,
H: int,
W: int,
fps: float,
path: str | None = None,
seed: int | None = None,
fps_max: int = 16,
) -> tuple[str, int, int] | None:
approx = 0.8
_, sampling_interval = map_target_fps(fps, fps_max)
T = T // sampling_interval
resolution = H * W
rng = np.random.default_rng(seed)
# Reference to probabilities and criteria for faster access
bucket_probs = self.bucket_probs
hw_criteria = self.hw_criteria
ar_criteria = self.ar_criteria
# Start searching for the appropriate bucket
for hw_id, t_criteria in bucket_probs.items():
# if resolution is too low, skip
if resolution < hw_criteria[hw_id] * approx:
continue
# if sample is an image
if T == 1:
if 1 in t_criteria:
if rng.random() < t_criteria[1]:
return hw_id, 1, get_closest_ratio(H, W, ar_criteria[hw_id][1])
continue
# Look for suitable t_id for video
for t_id, prob in t_criteria.items():
if T >= t_id and t_id != 1:
# if prob is a tuple, use the second value as the threshold to skip
# to the next t_id
if isinstance(prob, tuple):
next_hw_prob, next_t_prob = prob
if next_t_prob >= 1 or rng.random() <= next_t_prob:
continue
else:
next_hw_prob = prob
if next_hw_prob >= 1 or rng.random() <= next_hw_prob:
ar_id = get_closest_ratio(H, W, ar_criteria[hw_id][t_id])
return hw_id, t_id, ar_id
else:
break
return None
def get_thw(self, bucket_idx: tuple[str, int, int]) -> tuple[int, int, int]:
assert len(bucket_idx) == 3
T = self.t_criteria[bucket_idx[0]][bucket_idx[1]]
H, W = self.ar_criteria[bucket_idx[0]][bucket_idx[1]][bucket_idx[2]]
return T, H, W
def get_prob(self, bucket_idx: tuple[str, int]) -> float:
return self.bucket_probs[bucket_idx[0]][bucket_idx[1]]
def get_batch_size(self, bucket_idx: tuple[str, int]) -> int:
return self.bucket_bs[bucket_idx[0]][bucket_idx[1]]
def __len__(self) -> int:
return self.num_bucket