mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-02-22 21:43:19 +01:00
* upload v2.0 * update docs * [hotfix] fit latest fa3 (#802) * update readme * update readme * update readme * update train readme * update readme * update readme: motion score * cleaning video dc ae WIP * update config * add dependency functions * undo cleaning * use latest dcae * complete high compression training * update hcae config * cleaned up vae * update ae.md * further cleanup * update vae & ae paths * align naming of ae * [hotfix] fix ring attn bwd for fa3 (#803) * train ae default without wandb * update config * update evaluation results * added hcae report * update readme * update readme demo * update readme demo * update readme gif * display demo directly in readme * update paper * delete files --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Shen-Chenhui <shen_chenhui@u.nus.edu> Co-authored-by: wuxiwen <wuxiwen.simon@gmail.com>
91 lines
2.3 KiB
Python
91 lines
2.3 KiB
Python
import logging
|
|
import os
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
def is_distributed() -> bool:
|
|
"""
|
|
Check if the code is running in a distributed setting.
|
|
|
|
Returns:
|
|
bool: True if running in a distributed setting, False otherwise
|
|
"""
|
|
return os.environ.get("WORLD_SIZE", None) is not None
|
|
|
|
|
|
def is_main_process() -> bool:
|
|
"""
|
|
Check if the current process is the main process.
|
|
|
|
Returns:
|
|
bool: True if the current process is the main process, False otherwise.
|
|
"""
|
|
return not is_distributed() or dist.get_rank() == 0
|
|
|
|
|
|
def get_world_size() -> int:
|
|
"""
|
|
Get the number of processes in the distributed setting.
|
|
|
|
Returns:
|
|
int: The number of processes.
|
|
"""
|
|
if is_distributed():
|
|
return dist.get_world_size()
|
|
else:
|
|
return 1
|
|
|
|
|
|
def create_logger(logging_dir: str = None) -> logging.Logger:
|
|
"""
|
|
Create a logger that writes to a log file and stdout. Only the main process logs.
|
|
|
|
Args:
|
|
logging_dir (str): The directory to save the log file.
|
|
|
|
Returns:
|
|
logging.Logger: The logger.
|
|
"""
|
|
if is_main_process():
|
|
additional_args = dict()
|
|
if logging_dir is not None:
|
|
additional_args["handlers"] = [
|
|
logging.StreamHandler(),
|
|
logging.FileHandler(f"{logging_dir}/log.txt"),
|
|
]
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="[\033[34m%(asctime)s\033[0m] %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
**additional_args,
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
if logging_dir is not None:
|
|
logger.info("Experiment directory created at %s", logging_dir)
|
|
else:
|
|
logger = logging.getLogger(__name__)
|
|
logger.addHandler(logging.NullHandler())
|
|
return logger
|
|
|
|
|
|
def log_message(*args, level: str = "info"):
|
|
"""
|
|
Log a message to the logger.
|
|
|
|
Args:
|
|
*args: The message to log.
|
|
level (str): The logging level.
|
|
"""
|
|
logger = logging.getLogger(__name__)
|
|
if level == "info":
|
|
logger.info(*args)
|
|
elif level == "warning":
|
|
logger.warning(*args)
|
|
elif level == "error":
|
|
logger.error(*args)
|
|
elif level == "print":
|
|
print(*args)
|
|
else:
|
|
raise ValueError(f"Invalid logging level: {level}")
|