Open-Sora/opensora/acceleration/plugin.py

101 lines
3.3 KiB
Python
Raw Normal View History

2024-03-15 15:00:46 +01:00
import random
from typing import Optional
import numpy as np
import torch
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import ProcessGroupMesh
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
DP_AXIS, SP_AXIS = 0, 1
class ZeroSeqParallelPlugin(LowLevelZeroPlugin):
def __init__(
self,
sp_size: int = 1,
stage: int = 2,
precision: str = "fp16",
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
reduce_bucket_size_in_m: int = 12,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
) -> None:
super().__init__(
stage=stage,
precision=precision,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type,
reduce_bucket_size_in_m=reduce_bucket_size_in_m,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
master_weights=master_weights,
verbose=verbose,
)
self.sp_size = sp_size
assert self.world_size % sp_size == 0, "world_size must be divisible by sp_size"
self.dp_size = self.world_size // sp_size
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.sp_size)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
self.dp_rank = self.pg_mesh.coordinate(DP_AXIS)
self.sp_rank = self.pg_mesh.coordinate(SP_AXIS)
def __del__(self):
"""Destroy the prcess groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups()
def prepare_dataloader(
self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
distributed_sampler_cls=None,
2024-03-15 15:16:20 +01:00
**kwargs,
2024-03-15 15:00:46 +01:00
):
_kwargs = kwargs.copy()
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
2024-03-15 15:16:20 +01:00
sampler = distributed_sampler_cls(dataset, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle)
2024-03-15 15:00:46 +01:00
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
2024-03-15 15:16:20 +01:00
)