Open-Sora/opensora/datasets/dataloader.py

150 lines
4.3 KiB
Python

import collections
import random
from typing import Optional
import numpy as np
import torch
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import DataLoader
from .datasets import BatchFeatureDataset, VariableVideoTextDataset, VideoTextDataset
from .sampler import BatchDistributedSampler, StatefulDistributedSampler, VariableVideoBatchSampler
# Deterministic dataloader
def get_seed_worker(seed):
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return seed_worker
def prepare_dataloader(
dataset,
batch_size=None,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
process_group: Optional[ProcessGroup] = None,
bucket_config=None,
num_bucket_build_workers=1,
prefetch_factor=None,
**kwargs,
):
_kwargs = kwargs.copy()
if isinstance(dataset, VariableVideoTextDataset):
batch_sampler = VariableVideoBatchSampler(
dataset,
bucket_config,
num_replicas=process_group.size(),
rank=process_group.rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
verbose=True,
num_bucket_build_workers=num_bucket_build_workers,
)
return (
DataLoader(
dataset,
batch_sampler=batch_sampler,
worker_init_fn=get_seed_worker(seed),
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_default,
prefetch_factor=prefetch_factor,
**_kwargs,
),
batch_sampler,
)
elif isinstance(dataset, VideoTextDataset):
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(
dataset,
num_replicas=process_group.size(),
rank=process_group.rank(),
shuffle=shuffle,
)
return (
DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=get_seed_worker(seed),
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_default,
prefetch_factor=prefetch_factor,
**_kwargs,
),
sampler,
)
elif isinstance(dataset, BatchFeatureDataset):
sampler = BatchDistributedSampler(
dataset,
num_replicas=process_group.size(),
rank=process_group.rank(),
)
return (
DataLoader(
dataset,
batch_size=1,
sampler=sampler,
worker_init_fn=get_seed_worker(seed),
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_batch,
prefetch_factor=prefetch_factor,
**_kwargs,
),
sampler,
)
else:
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
def collate_fn_default(batch):
# HACK: for loading text features
use_mask = False
if "mask" in batch[0] and isinstance(batch[0]["mask"], int):
masks = [x.pop("mask") for x in batch]
texts = [x.pop("text") for x in batch]
texts = torch.cat(texts, dim=1)
use_mask = True
ret = torch.utils.data.default_collate(batch)
if use_mask:
ret["mask"] = masks
ret["text"] = texts
return ret
def collate_fn_batch(batch):
"""
Used only with BatchDistributedSampler
"""
res = torch.utils.data.default_collate(batch)
# squeeze the first dimension, which is due to torch.stack() in default_collate()
if isinstance(res, collections.abc.Mapping):
for k, v in res.items():
if isinstance(v, torch.Tensor):
res[k] = v.squeeze(0)
elif isinstance(res, collections.abc.Sequence):
res = [x.squeeze(0) if isinstance(x, torch.Tensor) else x for x in res]
elif isinstance(res, torch.Tensor):
res = res.squeeze(0)
else:
raise TypeError
return res