Open-Sora/opensora/acceleration/parallel_states.py
Zheng Zangwei (Alex Zheng) d851a85535 format (#69)
2024-03-15 22:16:20 +08:00

20 lines
457 B
Python

import torch.distributed as dist
_GLOBAL_PARALLEL_GROUPS = dict()
def set_data_parallel_group(group: dist.ProcessGroup):
_GLOBAL_PARALLEL_GROUPS["data"] = group
def get_data_parallel_group():
return _GLOBAL_PARALLEL_GROUPS.get("data", None)
def set_sequence_parallel_group(group: dist.ProcessGroup):
_GLOBAL_PARALLEL_GROUPS["sequence"] = group
def get_sequence_parallel_group():
return _GLOBAL_PARALLEL_GROUPS.get("sequence", None)