Merge branch 'dev/v1.0.1' of github.com:hpcaitech/Open-Sora-dev into dev/v1.0.1

This commit is contained in:
zhengzangw 2024-04-19 03:44:02 +00:00
commit 947a313eb5
7 changed files with 98 additions and 66 deletions

View file

@ -12,7 +12,7 @@ bucket_config = { # 6s/it
"256": {1: (1.0, 254)}, "256": {1: (1.0, 254)},
"512": {1: (0.5, 86)}, "512": {1: (0.5, 86)},
"480p": {1: (0.4, 54), 16: (0.4, 4), 32: (0.0, None)}, "480p": {1: (0.4, 54), 16: (0.4, 4), 32: (0.0, None)},
"720p": {16: (0.1, 2), 32: (0.0, None)}, # No examples now "720p": {16: (0.1, 2), 32: (0.0, None)},
"1024": {1: (0.3, 20)}, "1024": {1: (0.3, 20)},
"1080p": {1: (0.4, 8)}, "1080p": {1: (0.4, 8)},
} }
@ -30,6 +30,7 @@ mask_ratios = {
# Define acceleration # Define acceleration
num_workers = 4 num_workers = 4
num_bucket_build_workers = 16
dtype = "bf16" dtype = "bf16"
grad_checkpoint = True grad_checkpoint = True
plugin = "zero2" plugin = "zero2"

View file

@ -1,8 +1,7 @@
from collections import OrderedDict from collections import OrderedDict
import torch
from .aspect import ASPECT_RATIOS, get_closest_ratio from .aspect import ASPECT_RATIOS, get_closest_ratio
import numpy as np
def find_approximate_hw(hw, hw_dict, approx=0.8): def find_approximate_hw(hw, hw_dict, approx=0.8):
@ -44,12 +43,17 @@ class Bucket:
hw_criteria = dict() hw_criteria = dict()
t_criteria = dict() t_criteria = dict()
ar_criteria = dict() ar_criteria = dict()
bucket_id = OrderedDict()
bucket_id_cnt = 0
for k1, v1 in bucket_probs.items(): for k1, v1 in bucket_probs.items():
hw_criteria[k1] = ASPECT_RATIOS[k1][0] hw_criteria[k1] = ASPECT_RATIOS[k1][0]
t_criteria[k1] = dict() t_criteria[k1] = dict()
ar_criteria[k1] = dict() ar_criteria[k1] = dict()
bucket_id[k1] = dict()
for k2, _ in v1.items(): for k2, _ in v1.items():
t_criteria[k1][k2] = k2 t_criteria[k1][k2] = k2
bucket_id[k1][k2] = bucket_id_cnt
bucket_id_cnt += 1
ar_criteria[k1][k2] = dict() ar_criteria[k1][k2] = dict()
for k3, v3 in ASPECT_RATIOS[k1][1].items(): for k3, v3 in ASPECT_RATIOS[k1][1].items():
ar_criteria[k1][k2][k3] = v3 ar_criteria[k1][k2][k3] = v3
@ -57,58 +61,52 @@ class Bucket:
self.bucket_probs = bucket_probs self.bucket_probs = bucket_probs
self.bucket_bs = bucket_bs self.bucket_bs = bucket_bs
self.bucket_id = bucket_id
self.hw_criteria = hw_criteria self.hw_criteria = hw_criteria
self.t_criteria = t_criteria self.t_criteria = t_criteria
self.ar_criteria = ar_criteria self.ar_criteria = ar_criteria
self.num_bucket = num_bucket self.num_bucket = num_bucket
print(f"Number of buckets: {num_bucket}") print(f"Number of buckets: {num_bucket}")
def info_bucket(self, dataset, frame_interval=1): def get_bucket_id(self, T, H, W, frame_interval=1, seed=None):
infos = dict() resolution = H * W
infos_ar = dict() approx = 0.8
for i in range(len(dataset)):
T, H, W = dataset.get_data_info(i) fail = True
bucket_id = self.get_bucket_id(T, H, W, frame_interval) for hw_id, t_criteria in self.bucket_probs.items():
if bucket_id is None: if resolution < self.hw_criteria[hw_id] * approx:
continue continue
if f"{(bucket_id[0], bucket_id[1])}" not in infos:
infos[f"{(bucket_id[0], bucket_id[1])}"] = 0
if f"{bucket_id[2]}" not in infos_ar:
infos_ar[f"{bucket_id[2]}"] = 0
infos[f"{(bucket_id[0], bucket_id[1])}"] += 1
infos_ar[f"{bucket_id[2]}"] += 1
print(f"Dataset contains {len(dataset)} samples.")
print("Bucket info:", infos)
print("Aspect ratio info:", infos_ar)
def get_bucket_id(self, T, H, W, frame_interval=1, generator=None): # if sample is an image
# hw if T == 1:
hw = H * W if 1 in t_criteria:
hw_id = find_approximate_hw(hw, self.hw_criteria) fail = False
if hw_id is None: t_id = 1
return None
hw_id_index = list(self.hw_criteria.keys()).index(hw_id)
# hw drops by probablity
while True:
# T
T_id = find_closet_smaller_bucket(T, self.t_criteria[hw_id], frame_interval)
if T_id is not None:
prob = self.get_prob((hw_id, T_id))
if torch.rand(1, generator=generator).item() < prob:
break break
hw_id_index += 1 else:
if hw_id_index > len(self.hw_criteria) - 1: continue
break
hw_id = list(self.hw_criteria.keys())[hw_id_index]
if T_id is None or hw_id_index > len(self.hw_criteria) - 1: # otherwise, find suitable t_id for video
t_fail = True
for t_id, prob in t_criteria.items():
if T > t_id * frame_interval and t_id != 1:
t_fail = False
break
if t_fail:
continue
# leave the loop if prob is high enough
rng = np.random.default_rng(seed + self.bucket_id[hw_id][t_id])
if prob == 1 or rng.random() < prob:
fail = False
break
if fail:
return None return None
# ar # get aspect ratio id
ar_criteria = self.ar_criteria[hw_id][T_id] ar_criteria = self.ar_criteria[hw_id][t_id]
ar_id = get_closest_ratio(H, W, ar_criteria) ar_id = get_closest_ratio(H, W, ar_criteria)
return hw_id, T_id, ar_id return hw_id, t_id, ar_id
def get_thw(self, bucket_id): def get_thw(self, bucket_id):
assert len(bucket_id) == 3 assert len(bucket_id) == 3

View file

@ -107,6 +107,7 @@ def prepare_variable_dataloader(
pin_memory=False, pin_memory=False,
num_workers=0, num_workers=0,
process_group=None, process_group=None,
num_bucket_build_workers=1,
**kwargs, **kwargs,
): ):
_kwargs = kwargs.copy() _kwargs = kwargs.copy()
@ -120,6 +121,7 @@ def prepare_variable_dataloader(
seed=seed, seed=seed,
drop_last=drop_last, drop_last=drop_last,
verbose=True, verbose=True,
num_bucket_build_workers=num_bucket_build_workers,
) )
# Deterministic dataloader # Deterministic dataloader

View file

@ -113,6 +113,7 @@ class VariableVideoTextDataset(VideoTextDataset):
): ):
super().__init__(data_path, num_frames, frame_interval, image_size, transform_name=None) super().__init__(data_path, num_frames, frame_interval, image_size, transform_name=None)
self.transform_name = transform_name self.transform_name = transform_name
self.data["id"] = np.arange(len(self.data))
def get_data_info(self, index): def get_data_info(self, index):
T = self.data.iloc[index]["num_frames"] T = self.data.iloc[index]["num_frames"]

View file

@ -3,14 +3,28 @@ from collections import OrderedDict, defaultdict
from pprint import pprint from pprint import pprint
from typing import Iterator, List, Optional from typing import Iterator, List, Optional
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from pandarallel import pandarallel
from torch.utils.data import DistributedSampler from torch.utils.data import DistributedSampler
from .bucket import Bucket from .bucket import Bucket
from .datasets import VariableVideoTextDataset from .datasets import VariableVideoTextDataset
# HACK: use pandarallel
# pandarallel should only access local variables
def apply(data, method=None, frame_interval=None, seed=None, num_bucket=None):
return method(
data["num_frames"],
data["height"],
data["width"],
frame_interval,
seed + data["id"] * num_bucket,
)
class VariableVideoBatchSampler(DistributedSampler): class VariableVideoBatchSampler(DistributedSampler):
def __init__( def __init__(
self, self,
@ -22,6 +36,7 @@ class VariableVideoBatchSampler(DistributedSampler):
seed: int = 0, seed: int = 0,
drop_last: bool = False, drop_last: bool = False,
verbose: bool = False, verbose: bool = False,
num_bucket_build_workers: int = 1,
) -> None: ) -> None:
super().__init__( super().__init__(
dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last
@ -32,48 +47,56 @@ class VariableVideoBatchSampler(DistributedSampler):
self.last_micro_batch_access_index = 0 self.last_micro_batch_access_index = 0
self.approximate_num_batch = None self.approximate_num_batch = None
def get_num_batch(self) -> int: self._get_num_batch_cached_bucket_sample_dict = None
g = torch.Generator() self.num_bucket_build_workers = num_bucket_build_workers
g.manual_seed(self.seed + self.epoch)
def group_by_bucket(self) -> dict:
bucket_sample_dict = OrderedDict() bucket_sample_dict = OrderedDict()
pandarallel.initialize(progress_bar=True, nb_workers=self.num_bucket_build_workers)
bucket_ids = self.dataset.data.parallel_apply(
apply,
axis=1,
method=self.bucket.get_bucket_id,
frame_interval=self.dataset.frame_interval,
seed=self.seed + self.epoch,
num_bucket=self.bucket.num_bucket,
)
# group by bucket # group by bucket
# each data sample is put into a bucket with a similar image/video size # each data sample is put into a bucket with a similar image/video size
for i in range(len(self.dataset)): for i in range(len(self.dataset)):
t, h, w = self.dataset.get_data_info(i) bucket_id = bucket_ids[i]
bucket_id = self.bucket.get_bucket_id(t, h, w, self.dataset.frame_interval, g)
if bucket_id is None: if bucket_id is None:
continue continue
if bucket_id not in bucket_sample_dict: if bucket_id not in bucket_sample_dict:
bucket_sample_dict[bucket_id] = [] bucket_sample_dict[bucket_id] = []
bucket_sample_dict[bucket_id].append(i) bucket_sample_dict[bucket_id].append(i)
return bucket_sample_dict
def get_num_batch(self) -> int:
bucket_sample_dict = self.group_by_bucket()
self._get_num_batch_cached_bucket_sample_dict = bucket_sample_dict
# calculate the number of batches # calculate the number of batches
if self.verbose:
self._print_bucket_info(bucket_sample_dict) self._print_bucket_info(bucket_sample_dict)
return self.approximate_num_batch return self.approximate_num_batch
def __iter__(self) -> Iterator[List[int]]: def __iter__(self) -> Iterator[List[int]]:
g = torch.Generator() if self._get_num_batch_cached_bucket_sample_dict is not None:
g.manual_seed(self.seed + self.epoch) bucket_sample_dict = self._get_num_batch_cached_bucket_sample_dict
bucket_sample_dict = OrderedDict() self._get_num_batch_cached_bucket_sample_dict = None
bucket_micro_batch_count = OrderedDict() else:
bucket_last_consumed = OrderedDict() bucket_sample_dict = self.group_by_bucket()
# group by bucket
# each data sample is put into a bucket with a similar image/video size
for i in range(len(self.dataset)):
t, h, w = self.dataset.get_data_info(i)
bucket_id = self.bucket.get_bucket_id(t, h, w, self.dataset.frame_interval, g)
if bucket_id is None:
continue
if bucket_id not in bucket_sample_dict:
bucket_sample_dict[bucket_id] = []
bucket_sample_dict[bucket_id].append(i)
# print bucket info
if self.verbose: if self.verbose:
self._print_bucket_info(bucket_sample_dict) self._print_bucket_info(bucket_sample_dict)
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
bucket_micro_batch_count = OrderedDict()
bucket_last_consumed = OrderedDict()
# process the samples # process the samples
for bucket_id, data_list in bucket_sample_dict.items(): for bucket_id, data_list in bucket_sample_dict.items():
# handle droplast # handle droplast

View file

@ -103,6 +103,8 @@ def merge_args(cfg, args, training=False):
cfg["bucket_config"] = None cfg["bucket_config"] = None
if "transform_name" not in cfg.dataset: if "transform_name" not in cfg.dataset:
cfg.dataset["transform_name"] = "center" cfg.dataset["transform_name"] = "center"
if "num_bucket_build_workers" not in cfg:
cfg["num_bucket_build_workers"] = 1
# Both training and inference # Both training and inference
if "multi_resolution" not in cfg: if "multi_resolution" not in cfg:

View file

@ -4,6 +4,7 @@ from pprint import pprint
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import wandb
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
@ -11,7 +12,6 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed from colossalai.utils import get_current_device, set_seed
from tqdm import tqdm from tqdm import tqdm
import wandb
from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import ( from opensora.acceleration.parallel_states import (
get_data_parallel_group, get_data_parallel_group,
@ -95,6 +95,7 @@ def main():
# 3. build dataset and dataloader # 3. build dataset and dataloader
# ====================================================== # ======================================================
dataset = build_module(cfg.dataset, DATASETS) dataset = build_module(cfg.dataset, DATASETS)
logger.info(f"Dataset contains {len(dataset)} samples.")
dataloader_args = dict( dataloader_args = dict(
dataset=dataset, dataset=dataset,
batch_size=cfg.batch_size, batch_size=cfg.batch_size,
@ -109,7 +110,11 @@ def main():
if cfg.bucket_config is None: if cfg.bucket_config is None:
dataloader = prepare_dataloader(**dataloader_args) dataloader = prepare_dataloader(**dataloader_args)
else: else:
dataloader = prepare_variable_dataloader(bucket_config=cfg.bucket_config, **dataloader_args) dataloader = prepare_variable_dataloader(
bucket_config=cfg.bucket_config,
num_bucket_build_workers=cfg.num_bucket_build_workers,
**dataloader_args,
)
if cfg.dataset.type == "VideoTextDataset": if cfg.dataset.type == "VideoTextDataset":
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
logger.info(f"Total batch size: {total_batch_size}") logger.info(f"Total batch size: {total_batch_size}")