mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 04:37:45 +02:00
* upload v2.0 * update docs * [hotfix] fit latest fa3 (#802) * update readme * update readme * update readme * update train readme * update readme * update readme: motion score * cleaning video dc ae WIP * update config * add dependency functions * undo cleaning * use latest dcae * complete high compression training * update hcae config * cleaned up vae * update ae.md * further cleanup * update vae & ae paths * align naming of ae * [hotfix] fix ring attn bwd for fa3 (#803) * train ae default without wandb * update config * update evaluation results * added hcae report * update readme * update readme demo * update readme demo * update readme gif * display demo directly in readme * update paper * delete files --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Shen-Chenhui <shen_chenhui@u.nus.edu> Co-authored-by: wuxiwen <wuxiwen.simon@gmail.com>
42 lines
1 KiB
Python
42 lines
1 KiB
Python
from copy import deepcopy
|
|
|
|
import torch.nn as nn
|
|
from mmengine.registry import Registry
|
|
|
|
|
|
def build_module(module: dict | nn.Module, builder: Registry, **kwargs) -> nn.Module | None:
|
|
"""Build module from config or return the module itself.
|
|
|
|
Args:
|
|
module (dict | nn.Module): The module to build.
|
|
builder (Registry): The registry to build module.
|
|
*args, **kwargs: Arguments passed to build function.
|
|
|
|
Returns:
|
|
(None | nn.Module): The created model.
|
|
"""
|
|
if module is None:
|
|
return None
|
|
if isinstance(module, dict):
|
|
cfg = deepcopy(module)
|
|
for k, v in kwargs.items():
|
|
cfg[k] = v
|
|
return builder.build(cfg)
|
|
elif isinstance(module, nn.Module):
|
|
return module
|
|
elif module is None:
|
|
return None
|
|
else:
|
|
raise TypeError(f"Only support dict and nn.Module, but got {type(module)}.")
|
|
|
|
|
|
MODELS = Registry(
|
|
"model",
|
|
locations=["opensora.models"],
|
|
)
|
|
|
|
DATASETS = Registry(
|
|
"dataset",
|
|
locations=["opensora.datasets"],
|
|
)
|