From 322e566aed81572033c47cd00f1f4d055250d9da Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Wed, 8 May 2024 08:38:04 +0000 Subject: [PATCH] update --- opensora/utils/config_utils.py | 23 +++++++++-------------- scripts/train.py | 6 +++--- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index af02ec3..c03fa4e 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -22,8 +22,13 @@ def parse_args(training=False): # ====================================================== # General # ====================================================== - parser.add_argument("--seed", default=42, type=int, help="generation seed") - parser.add_argument("--ckpt-path", type=str, help="path to model ckpt; will overwrite cfg.ckpt_path if specified") + parser.add_argument("--seed", default=None, type=int, help="seed for reproducibility") + parser.add_argument( + "--ckpt-path", + default=None, + type=str, + help="path to model ckpt; will overwrite cfg.model.from_pretrained if specified", + ) parser.add_argument("--batch-size", default=None, type=int, help="batch size") parser.add_argument("--outputs", default=None, type=str, help="the dir to save model weights") @@ -121,23 +126,13 @@ def merge_args(cfg, args, training=False): cfg["prompt"] = cfg["prompt"][args.start_index :] elif args.end_index is not None: cfg["prompt"] = cfg["prompt"][: args.end_index] + if "multi_resolution" not in cfg: + cfg["multi_resolution"] = False else: # Training only # - Allow not set - if "mask_ratios" not in cfg: - cfg["mask_ratios"] = None - if "start_from_scratch" not in cfg: - cfg["start_from_scratch"] = False - if "bucket_config" not in cfg: - cfg["bucket_config"] = None if "transform_name" not in cfg.dataset: cfg.dataset["transform_name"] = "center" - if "num_bucket_build_workers" not in cfg: - cfg["num_bucket_build_workers"] = 1 - - # Both training and inference - if "multi_resolution" not in cfg: - cfg["multi_resolution"] = False return cfg diff --git a/scripts/train.py b/scripts/train.py index c8c53bc..15e7fa4 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -45,7 +45,7 @@ def main(): # NOTE: A very large timeout is set to avoid some processes exit early dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) - set_seed(1024) + set_seed(cfg.get("seed", 1024)) coordinator = DistCoordinator() device = get_current_device() @@ -102,8 +102,8 @@ def main(): logger.info("Total batch size: %s", total_batch_size) else: dataloader = prepare_variable_dataloader( - bucket_config=cfg.bucket_config, - num_bucket_build_workers=cfg.num_bucket_build_workers, + bucket_config=cfg.get("bucket_config", None), + num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1), **dataloader_args, )