Open-Sora/opensora/acceleration/parallel_states.py
Zheng Zangwei (Alex Zheng) febf3ad4b2
Update Open-Sora 2.0 (#807)
* upload v2.0

* update docs

* [hotfix] fit latest fa3 (#802)

* update readme

* update readme

* update readme

* update train readme

* update readme

* update readme: motion score

* cleaning video dc ae WIP

* update config

* add dependency functions

* undo cleaning

* use latest dcae

* complete high compression training

* update hcae config

* cleaned up vae

* update ae.md

* further cleanup

* update vae & ae paths

* align naming of ae

* [hotfix] fix ring attn bwd for fa3 (#803)

* train ae default without wandb

* update config

* update evaluation results

* added hcae report

* update readme

* update readme demo

* update readme demo

* update readme gif

* display demo directly in readme

* update paper

* delete files

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Shen-Chenhui <shen_chenhui@u.nus.edu>
Co-authored-by: wuxiwen <wuxiwen.simon@gmail.com>
2025-03-12 13:14:22 +08:00

30 lines
823 B
Python

import torch.distributed as dist
_GLOBAL_PARALLEL_GROUPS = dict()
def set_data_parallel_group(group: dist.ProcessGroup):
_GLOBAL_PARALLEL_GROUPS["data"] = group
def get_data_parallel_group(get_mixed_dp_pg : bool = False):
if get_mixed_dp_pg and "mixed_dp_group" in _GLOBAL_PARALLEL_GROUPS:
return _GLOBAL_PARALLEL_GROUPS["mixed_dp_group"]
return _GLOBAL_PARALLEL_GROUPS.get("data", dist.group.WORLD)
def set_sequence_parallel_group(group: dist.ProcessGroup):
_GLOBAL_PARALLEL_GROUPS["sequence"] = group
def get_sequence_parallel_group():
return _GLOBAL_PARALLEL_GROUPS.get("sequence", None)
def set_tensor_parallel_group(group: dist.ProcessGroup):
_GLOBAL_PARALLEL_GROUPS["tensor"] = group
def get_tensor_parallel_group():
return _GLOBAL_PARALLEL_GROUPS.get("tensor", None)