This commit is contained in:
zhengzangw 2024-05-08 08:38:04 +00:00
parent eb5ff82fc5
commit 322e566aed
2 changed files with 12 additions and 17 deletions

View file

@ -22,8 +22,13 @@ def parse_args(training=False):
# ======================================================
# General
# ======================================================
parser.add_argument("--seed", default=42, type=int, help="generation seed")
parser.add_argument("--ckpt-path", type=str, help="path to model ckpt; will overwrite cfg.ckpt_path if specified")
parser.add_argument("--seed", default=None, type=int, help="seed for reproducibility")
parser.add_argument(
"--ckpt-path",
default=None,
type=str,
help="path to model ckpt; will overwrite cfg.model.from_pretrained if specified",
)
parser.add_argument("--batch-size", default=None, type=int, help="batch size")
parser.add_argument("--outputs", default=None, type=str, help="the dir to save model weights")
@ -121,23 +126,13 @@ def merge_args(cfg, args, training=False):
cfg["prompt"] = cfg["prompt"][args.start_index :]
elif args.end_index is not None:
cfg["prompt"] = cfg["prompt"][: args.end_index]
if "multi_resolution" not in cfg:
cfg["multi_resolution"] = False
else:
# Training only
# - Allow not set
if "mask_ratios" not in cfg:
cfg["mask_ratios"] = None
if "start_from_scratch" not in cfg:
cfg["start_from_scratch"] = False
if "bucket_config" not in cfg:
cfg["bucket_config"] = None
if "transform_name" not in cfg.dataset:
cfg.dataset["transform_name"] = "center"
if "num_bucket_build_workers" not in cfg:
cfg["num_bucket_build_workers"] = 1
# Both training and inference
if "multi_resolution" not in cfg:
cfg["multi_resolution"] = False
return cfg

View file

@ -45,7 +45,7 @@ def main():
# NOTE: A very large timeout is set to avoid some processes exit early
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(1024)
set_seed(cfg.get("seed", 1024))
coordinator = DistCoordinator()
device = get_current_device()
@ -102,8 +102,8 @@ def main():
logger.info("Total batch size: %s", total_batch_size)
else:
dataloader = prepare_variable_dataloader(
bucket_config=cfg.bucket_config,
num_bucket_build_workers=cfg.num_bucket_build_workers,
bucket_config=cfg.get("bucket_config", None),
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
**dataloader_args,
)