diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index b5edb93..fe05832 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -71,15 +71,15 @@ def merge_args(cfg, args, training=False): if args.ckpt_path is not None: cfg.model["from_pretrained"] = args.ckpt_path args.ckpt_path = None - if args.cfg_scale is not None: - cfg.scheduler["cfg_scale"] = args.cfg_scale - args.cfg_scale = None - if args.num_sampling_steps is not None: - cfg.scheduler["num_sampling_steps"] = args.num_sampling_steps - args.num_sampling_steps = None if training and args.data_path is not None: cfg.dataset["data_path"] = args.data_path args.data_path = None + if not training and args.cfg_scale is not None: + cfg.scheduler["cfg_scale"] = args.cfg_scale + args.cfg_scale = None + if not training and args.num_sampling_steps is not None: + cfg.scheduler["num_sampling_steps"] = args.num_sampling_steps + args.num_sampling_steps = None for k, v in vars(args).items(): if v is not None: