mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-14 18:25:35 +02:00
20 lines
469 B
Python
20 lines
469 B
Python
import torch.distributed as dist
|
|
|
|
_GLOBAL_PARALLEL_GROUPS = dict()
|
|
|
|
|
|
def set_data_parallel_group(group: dist.ProcessGroup):
|
|
_GLOBAL_PARALLEL_GROUPS["data"] = group
|
|
|
|
|
|
def get_data_parallel_group():
|
|
return _GLOBAL_PARALLEL_GROUPS.get("data", dist.group.WORLD)
|
|
|
|
|
|
def set_sequence_parallel_group(group: dist.ProcessGroup):
|
|
_GLOBAL_PARALLEL_GROUPS["sequence"] = group
|
|
|
|
|
|
def get_sequence_parallel_group():
|
|
return _GLOBAL_PARALLEL_GROUPS.get("sequence", None)
|