mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
hotfix config args
This commit is contained in:
parent
da345719c7
commit
fea6aba655
|
|
@ -71,15 +71,15 @@ def merge_args(cfg, args, training=False):
|
|||
if args.ckpt_path is not None:
|
||||
cfg.model["from_pretrained"] = args.ckpt_path
|
||||
args.ckpt_path = None
|
||||
if args.cfg_scale is not None:
|
||||
cfg.scheduler["cfg_scale"] = args.cfg_scale
|
||||
args.cfg_scale = None
|
||||
if args.num_sampling_steps is not None:
|
||||
cfg.scheduler["num_sampling_steps"] = args.num_sampling_steps
|
||||
args.num_sampling_steps = None
|
||||
if training and args.data_path is not None:
|
||||
cfg.dataset["data_path"] = args.data_path
|
||||
args.data_path = None
|
||||
if not training and args.cfg_scale is not None:
|
||||
cfg.scheduler["cfg_scale"] = args.cfg_scale
|
||||
args.cfg_scale = None
|
||||
if not training and args.num_sampling_steps is not None:
|
||||
cfg.scheduler["num_sampling_steps"] = args.num_sampling_steps
|
||||
args.num_sampling_steps = None
|
||||
|
||||
for k, v in vars(args).items():
|
||||
if v is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue