2024-03-28 10:58:36 +01:00
|
|
|
import random
|
2024-05-09 10:07:56 +02:00
|
|
|
from typing import Optional
|
2024-03-28 10:58:36 +01:00
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
from torch.distributed.distributed_c10d import _get_default_group
|
2024-05-09 10:07:56 +02:00
|
|
|
from torch.utils.data import DataLoader
|
2024-03-28 10:58:36 +01:00
|
|
|
|
2024-05-20 10:40:45 +02:00
|
|
|
from .datasets import VariableVideoTextDataset, VideoTextDataset
|
2024-05-09 10:07:56 +02:00
|
|
|
from .sampler import StatefulDistributedSampler, VariableVideoBatchSampler
|
2024-03-28 10:58:36 +01:00
|
|
|
|
|
|
|
|
|
2024-05-20 10:40:45 +02:00
|
|
|
# Deterministic dataloader
|
|
|
|
|
def get_seed_worker(seed):
|
|
|
|
|
def seed_worker(worker_id):
|
|
|
|
|
worker_seed = seed
|
|
|
|
|
np.random.seed(worker_seed)
|
|
|
|
|
torch.manual_seed(worker_seed)
|
|
|
|
|
random.seed(worker_seed)
|
|
|
|
|
|
|
|
|
|
return seed_worker
|
|
|
|
|
|
|
|
|
|
|
2024-03-28 10:58:36 +01:00
|
|
|
def prepare_dataloader(
|
|
|
|
|
dataset,
|
2024-05-20 10:40:45 +02:00
|
|
|
batch_size=None,
|
2024-03-28 10:58:36 +01:00
|
|
|
shuffle=False,
|
|
|
|
|
seed=1024,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
pin_memory=False,
|
|
|
|
|
num_workers=0,
|
|
|
|
|
process_group: Optional[ProcessGroup] = None,
|
2024-05-20 10:40:45 +02:00
|
|
|
bucket_config=None,
|
|
|
|
|
num_bucket_build_workers=1,
|
2024-03-28 10:58:36 +01:00
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
_kwargs = kwargs.copy()
|
2024-05-20 10:40:45 +02:00
|
|
|
if isinstance(dataset, VariableVideoTextDataset):
|
|
|
|
|
batch_sampler = VariableVideoBatchSampler(
|
|
|
|
|
dataset,
|
|
|
|
|
bucket_config,
|
|
|
|
|
num_replicas=process_group.size(),
|
|
|
|
|
rank=process_group.rank(),
|
|
|
|
|
shuffle=shuffle,
|
|
|
|
|
seed=seed,
|
|
|
|
|
drop_last=drop_last,
|
|
|
|
|
verbose=True,
|
|
|
|
|
num_bucket_build_workers=num_bucket_build_workers,
|
|
|
|
|
)
|
|
|
|
|
return (
|
|
|
|
|
DataLoader(
|
|
|
|
|
dataset,
|
|
|
|
|
batch_sampler=batch_sampler,
|
|
|
|
|
worker_init_fn=get_seed_worker(seed),
|
|
|
|
|
pin_memory=pin_memory,
|
|
|
|
|
num_workers=num_workers,
|
|
|
|
|
**_kwargs,
|
|
|
|
|
),
|
|
|
|
|
batch_sampler,
|
|
|
|
|
)
|
|
|
|
|
elif isinstance(dataset, VideoTextDataset):
|
2024-05-14 07:40:17 +02:00
|
|
|
process_group = process_group or _get_default_group()
|
|
|
|
|
sampler = StatefulDistributedSampler(
|
|
|
|
|
dataset,
|
|
|
|
|
num_replicas=process_group.size(),
|
|
|
|
|
rank=process_group.rank(),
|
|
|
|
|
shuffle=shuffle,
|
|
|
|
|
)
|
2024-05-20 10:40:45 +02:00
|
|
|
return (
|
|
|
|
|
DataLoader(
|
|
|
|
|
dataset,
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
sampler=sampler,
|
|
|
|
|
worker_init_fn=get_seed_worker(seed),
|
|
|
|
|
drop_last=drop_last,
|
|
|
|
|
pin_memory=pin_memory,
|
|
|
|
|
num_workers=num_workers,
|
|
|
|
|
**_kwargs,
|
|
|
|
|
),
|
|
|
|
|
sampler,
|
|
|
|
|
)
|
2024-05-14 07:40:17 +02:00
|
|
|
else:
|
2024-05-20 10:40:45 +02:00
|
|
|
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
|