diff --git a/configs/opensora-v1-2/train/stage1-gc.py b/configs/opensora-v1-2/train/stage1-gc.py index 46344bc..001aa69 100644 --- a/configs/opensora-v1-2/train/stage1-gc.py +++ b/configs/opensora-v1-2/train/stage1-gc.py @@ -24,6 +24,7 @@ dataset = dict( # # --- # "2048": {1: (0.1, 5)}, # } + # webvid bucket_config = { # 20s/it "144p": {1: (1.0, 100), 51: (1.0, 30), 102: ((1.0, 0.33), 20), 204: ((1.0, 0.1), 8), 408: ((1.0, 0.1), 4)}, @@ -32,16 +33,16 @@ bucket_config = { # 20s/it "240p": {1: (0.3, 100), 51: (0.4, 24), 102: ((0.4, 0.33), 12), 204: ((0.4, 0.1), 4), 408: ((0.4, 0.1), 2)}, # --- "360p": {1: (0.2, 60), 51: (0.15, 12), 102: ((0.15, 0.33), 6), 204: ((0.15, 0.1), 2), 408: ((0.15, 0.1), 1)}, - # "512": {1: (0.1, 60), 51: (0.3, 12), 102: (0.3, 6), 204: (0.3, 2), 408: (0.3, 1)}, - # # --- - # "480p": {1: (0.1, 40), 51: (0.3, 6), 102: (0.3, 3), 204: (0.3, 1), 408: (0.0, None)}, - # # --- - # "720p": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)}, - # "1024": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)}, - # # --- - # "1080p": {1: (0.1, 10)}, - # # --- - # "2048": {1: (0.1, 5)}, + "512": {1: (0.1, 60), 51: (0.3, 12), 102: (0.3, 6), 204: (0.3, 2), 408: (0.3, 1)}, + # --- + "480p": {1: (0.1, 40), 51: (0.3, 6), 102: (0.3, 3), 204: (0.3, 1), 408: (0.0, None)}, + # --- + "720p": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)}, + "1024": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)}, + # --- + "1080p": {1: (0.1, 10)}, + # --- + "2048": {1: (0.1, 5)}, } grad_checkpoint = True diff --git a/opensora/utils/misc.py b/opensora/utils/misc.py index f3ae8cd..e492bdd 100644 --- a/opensora/utils/misc.py +++ b/opensora/utils/misc.py @@ -1,6 +1,7 @@ import collections import importlib import logging +import os import time from collections import OrderedDict from collections.abc import Sequence @@ -16,19 +17,30 @@ import torch.distributed as dist # ====================================================== -def create_logger(logging_dir): +def is_distributed(): + return os.environ.get("WORLD_SIZE", None) is not None + + +def is_main_process(): + return not is_distributed() or dist.get_rank() == 0 + + +def create_logger(logging_dir=None): """ Create a logger that writes to a log file and stdout. """ - if dist.get_rank() == 0: # real logger + if is_main_process(): # real logger + additional_args = dict() + if logging_dir is not None: + additional_args["handlers"] = [ + logging.StreamHandler(), + logging.FileHandler(f"{logging_dir}/log.txt"), + ] logging.basicConfig( level=logging.INFO, format="[\033[34m%(asctime)s\033[0m] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", - handlers=[ - logging.StreamHandler(), - logging.FileHandler(f"{logging_dir}/log.txt"), - ], + **additional_args, ) logger = logging.getLogger(__name__) else: # dummy logger (does nothing) diff --git a/scripts/inference.py b/scripts/inference.py index dceb7bf..a839ca5 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -12,7 +12,7 @@ from opensora.datasets import IMG_FPS, save_sample from opensora.models.text_encoder.t5 import text_preprocessing from opensora.registry import MODELS, SCHEDULERS, build_module from opensora.utils.config_utils import parse_configs -from opensora.utils.misc import to_torch_dtype +from opensora.utils.misc import create_logger, is_distributed, is_main_process, to_torch_dtype def main(): @@ -31,29 +31,28 @@ def main(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True - verbose = cfg.get("verbose", 2) - print(cfg) - - # init distributed - if os.environ.get("WORLD_SIZE", None): - use_dist = True + # == init distributed env == + if is_distributed(): colossalai.launch_from_torch({}) coordinator = DistCoordinator() - - if coordinator.world_size > 1: + enable_sequence_parallelism = coordinator.world_size > 1 + if enable_sequence_parallelism: set_sequence_parallel_group(dist.group.WORLD) - enable_sequence_parallelism = True - else: - enable_sequence_parallelism = False else: - use_dist = False enable_sequence_parallelism = False + set_random_seed(seed=cfg.seed) + + # == init logger == + create_logger() + verbose = cfg.get("verbose", 1) + breakpoint() + + print(cfg) # ====================================================== # 2. runtime variables # ====================================================== - set_random_seed(seed=cfg.seed) prompts = cfg.prompt # ====================================================== @@ -167,7 +166,7 @@ def main(): samples = vae.decode(samples.to(dtype), num_frames=cfg.num_frames) # 4.4. save samples - if not use_dist or coordinator.is_master(): + if is_main_process(): for idx, sample in enumerate(samples): if verbose >= 2: print(f"Prompt: {batch_prompts_raw[idx]}") diff --git a/scripts/train.py b/scripts/train.py index 0c671e6..0647432 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -313,7 +313,8 @@ def main(): global_step=global_step + 1, batch_size=cfg.get("batch_size", None), ) - model_sharding(ema) + if dist.get_rank() == 0: + model_sharding(ema) logger.info( "Saved checkpoint at epoch %s step %s global_step %s to %s", epoch,