2024-03-28 10:58:36 +01:00
|
|
|
import random
|
2024-05-09 10:07:56 +02:00
|
|
|
from typing import Optional
|
2024-03-28 10:58:36 +01:00
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
from torch.distributed.distributed_c10d import _get_default_group
|
2024-05-09 10:07:56 +02:00
|
|
|
from torch.utils.data import DataLoader
|
2024-03-28 10:58:36 +01:00
|
|
|
|
2024-05-20 10:40:45 +02:00
|
|
|
from .datasets import VariableVideoTextDataset, VideoTextDataset
|
2024-05-20 11:40:05 +02:00
|
|
|
from .sampler import StatefulDistributedSampler, VariableVideoBatchSampler, BatchDistributedSampler
|
2024-03-28 10:58:36 +01:00
|
|
|
|
|
|
|
|
|
2024-05-20 10:40:45 +02:00
|
|
|
# Deterministic dataloader
|
|
|
|
|
def get_seed_worker(seed):
|
|
|
|
|
def seed_worker(worker_id):
|
|
|
|
|
worker_seed = seed
|
|
|
|
|
np.random.seed(worker_seed)
|
|
|
|
|
torch.manual_seed(worker_seed)
|
|
|
|
|
random.seed(worker_seed)
|
|
|
|
|
|
|
|
|
|
return seed_worker
|
|
|
|
|
|
|
|
|
|
|
2024-03-28 10:58:36 +01:00
|
|
|
def prepare_dataloader(
|
|
|
|
|
dataset,
|
2024-05-20 10:40:45 +02:00
|
|
|
batch_size=None,
|
2024-03-28 10:58:36 +01:00
|
|
|
shuffle=False,
|
|
|
|
|
seed=1024,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
pin_memory=False,
|
|
|
|
|
num_workers=0,
|
|
|
|
|
process_group: Optional[ProcessGroup] = None,
|
2024-05-20 10:40:45 +02:00
|
|
|
bucket_config=None,
|
|
|
|
|
num_bucket_build_workers=1,
|
2024-03-28 10:58:36 +01:00
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
_kwargs = kwargs.copy()
|
2024-05-20 10:40:45 +02:00
|
|
|
if isinstance(dataset, VariableVideoTextDataset):
|
|
|
|
|
batch_sampler = VariableVideoBatchSampler(
|
|
|
|
|
dataset,
|
|
|
|
|
bucket_config,
|
|
|
|
|
num_replicas=process_group.size(),
|
|
|
|
|
rank=process_group.rank(),
|
|
|
|
|
shuffle=shuffle,
|
|
|
|
|
seed=seed,
|
|
|
|
|
drop_last=drop_last,
|
|
|
|
|
verbose=True,
|
|
|
|
|
num_bucket_build_workers=num_bucket_build_workers,
|
|
|
|
|
)
|
|
|
|
|
return (
|
|
|
|
|
DataLoader(
|
|
|
|
|
dataset,
|
|
|
|
|
batch_sampler=batch_sampler,
|
|
|
|
|
worker_init_fn=get_seed_worker(seed),
|
|
|
|
|
pin_memory=pin_memory,
|
|
|
|
|
num_workers=num_workers,
|
|
|
|
|
**_kwargs,
|
|
|
|
|
),
|
|
|
|
|
batch_sampler,
|
|
|
|
|
)
|
|
|
|
|
elif isinstance(dataset, VideoTextDataset):
|
2024-05-14 07:40:17 +02:00
|
|
|
process_group = process_group or _get_default_group()
|
|
|
|
|
sampler = StatefulDistributedSampler(
|
|
|
|
|
dataset,
|
|
|
|
|
num_replicas=process_group.size(),
|
|
|
|
|
rank=process_group.rank(),
|
|
|
|
|
shuffle=shuffle,
|
|
|
|
|
)
|
2024-05-20 10:40:45 +02:00
|
|
|
return (
|
|
|
|
|
DataLoader(
|
|
|
|
|
dataset,
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
sampler=sampler,
|
|
|
|
|
worker_init_fn=get_seed_worker(seed),
|
|
|
|
|
drop_last=drop_last,
|
|
|
|
|
pin_memory=pin_memory,
|
|
|
|
|
num_workers=num_workers,
|
|
|
|
|
**_kwargs,
|
|
|
|
|
),
|
|
|
|
|
sampler,
|
|
|
|
|
)
|
2024-05-14 07:40:17 +02:00
|
|
|
else:
|
2024-05-20 10:40:45 +02:00
|
|
|
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
|
2024-05-20 11:40:05 +02:00
|
|
|
|
|
|
|
|
def prepare_variable_dataloader(
|
|
|
|
|
dataset,
|
|
|
|
|
batch_size,
|
|
|
|
|
bucket_config,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
seed=1024,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
pin_memory=False,
|
|
|
|
|
num_workers=0,
|
|
|
|
|
process_group=None,
|
|
|
|
|
num_bucket_build_workers=1,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
_kwargs = kwargs.copy()
|
|
|
|
|
process_group = process_group or _get_default_group()
|
|
|
|
|
batch_sampler = VariableVideoBatchSampler(
|
|
|
|
|
dataset,
|
|
|
|
|
bucket_config,
|
|
|
|
|
num_replicas=process_group.size(),
|
|
|
|
|
rank=process_group.rank(),
|
|
|
|
|
shuffle=shuffle,
|
|
|
|
|
seed=seed,
|
|
|
|
|
drop_last=drop_last,
|
|
|
|
|
verbose=True,
|
|
|
|
|
num_bucket_build_workers=num_bucket_build_workers,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Deterministic dataloader
|
|
|
|
|
def seed_worker(worker_id):
|
|
|
|
|
worker_seed = seed
|
|
|
|
|
np.random.seed(worker_seed)
|
|
|
|
|
torch.manual_seed(worker_seed)
|
|
|
|
|
random.seed(worker_seed)
|
|
|
|
|
|
|
|
|
|
return torch.utils.data.DataLoader(
|
|
|
|
|
dataset,
|
|
|
|
|
batch_sampler=batch_sampler,
|
|
|
|
|
worker_init_fn=seed_worker,
|
|
|
|
|
pin_memory=pin_memory,
|
|
|
|
|
num_workers=num_workers,
|
|
|
|
|
**_kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_batch_dataloader(
|
|
|
|
|
dataset,
|
|
|
|
|
# batch_size=1,
|
|
|
|
|
# shuffle=False,
|
|
|
|
|
seed=1024,
|
|
|
|
|
# drop_last=False,
|
|
|
|
|
pin_memory=False,
|
|
|
|
|
num_workers=0,
|
|
|
|
|
process_group: Optional[ProcessGroup] = None,
|
|
|
|
|
distributed=True,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
r"""
|
|
|
|
|
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
|
|
|
|
`torch.utils.data.DataLoader` and `BatchDistributedSampler`.
|
|
|
|
|
|
|
|
|
|
batch_size must be 1; shuffle is not supported so far
|
|
|
|
|
"""
|
|
|
|
|
_kwargs = kwargs.copy()
|
|
|
|
|
if distributed:
|
|
|
|
|
process_group = process_group or _get_default_group()
|
|
|
|
|
sampler = BatchDistributedSampler(
|
|
|
|
|
dataset,
|
|
|
|
|
num_replicas=process_group.size(),
|
|
|
|
|
rank=process_group.rank(),
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
sampler = None
|
|
|
|
|
|
|
|
|
|
# Deterministic dataloader
|
|
|
|
|
def seed_worker(worker_id):
|
|
|
|
|
worker_seed = seed
|
|
|
|
|
np.random.seed(worker_seed)
|
|
|
|
|
torch.manual_seed(worker_seed)
|
|
|
|
|
random.seed(worker_seed)
|
|
|
|
|
|
|
|
|
|
return DataLoader(
|
|
|
|
|
dataset,
|
|
|
|
|
batch_size=1,
|
|
|
|
|
sampler=sampler,
|
|
|
|
|
worker_init_fn=seed_worker,
|
|
|
|
|
# drop_last=drop_last,
|
|
|
|
|
pin_memory=pin_memory,
|
|
|
|
|
num_workers=num_workers,
|
|
|
|
|
**_kwargs,
|
|
|
|
|
)
|