Open-Sora/opensora/datasets/aspect.py
Zheng Zangwei (Alex Zheng) febf3ad4b2
Update Open-Sora 2.0 (#807)
* upload v2.0

* update docs

* [hotfix] fit latest fa3 (#802)

* update readme

* update readme

* update readme

* update train readme

* update readme

* update readme: motion score

* cleaning video dc ae WIP

* update config

* add dependency functions

* undo cleaning

* use latest dcae

* complete high compression training

* update hcae config

* cleaned up vae

* update ae.md

* further cleanup

* update vae & ae paths

* align naming of ae

* [hotfix] fix ring attn bwd for fa3 (#803)

* train ae default without wandb

* update config

* update evaluation results

* added hcae report

* update readme

* update readme demo

* update readme demo

* update readme gif

* display demo directly in readme

* update paper

* delete files

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Shen-Chenhui <shen_chenhui@u.nus.edu>
Co-authored-by: wuxiwen <wuxiwen.simon@gmail.com>
2025-03-12 13:14:22 +08:00

152 lines
5 KiB
Python

import math
import os
ASPECT_RATIO_LD_LIST = [ # width:height
"2.39:1", # cinemascope, 2.39
"2:1", # rare, 2
"16:9", # rare, 1.89
"1.85:1", # american widescreen, 1.85
"9:16", # popular, 1.78
"5:8", # rare, 1.6
"3:2", # rare, 1.5
"4:3", # classic, 1.33
"1:1", # square
]
def get_ratio(name: str) -> float:
width, height = map(float, name.split(":"))
return height / width
def get_aspect_ratios_dict(
total_pixels: int = 256 * 256, training: bool = True
) -> dict[str, tuple[int, int]]:
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
aspect_ratios_dict = {}
aspect_ratios_vertical_dict = {}
for ratio in ASPECT_RATIO_LD_LIST:
width_ratio, height_ratio = map(float, ratio.split(":"))
width = int(math.sqrt(total_pixels * (width_ratio / height_ratio)) // D) * D
height = int((total_pixels / width) // D) * D
if training:
# adjust aspect ratio to match total pixels
diff = abs(height * width - total_pixels)
candidate = [
(height - D, width),
(height + D, width),
(height, width - D),
(height, width + D),
]
for h, w in candidate:
if abs(h * w - total_pixels) < diff:
height, width = h, w
diff = abs(h * w - total_pixels)
# remove duplicated aspect ratio
if (height, width) not in aspect_ratios_dict.values() or not training:
aspect_ratios_dict[ratio] = (height, width)
vertial_ratios = ":".join(ratio.split(":")[::-1])
aspect_ratios_vertical_dict[vertial_ratios] = (width, height)
aspect_ratios_dict.update(aspect_ratios_vertical_dict)
return aspect_ratios_dict
def get_num_pexels(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str, int]:
return {ratio: h * w for ratio, (h, w) in aspect_ratios_dict.items()}
def get_num_tokens(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str, int]:
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
return {ratio: h * w // D // D for ratio, (h, w) in aspect_ratios_dict.items()}
def get_num_pexels_from_name(resolution: str) -> int:
resolution = resolution.split("_")[0]
if resolution.endswith("px"):
size = int(resolution[:-2])
num_pexels = size * size
elif resolution.endswith("p"):
size = int(resolution[:-1])
num_pexels = int(size * size / 9 * 16)
else:
raise ValueError(f"Invalid resolution {resolution}")
return num_pexels
def get_resolution_with_aspect_ratio(
resolution: str,
) -> tuple[int, dict[str, tuple[int, int]]]:
"""Get resolution with aspect ratio
Args:
resolution (str): resolution name. The format is name only or "{name}_{setting}".
name supports "256px" or "360p". setting supports "ar1:1" or "max".
Returns:
tuple[int, dict[str, tuple[int, int]]]: resolution with aspect ratio
"""
keys = resolution.split("_")
if len(keys) == 1:
resolution = keys[0]
setting = ""
else:
resolution, setting = keys
assert setting == "max" or setting.startswith(
"ar"
), f"Invalid setting {setting}"
# get resolution
num_pexels = get_num_pexels_from_name(resolution)
# get aspect ratio
aspect_ratio_dict = get_aspect_ratios_dict(num_pexels)
# handle setting
if setting == "max":
aspect_ratio = max(
aspect_ratio_dict,
key=lambda x: aspect_ratio_dict[x][0] * aspect_ratio_dict[x][1],
)
aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]}
elif setting.startswith("ar"):
aspect_ratio = setting[2:]
assert (
aspect_ratio in aspect_ratio_dict
), f"Aspect ratio {aspect_ratio} not found"
aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]}
return num_pexels, aspect_ratio_dict
def get_closest_ratio(height: float, width: float, ratios: dict) -> str:
aspect_ratio = height / width
closest_ratio = min(
ratios.keys(), key=lambda ratio: abs(aspect_ratio - get_ratio(ratio))
)
return closest_ratio
def get_image_size(
resolution: str, ar_ratio: str, training: bool = True
) -> tuple[int, int]:
num_pexels = get_num_pexels_from_name(resolution)
ar_dict = get_aspect_ratios_dict(num_pexels, training)
assert ar_ratio in ar_dict, f"Aspect ratio {ar_ratio} not found"
return ar_dict[ar_ratio]
def bucket_to_shapes(bucket_config, batch_size=None):
shapes = []
for resolution, infos in bucket_config.items():
for num_frames, (_, bs) in infos.items():
aspect_ratios = get_aspect_ratios_dict(get_num_pexels_from_name(resolution))
for ar, (height, width) in aspect_ratios.items():
if batch_size is not None:
bs = batch_size
shapes.append((bs, 3, num_frames, height, width))
return shapes