mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-17 05:36:01 +02:00
* upload v2.0 * update docs * [hotfix] fit latest fa3 (#802) * update readme * update readme * update readme * update train readme * update readme * update readme: motion score * cleaning video dc ae WIP * update config * add dependency functions * undo cleaning * use latest dcae * complete high compression training * update hcae config * cleaned up vae * update ae.md * further cleanup * update vae & ae paths * align naming of ae * [hotfix] fix ring attn bwd for fa3 (#803) * train ae default without wandb * update config * update evaluation results * added hcae report * update readme * update readme demo * update readme demo * update readme gif * display demo directly in readme * update paper * delete files --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Shen-Chenhui <shen_chenhui@u.nus.edu> Co-authored-by: wuxiwen <wuxiwen.simon@gmail.com>
30 lines
823 B
Python
30 lines
823 B
Python
import torch.distributed as dist
|
|
|
|
_GLOBAL_PARALLEL_GROUPS = dict()
|
|
|
|
|
|
def set_data_parallel_group(group: dist.ProcessGroup):
|
|
_GLOBAL_PARALLEL_GROUPS["data"] = group
|
|
|
|
|
|
def get_data_parallel_group(get_mixed_dp_pg : bool = False):
|
|
if get_mixed_dp_pg and "mixed_dp_group" in _GLOBAL_PARALLEL_GROUPS:
|
|
return _GLOBAL_PARALLEL_GROUPS["mixed_dp_group"]
|
|
return _GLOBAL_PARALLEL_GROUPS.get("data", dist.group.WORLD)
|
|
|
|
|
|
def set_sequence_parallel_group(group: dist.ProcessGroup):
|
|
_GLOBAL_PARALLEL_GROUPS["sequence"] = group
|
|
|
|
|
|
def get_sequence_parallel_group():
|
|
return _GLOBAL_PARALLEL_GROUPS.get("sequence", None)
|
|
|
|
|
|
def set_tensor_parallel_group(group: dist.ProcessGroup):
|
|
_GLOBAL_PARALLEL_GROUPS["tensor"] = group
|
|
|
|
|
|
def get_tensor_parallel_group():
|
|
return _GLOBAL_PARALLEL_GROUPS.get("tensor", None)
|