mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 12:49:38 +02:00
update
This commit is contained in:
parent
eb5ff82fc5
commit
322e566aed
|
|
@ -22,8 +22,13 @@ def parse_args(training=False):
|
|||
# ======================================================
|
||||
# General
|
||||
# ======================================================
|
||||
parser.add_argument("--seed", default=42, type=int, help="generation seed")
|
||||
parser.add_argument("--ckpt-path", type=str, help="path to model ckpt; will overwrite cfg.ckpt_path if specified")
|
||||
parser.add_argument("--seed", default=None, type=int, help="seed for reproducibility")
|
||||
parser.add_argument(
|
||||
"--ckpt-path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="path to model ckpt; will overwrite cfg.model.from_pretrained if specified",
|
||||
)
|
||||
parser.add_argument("--batch-size", default=None, type=int, help="batch size")
|
||||
parser.add_argument("--outputs", default=None, type=str, help="the dir to save model weights")
|
||||
|
||||
|
|
@ -121,23 +126,13 @@ def merge_args(cfg, args, training=False):
|
|||
cfg["prompt"] = cfg["prompt"][args.start_index :]
|
||||
elif args.end_index is not None:
|
||||
cfg["prompt"] = cfg["prompt"][: args.end_index]
|
||||
if "multi_resolution" not in cfg:
|
||||
cfg["multi_resolution"] = False
|
||||
else:
|
||||
# Training only
|
||||
# - Allow not set
|
||||
if "mask_ratios" not in cfg:
|
||||
cfg["mask_ratios"] = None
|
||||
if "start_from_scratch" not in cfg:
|
||||
cfg["start_from_scratch"] = False
|
||||
if "bucket_config" not in cfg:
|
||||
cfg["bucket_config"] = None
|
||||
if "transform_name" not in cfg.dataset:
|
||||
cfg.dataset["transform_name"] = "center"
|
||||
if "num_bucket_build_workers" not in cfg:
|
||||
cfg["num_bucket_build_workers"] = 1
|
||||
|
||||
# Both training and inference
|
||||
if "multi_resolution" not in cfg:
|
||||
cfg["multi_resolution"] = False
|
||||
|
||||
return cfg
|
||||
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ def main():
|
|||
# NOTE: A very large timeout is set to avoid some processes exit early
|
||||
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
||||
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
||||
set_seed(1024)
|
||||
set_seed(cfg.get("seed", 1024))
|
||||
coordinator = DistCoordinator()
|
||||
device = get_current_device()
|
||||
|
||||
|
|
@ -102,8 +102,8 @@ def main():
|
|||
logger.info("Total batch size: %s", total_batch_size)
|
||||
else:
|
||||
dataloader = prepare_variable_dataloader(
|
||||
bucket_config=cfg.bucket_config,
|
||||
num_bucket_build_workers=cfg.num_bucket_build_workers,
|
||||
bucket_config=cfg.get("bucket_config", None),
|
||||
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
|
||||
**dataloader_args,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue