Open-Sora/opensora/datasets/sampler.py
2024-04-22 12:05:59 +08:00

242 lines
9.8 KiB
Python

import warnings
from collections import OrderedDict, defaultdict
from pprint import pprint
from typing import Iterator, List, Optional
import torch
import torch.distributed as dist
from pandarallel import pandarallel
from torch.utils.data import DistributedSampler
from .bucket import Bucket
from .datasets import VariableVideoTextDataset
# HACK: use pandarallel
# pandarallel should only access local variables
def apply(data, method=None, frame_interval=None, seed=None, num_bucket=None):
return method(
data["num_frames"],
data["height"],
data["width"],
frame_interval,
seed + data["id"] * num_bucket,
)
class VariableVideoBatchSampler(DistributedSampler):
def __init__(
self,
dataset: VariableVideoTextDataset,
bucket_config: dict,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
verbose: bool = False,
num_bucket_build_workers: int = 1,
) -> None:
super().__init__(
dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last
)
self.dataset = dataset
self.bucket = Bucket(bucket_config)
self.verbose = verbose
self.last_micro_batch_access_index = 0
self.approximate_num_batch = None
self._get_num_batch_cached_bucket_sample_dict = None
self.num_bucket_build_workers = num_bucket_build_workers
def group_by_bucket(self) -> dict:
bucket_sample_dict = OrderedDict()
pandarallel.initialize(nb_workers=self.num_bucket_build_workers, progress_bar=False)
bucket_ids = self.dataset.data.parallel_apply(
apply,
axis=1,
method=self.bucket.get_bucket_id,
frame_interval=self.dataset.frame_interval,
seed=self.seed + self.epoch,
num_bucket=self.bucket.num_bucket,
)
# group by bucket
# each data sample is put into a bucket with a similar image/video size
for i in range(len(self.dataset)):
bucket_id = bucket_ids[i]
if bucket_id is None:
continue
if bucket_id not in bucket_sample_dict:
bucket_sample_dict[bucket_id] = []
bucket_sample_dict[bucket_id].append(i)
return bucket_sample_dict
def get_num_batch(self) -> int:
bucket_sample_dict = self.group_by_bucket()
self._get_num_batch_cached_bucket_sample_dict = bucket_sample_dict
# calculate the number of batches
if self.verbose:
self._print_bucket_info(bucket_sample_dict)
return self.approximate_num_batch
def __iter__(self) -> Iterator[List[int]]:
if self._get_num_batch_cached_bucket_sample_dict is not None:
bucket_sample_dict = self._get_num_batch_cached_bucket_sample_dict
self._get_num_batch_cached_bucket_sample_dict = None
else:
bucket_sample_dict = self.group_by_bucket()
if self.verbose:
self._print_bucket_info(bucket_sample_dict)
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
bucket_micro_batch_count = OrderedDict()
bucket_last_consumed = OrderedDict()
# process the samples
for bucket_id, data_list in bucket_sample_dict.items():
# handle droplast
bs_per_gpu = self.bucket.get_batch_size(bucket_id)
remainder = len(data_list) % bs_per_gpu
if remainder > 0:
if not self.drop_last:
# if there is remainder, we pad to make it divisible
data_list += data_list[: bs_per_gpu - remainder]
else:
# we just drop the remainder to make it divisible
data_list = data_list[:-remainder]
bucket_sample_dict[bucket_id] = data_list
# handle shuffle
if self.shuffle:
data_indices = torch.randperm(len(data_list), generator=g).tolist()
data_list = [data_list[i] for i in data_indices]
bucket_sample_dict[bucket_id] = data_list
# compute how many micro-batches each bucket has
num_micro_batches = len(data_list) // bs_per_gpu
bucket_micro_batch_count[bucket_id] = num_micro_batches
# compute the bucket access order
# each bucket may have more than one batch of data
# thus bucket_id may appear more than 1 time
bucket_id_access_order = []
for bucket_id, num_micro_batch in bucket_micro_batch_count.items():
bucket_id_access_order.extend([bucket_id] * num_micro_batch)
# randomize the access order
if self.shuffle:
bucket_id_access_order_indices = torch.randperm(len(bucket_id_access_order), generator=g).tolist()
bucket_id_access_order = [bucket_id_access_order[i] for i in bucket_id_access_order_indices]
# make the number of bucket accesses divisible by dp size
remainder = len(bucket_id_access_order) % self.num_replicas
if remainder > 0:
if self.drop_last:
bucket_id_access_order = bucket_id_access_order[: len(bucket_id_access_order) - remainder]
else:
bucket_id_access_order += bucket_id_access_order[: self.num_replicas - remainder]
# prepare each batch from its bucket
# according to the predefined bucket access order
num_iters = len(bucket_id_access_order) // self.num_replicas
start_iter_idx = self.last_micro_batch_access_index // self.num_replicas
# re-compute the micro-batch consumption
# this is useful when resuming from a state dict with a different number of GPUs
self.last_micro_batch_access_index = start_iter_idx * self.num_replicas
for i in range(self.last_micro_batch_access_index):
bucket_id = bucket_id_access_order[i]
bucket_bs = self.bucket.get_batch_size(bucket_id)
if bucket_id in bucket_last_consumed:
bucket_last_consumed[bucket_id] += bucket_bs
else:
bucket_last_consumed[bucket_id] = bucket_bs
for i in range(start_iter_idx, num_iters):
bucket_access_list = bucket_id_access_order[i * self.num_replicas : (i + 1) * self.num_replicas]
self.last_micro_batch_access_index += self.num_replicas
# comppute the data samples consumed by each access
bucket_access_boundaries = []
for bucket_id in bucket_access_list:
bucket_bs = self.bucket.get_batch_size(bucket_id)
last_consumed_index = bucket_last_consumed.get(bucket_id, 0)
bucket_access_boundaries.append([last_consumed_index, last_consumed_index + bucket_bs])
# update consumption
if bucket_id in bucket_last_consumed:
bucket_last_consumed[bucket_id] += bucket_bs
else:
bucket_last_consumed[bucket_id] = bucket_bs
# compute the range of data accessed by each GPU
bucket_id = bucket_access_list[self.rank]
boundary = bucket_access_boundaries[self.rank]
cur_micro_batch = bucket_sample_dict[bucket_id][boundary[0] : boundary[1]]
# encode t, h, w into the sample index
real_t, real_h, real_w = self.bucket.get_thw(bucket_id)
cur_micro_batch = [f"{idx}-{real_t}-{real_h}-{real_w}" for idx in cur_micro_batch]
yield cur_micro_batch
self._reset()
def _reset(self):
self.last_micro_batch_access_index = 0
def state_dict(self, num_steps: int) -> dict:
# the last_micro_batch_access_index in the __iter__ is often
# not accurate during multi-workers and data prefetching
# thus, we need the user to pass the actual steps which have been executed
# to calculate the correct last_micro_batch_access_index
return {"seed": self.seed, "epoch": self.epoch, "last_micro_batch_access_index": num_steps * self.num_replicas}
def load_state_dict(self, state_dict: dict) -> None:
self.__dict__.update(state_dict)
def _print_bucket_info(self, bucket_sample_dict: dict, verbose=True) -> None:
total_samples = 0
num_batch = 0
num_dict = {}
num_aspect_dict = defaultdict(int)
num_hwt_dict = defaultdict(int)
for k, v in bucket_sample_dict.items():
size = len(v)
total_samples += size
num_dict[k] = size
num_aspect_dict[k[-1]] += size
num_hwt_dict[k[:-1]] += size
num_batch += size // self.bucket.get_batch_size(k[:-1])
if dist.get_rank() == 0 and verbose:
print(f"Total training samples: {total_samples}, num buckets: {len(num_dict)}")
print("Bucket samples:")
pprint(num_dict)
print("Bucket samples by aspect ratio:")
pprint(num_aspect_dict)
print("Bucket samples by HxWxT:")
pprint(num_hwt_dict)
print(f"Number of batches: {num_batch}")
self.approximate_num_batch = num_batch
def set_epoch(self, epoch: int) -> None:
super().set_epoch(epoch)
def __len__(self) -> int:
warnings.warn(
"The length of VariableVideoBatchSampler is dynamic and may not be accurate. Return the max value."
)
min_batch_size = None
for v in self.bucket.bucket_bs.values():
for bs in v.values():
if bs is not None and (min_batch_size is None or bs < min_batch_size):
min_batch_size = bs
if self.drop_last:
return len(self.dataset) // min_batch_size
else:
return (len(self.dataset) + min_batch_size - 1) // min_batch_size