diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index ecb9438..af02ec3 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -149,7 +149,7 @@ def parse_configs(training=False): return cfg -def create_experiment_workspace(cfg, get_last_workspace=False): +def define_experiment_workspace(cfg, get_last_workspace=False): """ This function creates a folder for experiment tracking. @@ -169,7 +169,6 @@ def create_experiment_workspace(cfg, get_last_workspace=False): model_name = cfg.model["type"].replace("/", "-") exp_name = f"{experiment_index:03d}-{model_name}" exp_dir = f"{cfg.outputs}/{exp_name}" - os.makedirs(exp_dir, exist_ok=True) return exp_name, exp_dir diff --git a/scripts/train.py b/scripts/train.py index 43137b2..93fc3b6 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,10 +1,10 @@ +import os from copy import deepcopy from datetime import timedelta from pprint import pprint import torch import torch.distributed as dist -import wandb from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import DistCoordinator @@ -12,6 +12,7 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device, set_seed from tqdm import tqdm +import wandb from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import ( get_data_parallel_group, @@ -23,8 +24,8 @@ from opensora.datasets import prepare_dataloader, prepare_variable_dataloader from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save from opensora.utils.config_utils import ( - create_experiment_workspace, create_tensorboard_writer, + define_experiment_workspace, parse_configs, save_training_config, ) @@ -37,8 +38,6 @@ def main(): # 1. args & cfg # ====================================================== cfg = parse_configs(training=True) - exp_name, exp_dir = create_experiment_workspace(cfg) - save_training_config(cfg._cfg_dict, exp_dir) # ====================================================== # 2. runtime variables & colossalai launch @@ -55,7 +54,14 @@ def main(): device = get_current_device() dtype = to_torch_dtype(cfg.dtype) - # 2.2. init logger, tensorboard & wandb + # 2.2. init exp_dir, logger, tensorboard & wandb + exp_name, exp_dir = define_experiment_workspace(cfg) + coordinator.block_all() + if coordinator.is_master(): + os.makedirs(exp_dir, exist_ok=True) + save_training_config(cfg._cfg_dict, exp_dir) + coordinator.block_all() + if not coordinator.is_master(): logger = create_logger(None) else: diff --git a/scripts/train_vae.py b/scripts/train_vae.py index f02bd3d..5398e04 100644 --- a/scripts/train_vae.py +++ b/scripts/train_vae.py @@ -5,7 +5,6 @@ from pprint import pprint import torch import torch.distributed as dist -import wandb from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import DistCoordinator @@ -14,6 +13,7 @@ from colossalai.utils import get_current_device, set_seed from einops import rearrange from tqdm import tqdm +import wandb from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group from opensora.datasets import prepare_dataloader @@ -21,8 +21,8 @@ from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELo from opensora.registry import DATASETS, MODELS, build_module from opensora.utils.ckpt_utils import create_logger, load_json, save_json from opensora.utils.config_utils import ( - create_experiment_workspace, create_tensorboard_writer, + define_experiment_workspace, parse_configs, save_training_config, ) @@ -34,8 +34,6 @@ def main(): # 1. args & cfg # ====================================================== cfg = parse_configs(training=True) - exp_name, exp_dir = create_experiment_workspace(cfg) - save_training_config(cfg._cfg_dict, exp_dir) # ====================================================== # 2. runtime variables & colossalai launch @@ -52,7 +50,14 @@ def main(): device = get_current_device() dtype = to_torch_dtype(cfg.dtype) - # 2.2. init logger, tensorboard & wandb + # 2.2. init exp_dir, logger, tensorboard & wandb + exp_name, exp_dir = define_experiment_workspace(cfg) + coordinator.block_all() + if coordinator.is_master(): + os.makedirs(exp_dir, exist_ok=True) + save_training_config(cfg._cfg_dict, exp_dir) + coordinator.block_all() + if not coordinator.is_master(): logger = create_logger(None) else: