mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
Merge branch 'dev/v1.0.1' of github.com:hpcaitech/Open-Sora-dev into dev/v1.0.1
This commit is contained in:
commit
947a313eb5
|
|
@ -12,7 +12,7 @@ bucket_config = { # 6s/it
|
||||||
"256": {1: (1.0, 254)},
|
"256": {1: (1.0, 254)},
|
||||||
"512": {1: (0.5, 86)},
|
"512": {1: (0.5, 86)},
|
||||||
"480p": {1: (0.4, 54), 16: (0.4, 4), 32: (0.0, None)},
|
"480p": {1: (0.4, 54), 16: (0.4, 4), 32: (0.0, None)},
|
||||||
"720p": {16: (0.1, 2), 32: (0.0, None)}, # No examples now
|
"720p": {16: (0.1, 2), 32: (0.0, None)},
|
||||||
"1024": {1: (0.3, 20)},
|
"1024": {1: (0.3, 20)},
|
||||||
"1080p": {1: (0.4, 8)},
|
"1080p": {1: (0.4, 8)},
|
||||||
}
|
}
|
||||||
|
|
@ -30,6 +30,7 @@ mask_ratios = {
|
||||||
|
|
||||||
# Define acceleration
|
# Define acceleration
|
||||||
num_workers = 4
|
num_workers = 4
|
||||||
|
num_bucket_build_workers = 16
|
||||||
dtype = "bf16"
|
dtype = "bf16"
|
||||||
grad_checkpoint = True
|
grad_checkpoint = True
|
||||||
plugin = "zero2"
|
plugin = "zero2"
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .aspect import ASPECT_RATIOS, get_closest_ratio
|
from .aspect import ASPECT_RATIOS, get_closest_ratio
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def find_approximate_hw(hw, hw_dict, approx=0.8):
|
def find_approximate_hw(hw, hw_dict, approx=0.8):
|
||||||
|
|
@ -44,12 +43,17 @@ class Bucket:
|
||||||
hw_criteria = dict()
|
hw_criteria = dict()
|
||||||
t_criteria = dict()
|
t_criteria = dict()
|
||||||
ar_criteria = dict()
|
ar_criteria = dict()
|
||||||
|
bucket_id = OrderedDict()
|
||||||
|
bucket_id_cnt = 0
|
||||||
for k1, v1 in bucket_probs.items():
|
for k1, v1 in bucket_probs.items():
|
||||||
hw_criteria[k1] = ASPECT_RATIOS[k1][0]
|
hw_criteria[k1] = ASPECT_RATIOS[k1][0]
|
||||||
t_criteria[k1] = dict()
|
t_criteria[k1] = dict()
|
||||||
ar_criteria[k1] = dict()
|
ar_criteria[k1] = dict()
|
||||||
|
bucket_id[k1] = dict()
|
||||||
for k2, _ in v1.items():
|
for k2, _ in v1.items():
|
||||||
t_criteria[k1][k2] = k2
|
t_criteria[k1][k2] = k2
|
||||||
|
bucket_id[k1][k2] = bucket_id_cnt
|
||||||
|
bucket_id_cnt += 1
|
||||||
ar_criteria[k1][k2] = dict()
|
ar_criteria[k1][k2] = dict()
|
||||||
for k3, v3 in ASPECT_RATIOS[k1][1].items():
|
for k3, v3 in ASPECT_RATIOS[k1][1].items():
|
||||||
ar_criteria[k1][k2][k3] = v3
|
ar_criteria[k1][k2][k3] = v3
|
||||||
|
|
@ -57,58 +61,52 @@ class Bucket:
|
||||||
|
|
||||||
self.bucket_probs = bucket_probs
|
self.bucket_probs = bucket_probs
|
||||||
self.bucket_bs = bucket_bs
|
self.bucket_bs = bucket_bs
|
||||||
|
self.bucket_id = bucket_id
|
||||||
self.hw_criteria = hw_criteria
|
self.hw_criteria = hw_criteria
|
||||||
self.t_criteria = t_criteria
|
self.t_criteria = t_criteria
|
||||||
self.ar_criteria = ar_criteria
|
self.ar_criteria = ar_criteria
|
||||||
self.num_bucket = num_bucket
|
self.num_bucket = num_bucket
|
||||||
print(f"Number of buckets: {num_bucket}")
|
print(f"Number of buckets: {num_bucket}")
|
||||||
|
|
||||||
def info_bucket(self, dataset, frame_interval=1):
|
def get_bucket_id(self, T, H, W, frame_interval=1, seed=None):
|
||||||
infos = dict()
|
resolution = H * W
|
||||||
infos_ar = dict()
|
approx = 0.8
|
||||||
for i in range(len(dataset)):
|
|
||||||
T, H, W = dataset.get_data_info(i)
|
fail = True
|
||||||
bucket_id = self.get_bucket_id(T, H, W, frame_interval)
|
for hw_id, t_criteria in self.bucket_probs.items():
|
||||||
if bucket_id is None:
|
if resolution < self.hw_criteria[hw_id] * approx:
|
||||||
continue
|
continue
|
||||||
if f"{(bucket_id[0], bucket_id[1])}" not in infos:
|
|
||||||
infos[f"{(bucket_id[0], bucket_id[1])}"] = 0
|
|
||||||
if f"{bucket_id[2]}" not in infos_ar:
|
|
||||||
infos_ar[f"{bucket_id[2]}"] = 0
|
|
||||||
infos[f"{(bucket_id[0], bucket_id[1])}"] += 1
|
|
||||||
infos_ar[f"{bucket_id[2]}"] += 1
|
|
||||||
print(f"Dataset contains {len(dataset)} samples.")
|
|
||||||
print("Bucket info:", infos)
|
|
||||||
print("Aspect ratio info:", infos_ar)
|
|
||||||
|
|
||||||
def get_bucket_id(self, T, H, W, frame_interval=1, generator=None):
|
# if sample is an image
|
||||||
# hw
|
if T == 1:
|
||||||
hw = H * W
|
if 1 in t_criteria:
|
||||||
hw_id = find_approximate_hw(hw, self.hw_criteria)
|
fail = False
|
||||||
if hw_id is None:
|
t_id = 1
|
||||||
return None
|
|
||||||
hw_id_index = list(self.hw_criteria.keys()).index(hw_id)
|
|
||||||
|
|
||||||
# hw drops by probablity
|
|
||||||
while True:
|
|
||||||
# T
|
|
||||||
T_id = find_closet_smaller_bucket(T, self.t_criteria[hw_id], frame_interval)
|
|
||||||
if T_id is not None:
|
|
||||||
prob = self.get_prob((hw_id, T_id))
|
|
||||||
if torch.rand(1, generator=generator).item() < prob:
|
|
||||||
break
|
break
|
||||||
hw_id_index += 1
|
else:
|
||||||
if hw_id_index > len(self.hw_criteria) - 1:
|
continue
|
||||||
break
|
|
||||||
hw_id = list(self.hw_criteria.keys())[hw_id_index]
|
|
||||||
|
|
||||||
if T_id is None or hw_id_index > len(self.hw_criteria) - 1:
|
# otherwise, find suitable t_id for video
|
||||||
|
t_fail = True
|
||||||
|
for t_id, prob in t_criteria.items():
|
||||||
|
if T > t_id * frame_interval and t_id != 1:
|
||||||
|
t_fail = False
|
||||||
|
break
|
||||||
|
if t_fail:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# leave the loop if prob is high enough
|
||||||
|
rng = np.random.default_rng(seed + self.bucket_id[hw_id][t_id])
|
||||||
|
if prob == 1 or rng.random() < prob:
|
||||||
|
fail = False
|
||||||
|
break
|
||||||
|
if fail:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# ar
|
# get aspect ratio id
|
||||||
ar_criteria = self.ar_criteria[hw_id][T_id]
|
ar_criteria = self.ar_criteria[hw_id][t_id]
|
||||||
ar_id = get_closest_ratio(H, W, ar_criteria)
|
ar_id = get_closest_ratio(H, W, ar_criteria)
|
||||||
return hw_id, T_id, ar_id
|
return hw_id, t_id, ar_id
|
||||||
|
|
||||||
def get_thw(self, bucket_id):
|
def get_thw(self, bucket_id):
|
||||||
assert len(bucket_id) == 3
|
assert len(bucket_id) == 3
|
||||||
|
|
|
||||||
|
|
@ -107,6 +107,7 @@ def prepare_variable_dataloader(
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
process_group=None,
|
process_group=None,
|
||||||
|
num_bucket_build_workers=1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
_kwargs = kwargs.copy()
|
_kwargs = kwargs.copy()
|
||||||
|
|
@ -120,6 +121,7 @@ def prepare_variable_dataloader(
|
||||||
seed=seed,
|
seed=seed,
|
||||||
drop_last=drop_last,
|
drop_last=drop_last,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
|
num_bucket_build_workers=num_bucket_build_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Deterministic dataloader
|
# Deterministic dataloader
|
||||||
|
|
|
||||||
|
|
@ -113,6 +113,7 @@ class VariableVideoTextDataset(VideoTextDataset):
|
||||||
):
|
):
|
||||||
super().__init__(data_path, num_frames, frame_interval, image_size, transform_name=None)
|
super().__init__(data_path, num_frames, frame_interval, image_size, transform_name=None)
|
||||||
self.transform_name = transform_name
|
self.transform_name = transform_name
|
||||||
|
self.data["id"] = np.arange(len(self.data))
|
||||||
|
|
||||||
def get_data_info(self, index):
|
def get_data_info(self, index):
|
||||||
T = self.data.iloc[index]["num_frames"]
|
T = self.data.iloc[index]["num_frames"]
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,28 @@ from collections import OrderedDict, defaultdict
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from typing import Iterator, List, Optional
|
from typing import Iterator, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from pandarallel import pandarallel
|
||||||
from torch.utils.data import DistributedSampler
|
from torch.utils.data import DistributedSampler
|
||||||
|
|
||||||
from .bucket import Bucket
|
from .bucket import Bucket
|
||||||
from .datasets import VariableVideoTextDataset
|
from .datasets import VariableVideoTextDataset
|
||||||
|
|
||||||
|
|
||||||
|
# HACK: use pandarallel
|
||||||
|
# pandarallel should only access local variables
|
||||||
|
def apply(data, method=None, frame_interval=None, seed=None, num_bucket=None):
|
||||||
|
return method(
|
||||||
|
data["num_frames"],
|
||||||
|
data["height"],
|
||||||
|
data["width"],
|
||||||
|
frame_interval,
|
||||||
|
seed + data["id"] * num_bucket,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VariableVideoBatchSampler(DistributedSampler):
|
class VariableVideoBatchSampler(DistributedSampler):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -22,6 +36,7 @@ class VariableVideoBatchSampler(DistributedSampler):
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
drop_last: bool = False,
|
drop_last: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
num_bucket_build_workers: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last
|
dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last
|
||||||
|
|
@ -32,48 +47,56 @@ class VariableVideoBatchSampler(DistributedSampler):
|
||||||
self.last_micro_batch_access_index = 0
|
self.last_micro_batch_access_index = 0
|
||||||
self.approximate_num_batch = None
|
self.approximate_num_batch = None
|
||||||
|
|
||||||
def get_num_batch(self) -> int:
|
self._get_num_batch_cached_bucket_sample_dict = None
|
||||||
g = torch.Generator()
|
self.num_bucket_build_workers = num_bucket_build_workers
|
||||||
g.manual_seed(self.seed + self.epoch)
|
|
||||||
|
def group_by_bucket(self) -> dict:
|
||||||
bucket_sample_dict = OrderedDict()
|
bucket_sample_dict = OrderedDict()
|
||||||
|
|
||||||
|
pandarallel.initialize(progress_bar=True, nb_workers=self.num_bucket_build_workers)
|
||||||
|
bucket_ids = self.dataset.data.parallel_apply(
|
||||||
|
apply,
|
||||||
|
axis=1,
|
||||||
|
method=self.bucket.get_bucket_id,
|
||||||
|
frame_interval=self.dataset.frame_interval,
|
||||||
|
seed=self.seed + self.epoch,
|
||||||
|
num_bucket=self.bucket.num_bucket,
|
||||||
|
)
|
||||||
|
|
||||||
# group by bucket
|
# group by bucket
|
||||||
# each data sample is put into a bucket with a similar image/video size
|
# each data sample is put into a bucket with a similar image/video size
|
||||||
for i in range(len(self.dataset)):
|
for i in range(len(self.dataset)):
|
||||||
t, h, w = self.dataset.get_data_info(i)
|
bucket_id = bucket_ids[i]
|
||||||
bucket_id = self.bucket.get_bucket_id(t, h, w, self.dataset.frame_interval, g)
|
|
||||||
if bucket_id is None:
|
if bucket_id is None:
|
||||||
continue
|
continue
|
||||||
if bucket_id not in bucket_sample_dict:
|
if bucket_id not in bucket_sample_dict:
|
||||||
bucket_sample_dict[bucket_id] = []
|
bucket_sample_dict[bucket_id] = []
|
||||||
bucket_sample_dict[bucket_id].append(i)
|
bucket_sample_dict[bucket_id].append(i)
|
||||||
|
return bucket_sample_dict
|
||||||
|
|
||||||
|
def get_num_batch(self) -> int:
|
||||||
|
bucket_sample_dict = self.group_by_bucket()
|
||||||
|
self._get_num_batch_cached_bucket_sample_dict = bucket_sample_dict
|
||||||
|
|
||||||
# calculate the number of batches
|
# calculate the number of batches
|
||||||
|
if self.verbose:
|
||||||
self._print_bucket_info(bucket_sample_dict)
|
self._print_bucket_info(bucket_sample_dict)
|
||||||
return self.approximate_num_batch
|
return self.approximate_num_batch
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[List[int]]:
|
def __iter__(self) -> Iterator[List[int]]:
|
||||||
g = torch.Generator()
|
if self._get_num_batch_cached_bucket_sample_dict is not None:
|
||||||
g.manual_seed(self.seed + self.epoch)
|
bucket_sample_dict = self._get_num_batch_cached_bucket_sample_dict
|
||||||
bucket_sample_dict = OrderedDict()
|
self._get_num_batch_cached_bucket_sample_dict = None
|
||||||
bucket_micro_batch_count = OrderedDict()
|
else:
|
||||||
bucket_last_consumed = OrderedDict()
|
bucket_sample_dict = self.group_by_bucket()
|
||||||
|
|
||||||
# group by bucket
|
|
||||||
# each data sample is put into a bucket with a similar image/video size
|
|
||||||
for i in range(len(self.dataset)):
|
|
||||||
t, h, w = self.dataset.get_data_info(i)
|
|
||||||
bucket_id = self.bucket.get_bucket_id(t, h, w, self.dataset.frame_interval, g)
|
|
||||||
if bucket_id is None:
|
|
||||||
continue
|
|
||||||
if bucket_id not in bucket_sample_dict:
|
|
||||||
bucket_sample_dict[bucket_id] = []
|
|
||||||
bucket_sample_dict[bucket_id].append(i)
|
|
||||||
|
|
||||||
# print bucket info
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self._print_bucket_info(bucket_sample_dict)
|
self._print_bucket_info(bucket_sample_dict)
|
||||||
|
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(self.seed + self.epoch)
|
||||||
|
bucket_micro_batch_count = OrderedDict()
|
||||||
|
bucket_last_consumed = OrderedDict()
|
||||||
|
|
||||||
# process the samples
|
# process the samples
|
||||||
for bucket_id, data_list in bucket_sample_dict.items():
|
for bucket_id, data_list in bucket_sample_dict.items():
|
||||||
# handle droplast
|
# handle droplast
|
||||||
|
|
|
||||||
|
|
@ -103,6 +103,8 @@ def merge_args(cfg, args, training=False):
|
||||||
cfg["bucket_config"] = None
|
cfg["bucket_config"] = None
|
||||||
if "transform_name" not in cfg.dataset:
|
if "transform_name" not in cfg.dataset:
|
||||||
cfg.dataset["transform_name"] = "center"
|
cfg.dataset["transform_name"] = "center"
|
||||||
|
if "num_bucket_build_workers" not in cfg:
|
||||||
|
cfg["num_bucket_build_workers"] = 1
|
||||||
|
|
||||||
# Both training and inference
|
# Both training and inference
|
||||||
if "multi_resolution" not in cfg:
|
if "multi_resolution" not in cfg:
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from pprint import pprint
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import wandb
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
|
@ -11,7 +12,6 @@ from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.utils import get_current_device, set_seed
|
from colossalai.utils import get_current_device, set_seed
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import wandb
|
|
||||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||||
from opensora.acceleration.parallel_states import (
|
from opensora.acceleration.parallel_states import (
|
||||||
get_data_parallel_group,
|
get_data_parallel_group,
|
||||||
|
|
@ -95,6 +95,7 @@ def main():
|
||||||
# 3. build dataset and dataloader
|
# 3. build dataset and dataloader
|
||||||
# ======================================================
|
# ======================================================
|
||||||
dataset = build_module(cfg.dataset, DATASETS)
|
dataset = build_module(cfg.dataset, DATASETS)
|
||||||
|
logger.info(f"Dataset contains {len(dataset)} samples.")
|
||||||
dataloader_args = dict(
|
dataloader_args = dict(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
batch_size=cfg.batch_size,
|
batch_size=cfg.batch_size,
|
||||||
|
|
@ -109,7 +110,11 @@ def main():
|
||||||
if cfg.bucket_config is None:
|
if cfg.bucket_config is None:
|
||||||
dataloader = prepare_dataloader(**dataloader_args)
|
dataloader = prepare_dataloader(**dataloader_args)
|
||||||
else:
|
else:
|
||||||
dataloader = prepare_variable_dataloader(bucket_config=cfg.bucket_config, **dataloader_args)
|
dataloader = prepare_variable_dataloader(
|
||||||
|
bucket_config=cfg.bucket_config,
|
||||||
|
num_bucket_build_workers=cfg.num_bucket_build_workers,
|
||||||
|
**dataloader_args,
|
||||||
|
)
|
||||||
if cfg.dataset.type == "VideoTextDataset":
|
if cfg.dataset.type == "VideoTextDataset":
|
||||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
|
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
|
||||||
logger.info(f"Total batch size: {total_batch_size}")
|
logger.info(f"Total batch size: {total_batch_size}")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue