mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
update
This commit is contained in:
parent
eb5ff82fc5
commit
322e566aed
|
|
@ -22,8 +22,13 @@ def parse_args(training=False):
|
||||||
# ======================================================
|
# ======================================================
|
||||||
# General
|
# General
|
||||||
# ======================================================
|
# ======================================================
|
||||||
parser.add_argument("--seed", default=42, type=int, help="generation seed")
|
parser.add_argument("--seed", default=None, type=int, help="seed for reproducibility")
|
||||||
parser.add_argument("--ckpt-path", type=str, help="path to model ckpt; will overwrite cfg.ckpt_path if specified")
|
parser.add_argument(
|
||||||
|
"--ckpt-path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="path to model ckpt; will overwrite cfg.model.from_pretrained if specified",
|
||||||
|
)
|
||||||
parser.add_argument("--batch-size", default=None, type=int, help="batch size")
|
parser.add_argument("--batch-size", default=None, type=int, help="batch size")
|
||||||
parser.add_argument("--outputs", default=None, type=str, help="the dir to save model weights")
|
parser.add_argument("--outputs", default=None, type=str, help="the dir to save model weights")
|
||||||
|
|
||||||
|
|
@ -121,23 +126,13 @@ def merge_args(cfg, args, training=False):
|
||||||
cfg["prompt"] = cfg["prompt"][args.start_index :]
|
cfg["prompt"] = cfg["prompt"][args.start_index :]
|
||||||
elif args.end_index is not None:
|
elif args.end_index is not None:
|
||||||
cfg["prompt"] = cfg["prompt"][: args.end_index]
|
cfg["prompt"] = cfg["prompt"][: args.end_index]
|
||||||
|
if "multi_resolution" not in cfg:
|
||||||
|
cfg["multi_resolution"] = False
|
||||||
else:
|
else:
|
||||||
# Training only
|
# Training only
|
||||||
# - Allow not set
|
# - Allow not set
|
||||||
if "mask_ratios" not in cfg:
|
|
||||||
cfg["mask_ratios"] = None
|
|
||||||
if "start_from_scratch" not in cfg:
|
|
||||||
cfg["start_from_scratch"] = False
|
|
||||||
if "bucket_config" not in cfg:
|
|
||||||
cfg["bucket_config"] = None
|
|
||||||
if "transform_name" not in cfg.dataset:
|
if "transform_name" not in cfg.dataset:
|
||||||
cfg.dataset["transform_name"] = "center"
|
cfg.dataset["transform_name"] = "center"
|
||||||
if "num_bucket_build_workers" not in cfg:
|
|
||||||
cfg["num_bucket_build_workers"] = 1
|
|
||||||
|
|
||||||
# Both training and inference
|
|
||||||
if "multi_resolution" not in cfg:
|
|
||||||
cfg["multi_resolution"] = False
|
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ def main():
|
||||||
# NOTE: A very large timeout is set to avoid some processes exit early
|
# NOTE: A very large timeout is set to avoid some processes exit early
|
||||||
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
||||||
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
||||||
set_seed(1024)
|
set_seed(cfg.get("seed", 1024))
|
||||||
coordinator = DistCoordinator()
|
coordinator = DistCoordinator()
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
|
|
||||||
|
|
@ -102,8 +102,8 @@ def main():
|
||||||
logger.info("Total batch size: %s", total_batch_size)
|
logger.info("Total batch size: %s", total_batch_size)
|
||||||
else:
|
else:
|
||||||
dataloader = prepare_variable_dataloader(
|
dataloader = prepare_variable_dataloader(
|
||||||
bucket_config=cfg.bucket_config,
|
bucket_config=cfg.get("bucket_config", None),
|
||||||
num_bucket_build_workers=cfg.num_bucket_build_workers,
|
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
|
||||||
**dataloader_args,
|
**dataloader_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue