mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 13:14:44 +02:00
101 lines
3.3 KiB
Python
101 lines
3.3 KiB
Python
import random
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
from colossalai.booster.plugin import LowLevelZeroPlugin
|
|
from colossalai.cluster import ProcessGroupMesh
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
DP_AXIS, SP_AXIS = 0, 1
|
|
|
|
|
|
class ZeroSeqParallelPlugin(LowLevelZeroPlugin):
|
|
def __init__(
|
|
self,
|
|
sp_size: int = 1,
|
|
stage: int = 2,
|
|
precision: str = "fp16",
|
|
initial_scale: float = 2**32,
|
|
min_scale: float = 1,
|
|
growth_factor: float = 2,
|
|
backoff_factor: float = 0.5,
|
|
growth_interval: int = 1000,
|
|
hysteresis: int = 2,
|
|
max_scale: float = 2**32,
|
|
max_norm: float = 0.0,
|
|
norm_type: float = 2.0,
|
|
reduce_bucket_size_in_m: int = 12,
|
|
communication_dtype: Optional[torch.dtype] = None,
|
|
overlap_communication: bool = True,
|
|
cpu_offload: bool = False,
|
|
master_weights: bool = True,
|
|
verbose: bool = False,
|
|
) -> None:
|
|
super().__init__(
|
|
stage=stage,
|
|
precision=precision,
|
|
initial_scale=initial_scale,
|
|
min_scale=min_scale,
|
|
growth_factor=growth_factor,
|
|
backoff_factor=backoff_factor,
|
|
growth_interval=growth_interval,
|
|
hysteresis=hysteresis,
|
|
max_scale=max_scale,
|
|
max_norm=max_norm,
|
|
norm_type=norm_type,
|
|
reduce_bucket_size_in_m=reduce_bucket_size_in_m,
|
|
communication_dtype=communication_dtype,
|
|
overlap_communication=overlap_communication,
|
|
cpu_offload=cpu_offload,
|
|
master_weights=master_weights,
|
|
verbose=verbose,
|
|
)
|
|
self.sp_size = sp_size
|
|
assert self.world_size % sp_size == 0, "world_size must be divisible by sp_size"
|
|
self.dp_size = self.world_size // sp_size
|
|
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.sp_size)
|
|
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
|
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
|
|
self.dp_rank = self.pg_mesh.coordinate(DP_AXIS)
|
|
self.sp_rank = self.pg_mesh.coordinate(SP_AXIS)
|
|
|
|
def __del__(self):
|
|
"""Destroy the prcess groups in ProcessGroupMesh"""
|
|
self.pg_mesh.destroy_mesh_process_groups()
|
|
|
|
def prepare_dataloader(
|
|
self,
|
|
dataset,
|
|
batch_size,
|
|
shuffle=False,
|
|
seed=1024,
|
|
drop_last=False,
|
|
pin_memory=False,
|
|
num_workers=0,
|
|
distributed_sampler_cls=None,
|
|
**kwargs,
|
|
):
|
|
_kwargs = kwargs.copy()
|
|
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
|
sampler = distributed_sampler_cls(dataset, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle)
|
|
|
|
# Deterministic dataloader
|
|
def seed_worker(worker_id):
|
|
worker_seed = seed
|
|
np.random.seed(worker_seed)
|
|
torch.manual_seed(worker_seed)
|
|
random.seed(worker_seed)
|
|
|
|
return DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
sampler=sampler,
|
|
worker_init_fn=seed_worker,
|
|
drop_last=drop_last,
|
|
pin_memory=pin_memory,
|
|
num_workers=num_workers,
|
|
**_kwargs,
|
|
)
|