Open-Sora/opensora/utils/logger.py
Zheng Zangwei (Alex Zheng) febf3ad4b2
Update Open-Sora 2.0 (#807)
* upload v2.0

* update docs

* [hotfix] fit latest fa3 (#802)

* update readme

* update readme

* update readme

* update train readme

* update readme

* update readme: motion score

* cleaning video dc ae WIP

* update config

* add dependency functions

* undo cleaning

* use latest dcae

* complete high compression training

* update hcae config

* cleaned up vae

* update ae.md

* further cleanup

* update vae & ae paths

* align naming of ae

* [hotfix] fix ring attn bwd for fa3 (#803)

* train ae default without wandb

* update config

* update evaluation results

* added hcae report

* update readme

* update readme demo

* update readme demo

* update readme gif

* display demo directly in readme

* update paper

* delete files

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Shen-Chenhui <shen_chenhui@u.nus.edu>
Co-authored-by: wuxiwen <wuxiwen.simon@gmail.com>
2025-03-12 13:14:22 +08:00

91 lines
2.3 KiB
Python

import logging
import os
import torch.distributed as dist
def is_distributed() -> bool:
"""
Check if the code is running in a distributed setting.
Returns:
bool: True if running in a distributed setting, False otherwise
"""
return os.environ.get("WORLD_SIZE", None) is not None
def is_main_process() -> bool:
"""
Check if the current process is the main process.
Returns:
bool: True if the current process is the main process, False otherwise.
"""
return not is_distributed() or dist.get_rank() == 0
def get_world_size() -> int:
"""
Get the number of processes in the distributed setting.
Returns:
int: The number of processes.
"""
if is_distributed():
return dist.get_world_size()
else:
return 1
def create_logger(logging_dir: str = None) -> logging.Logger:
"""
Create a logger that writes to a log file and stdout. Only the main process logs.
Args:
logging_dir (str): The directory to save the log file.
Returns:
logging.Logger: The logger.
"""
if is_main_process():
additional_args = dict()
if logging_dir is not None:
additional_args["handlers"] = [
logging.StreamHandler(),
logging.FileHandler(f"{logging_dir}/log.txt"),
]
logging.basicConfig(
level=logging.INFO,
format="[\033[34m%(asctime)s\033[0m] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
**additional_args,
)
logger = logging.getLogger(__name__)
if logging_dir is not None:
logger.info("Experiment directory created at %s", logging_dir)
else:
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
return logger
def log_message(*args, level: str = "info"):
"""
Log a message to the logger.
Args:
*args: The message to log.
level (str): The logging level.
"""
logger = logging.getLogger(__name__)
if level == "info":
logger.info(*args)
elif level == "warning":
logger.warning(*args)
elif level == "error":
logger.error(*args)
elif level == "print":
print(*args)
else:
raise ValueError(f"Invalid logging level: {level}")