mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
Merge branch 'dev/v1.0.1' of github.com:hpcaitech/Open-Sora-dev into dev/v1.0.1
This commit is contained in:
commit
947a313eb5
|
|
@ -12,7 +12,7 @@ bucket_config = { # 6s/it
|
|||
"256": {1: (1.0, 254)},
|
||||
"512": {1: (0.5, 86)},
|
||||
"480p": {1: (0.4, 54), 16: (0.4, 4), 32: (0.0, None)},
|
||||
"720p": {16: (0.1, 2), 32: (0.0, None)}, # No examples now
|
||||
"720p": {16: (0.1, 2), 32: (0.0, None)},
|
||||
"1024": {1: (0.3, 20)},
|
||||
"1080p": {1: (0.4, 8)},
|
||||
}
|
||||
|
|
@ -30,6 +30,7 @@ mask_ratios = {
|
|||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
num_bucket_build_workers = 16
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
from .aspect import ASPECT_RATIOS, get_closest_ratio
|
||||
import numpy as np
|
||||
|
||||
|
||||
def find_approximate_hw(hw, hw_dict, approx=0.8):
|
||||
|
|
@ -44,12 +43,17 @@ class Bucket:
|
|||
hw_criteria = dict()
|
||||
t_criteria = dict()
|
||||
ar_criteria = dict()
|
||||
bucket_id = OrderedDict()
|
||||
bucket_id_cnt = 0
|
||||
for k1, v1 in bucket_probs.items():
|
||||
hw_criteria[k1] = ASPECT_RATIOS[k1][0]
|
||||
t_criteria[k1] = dict()
|
||||
ar_criteria[k1] = dict()
|
||||
bucket_id[k1] = dict()
|
||||
for k2, _ in v1.items():
|
||||
t_criteria[k1][k2] = k2
|
||||
bucket_id[k1][k2] = bucket_id_cnt
|
||||
bucket_id_cnt += 1
|
||||
ar_criteria[k1][k2] = dict()
|
||||
for k3, v3 in ASPECT_RATIOS[k1][1].items():
|
||||
ar_criteria[k1][k2][k3] = v3
|
||||
|
|
@ -57,58 +61,52 @@ class Bucket:
|
|||
|
||||
self.bucket_probs = bucket_probs
|
||||
self.bucket_bs = bucket_bs
|
||||
self.bucket_id = bucket_id
|
||||
self.hw_criteria = hw_criteria
|
||||
self.t_criteria = t_criteria
|
||||
self.ar_criteria = ar_criteria
|
||||
self.num_bucket = num_bucket
|
||||
print(f"Number of buckets: {num_bucket}")
|
||||
|
||||
def info_bucket(self, dataset, frame_interval=1):
|
||||
infos = dict()
|
||||
infos_ar = dict()
|
||||
for i in range(len(dataset)):
|
||||
T, H, W = dataset.get_data_info(i)
|
||||
bucket_id = self.get_bucket_id(T, H, W, frame_interval)
|
||||
if bucket_id is None:
|
||||
def get_bucket_id(self, T, H, W, frame_interval=1, seed=None):
|
||||
resolution = H * W
|
||||
approx = 0.8
|
||||
|
||||
fail = True
|
||||
for hw_id, t_criteria in self.bucket_probs.items():
|
||||
if resolution < self.hw_criteria[hw_id] * approx:
|
||||
continue
|
||||
if f"{(bucket_id[0], bucket_id[1])}" not in infos:
|
||||
infos[f"{(bucket_id[0], bucket_id[1])}"] = 0
|
||||
if f"{bucket_id[2]}" not in infos_ar:
|
||||
infos_ar[f"{bucket_id[2]}"] = 0
|
||||
infos[f"{(bucket_id[0], bucket_id[1])}"] += 1
|
||||
infos_ar[f"{bucket_id[2]}"] += 1
|
||||
print(f"Dataset contains {len(dataset)} samples.")
|
||||
print("Bucket info:", infos)
|
||||
print("Aspect ratio info:", infos_ar)
|
||||
|
||||
def get_bucket_id(self, T, H, W, frame_interval=1, generator=None):
|
||||
# hw
|
||||
hw = H * W
|
||||
hw_id = find_approximate_hw(hw, self.hw_criteria)
|
||||
if hw_id is None:
|
||||
return None
|
||||
hw_id_index = list(self.hw_criteria.keys()).index(hw_id)
|
||||
|
||||
# hw drops by probablity
|
||||
while True:
|
||||
# T
|
||||
T_id = find_closet_smaller_bucket(T, self.t_criteria[hw_id], frame_interval)
|
||||
if T_id is not None:
|
||||
prob = self.get_prob((hw_id, T_id))
|
||||
if torch.rand(1, generator=generator).item() < prob:
|
||||
# if sample is an image
|
||||
if T == 1:
|
||||
if 1 in t_criteria:
|
||||
fail = False
|
||||
t_id = 1
|
||||
break
|
||||
hw_id_index += 1
|
||||
if hw_id_index > len(self.hw_criteria) - 1:
|
||||
break
|
||||
hw_id = list(self.hw_criteria.keys())[hw_id_index]
|
||||
else:
|
||||
continue
|
||||
|
||||
if T_id is None or hw_id_index > len(self.hw_criteria) - 1:
|
||||
# otherwise, find suitable t_id for video
|
||||
t_fail = True
|
||||
for t_id, prob in t_criteria.items():
|
||||
if T > t_id * frame_interval and t_id != 1:
|
||||
t_fail = False
|
||||
break
|
||||
if t_fail:
|
||||
continue
|
||||
|
||||
# leave the loop if prob is high enough
|
||||
rng = np.random.default_rng(seed + self.bucket_id[hw_id][t_id])
|
||||
if prob == 1 or rng.random() < prob:
|
||||
fail = False
|
||||
break
|
||||
if fail:
|
||||
return None
|
||||
|
||||
# ar
|
||||
ar_criteria = self.ar_criteria[hw_id][T_id]
|
||||
# get aspect ratio id
|
||||
ar_criteria = self.ar_criteria[hw_id][t_id]
|
||||
ar_id = get_closest_ratio(H, W, ar_criteria)
|
||||
return hw_id, T_id, ar_id
|
||||
return hw_id, t_id, ar_id
|
||||
|
||||
def get_thw(self, bucket_id):
|
||||
assert len(bucket_id) == 3
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ def prepare_variable_dataloader(
|
|||
pin_memory=False,
|
||||
num_workers=0,
|
||||
process_group=None,
|
||||
num_bucket_build_workers=1,
|
||||
**kwargs,
|
||||
):
|
||||
_kwargs = kwargs.copy()
|
||||
|
|
@ -120,6 +121,7 @@ def prepare_variable_dataloader(
|
|||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
verbose=True,
|
||||
num_bucket_build_workers=num_bucket_build_workers,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ class VariableVideoTextDataset(VideoTextDataset):
|
|||
):
|
||||
super().__init__(data_path, num_frames, frame_interval, image_size, transform_name=None)
|
||||
self.transform_name = transform_name
|
||||
self.data["id"] = np.arange(len(self.data))
|
||||
|
||||
def get_data_info(self, index):
|
||||
T = self.data.iloc[index]["num_frames"]
|
||||
|
|
|
|||
|
|
@ -3,14 +3,28 @@ from collections import OrderedDict, defaultdict
|
|||
from pprint import pprint
|
||||
from typing import Iterator, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from pandarallel import pandarallel
|
||||
from torch.utils.data import DistributedSampler
|
||||
|
||||
from .bucket import Bucket
|
||||
from .datasets import VariableVideoTextDataset
|
||||
|
||||
|
||||
# HACK: use pandarallel
|
||||
# pandarallel should only access local variables
|
||||
def apply(data, method=None, frame_interval=None, seed=None, num_bucket=None):
|
||||
return method(
|
||||
data["num_frames"],
|
||||
data["height"],
|
||||
data["width"],
|
||||
frame_interval,
|
||||
seed + data["id"] * num_bucket,
|
||||
)
|
||||
|
||||
|
||||
class VariableVideoBatchSampler(DistributedSampler):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -22,6 +36,7 @@ class VariableVideoBatchSampler(DistributedSampler):
|
|||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
verbose: bool = False,
|
||||
num_bucket_build_workers: int = 1,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last
|
||||
|
|
@ -32,48 +47,56 @@ class VariableVideoBatchSampler(DistributedSampler):
|
|||
self.last_micro_batch_access_index = 0
|
||||
self.approximate_num_batch = None
|
||||
|
||||
def get_num_batch(self) -> int:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
self._get_num_batch_cached_bucket_sample_dict = None
|
||||
self.num_bucket_build_workers = num_bucket_build_workers
|
||||
|
||||
def group_by_bucket(self) -> dict:
|
||||
bucket_sample_dict = OrderedDict()
|
||||
|
||||
pandarallel.initialize(progress_bar=True, nb_workers=self.num_bucket_build_workers)
|
||||
bucket_ids = self.dataset.data.parallel_apply(
|
||||
apply,
|
||||
axis=1,
|
||||
method=self.bucket.get_bucket_id,
|
||||
frame_interval=self.dataset.frame_interval,
|
||||
seed=self.seed + self.epoch,
|
||||
num_bucket=self.bucket.num_bucket,
|
||||
)
|
||||
|
||||
# group by bucket
|
||||
# each data sample is put into a bucket with a similar image/video size
|
||||
for i in range(len(self.dataset)):
|
||||
t, h, w = self.dataset.get_data_info(i)
|
||||
bucket_id = self.bucket.get_bucket_id(t, h, w, self.dataset.frame_interval, g)
|
||||
bucket_id = bucket_ids[i]
|
||||
if bucket_id is None:
|
||||
continue
|
||||
if bucket_id not in bucket_sample_dict:
|
||||
bucket_sample_dict[bucket_id] = []
|
||||
bucket_sample_dict[bucket_id].append(i)
|
||||
return bucket_sample_dict
|
||||
|
||||
def get_num_batch(self) -> int:
|
||||
bucket_sample_dict = self.group_by_bucket()
|
||||
self._get_num_batch_cached_bucket_sample_dict = bucket_sample_dict
|
||||
|
||||
# calculate the number of batches
|
||||
if self.verbose:
|
||||
self._print_bucket_info(bucket_sample_dict)
|
||||
return self.approximate_num_batch
|
||||
|
||||
def __iter__(self) -> Iterator[List[int]]:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
bucket_sample_dict = OrderedDict()
|
||||
bucket_micro_batch_count = OrderedDict()
|
||||
bucket_last_consumed = OrderedDict()
|
||||
|
||||
# group by bucket
|
||||
# each data sample is put into a bucket with a similar image/video size
|
||||
for i in range(len(self.dataset)):
|
||||
t, h, w = self.dataset.get_data_info(i)
|
||||
bucket_id = self.bucket.get_bucket_id(t, h, w, self.dataset.frame_interval, g)
|
||||
if bucket_id is None:
|
||||
continue
|
||||
if bucket_id not in bucket_sample_dict:
|
||||
bucket_sample_dict[bucket_id] = []
|
||||
bucket_sample_dict[bucket_id].append(i)
|
||||
|
||||
# print bucket info
|
||||
if self._get_num_batch_cached_bucket_sample_dict is not None:
|
||||
bucket_sample_dict = self._get_num_batch_cached_bucket_sample_dict
|
||||
self._get_num_batch_cached_bucket_sample_dict = None
|
||||
else:
|
||||
bucket_sample_dict = self.group_by_bucket()
|
||||
if self.verbose:
|
||||
self._print_bucket_info(bucket_sample_dict)
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
bucket_micro_batch_count = OrderedDict()
|
||||
bucket_last_consumed = OrderedDict()
|
||||
|
||||
# process the samples
|
||||
for bucket_id, data_list in bucket_sample_dict.items():
|
||||
# handle droplast
|
||||
|
|
|
|||
|
|
@ -103,6 +103,8 @@ def merge_args(cfg, args, training=False):
|
|||
cfg["bucket_config"] = None
|
||||
if "transform_name" not in cfg.dataset:
|
||||
cfg.dataset["transform_name"] = "center"
|
||||
if "num_bucket_build_workers" not in cfg:
|
||||
cfg["num_bucket_build_workers"] = 1
|
||||
|
||||
# Both training and inference
|
||||
if "multi_resolution" not in cfg:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from pprint import pprint
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
|
@ -11,7 +12,6 @@ from colossalai.nn.optimizer import HybridAdam
|
|||
from colossalai.utils import get_current_device, set_seed
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
from opensora.acceleration.parallel_states import (
|
||||
get_data_parallel_group,
|
||||
|
|
@ -95,6 +95,7 @@ def main():
|
|||
# 3. build dataset and dataloader
|
||||
# ======================================================
|
||||
dataset = build_module(cfg.dataset, DATASETS)
|
||||
logger.info(f"Dataset contains {len(dataset)} samples.")
|
||||
dataloader_args = dict(
|
||||
dataset=dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
|
|
@ -109,7 +110,11 @@ def main():
|
|||
if cfg.bucket_config is None:
|
||||
dataloader = prepare_dataloader(**dataloader_args)
|
||||
else:
|
||||
dataloader = prepare_variable_dataloader(bucket_config=cfg.bucket_config, **dataloader_args)
|
||||
dataloader = prepare_variable_dataloader(
|
||||
bucket_config=cfg.bucket_config,
|
||||
num_bucket_build_workers=cfg.num_bucket_build_workers,
|
||||
**dataloader_args,
|
||||
)
|
||||
if cfg.dataset.type == "VideoTextDataset":
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
|
||||
logger.info(f"Total batch size: {total_batch_size}")
|
||||
|
|
|
|||
Loading…
Reference in a new issue