From 6fb4e3cd2292346499663febc174b6b200bc3ff6 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Sat, 27 Apr 2024 17:18:19 +0800 Subject: [PATCH] debug --- .../vae_magvit_v2/inference/17x128x128_pixabay.py | 12 ++++-------- opensora/utils/config_utils.py | 2 +- scripts/inference-vae-v2.py | 2 +- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/configs/vae_magvit_v2/inference/17x128x128_pixabay.py b/configs/vae_magvit_v2/inference/17x128x128_pixabay.py index 9517f12..86b85d8 100644 --- a/configs/vae_magvit_v2/inference/17x128x128_pixabay.py +++ b/configs/vae_magvit_v2/inference/17x128x128_pixabay.py @@ -11,24 +11,21 @@ dataset = dict( fps = 24 // 3 is_vae = True -use_pipeline = True # Define dataset -root = None -data_path = "CSV_PATH" max_test_samples = -1 -use_image_transform = False -num_workers = 4 -video_contains_first_frame = True - # Define acceleration +num_workers = 4 dtype = "bf16" grad_checkpoint = True plugin = "zero2" sp_size = 1 +use_pipeline = True +video_contains_first_frame = True + # Define model @@ -52,7 +49,6 @@ model = dict( disable_space = True, custom_conv_padding = None, encoder_double_z = False, - custom_conv_padding=None, ) discriminator = dict( diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index 8262321..e59d4ab 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -76,7 +76,7 @@ def merge_args(cfg, args, training=False): if cfg.get("discriminator") is not None: cfg.discriminator["from_pretrained"] = args.ckpt_path args.ckpt_path = None - if training and args.data_path is not None: + if (training or cfg.get("is_vae", False)) and args.data_path is not None: cfg.dataset["data_path"] = args.data_path args.data_path = None if not training and args.cfg_scale is not None: diff --git a/scripts/inference-vae-v2.py b/scripts/inference-vae-v2.py index 5a41d79..bd07bf4 100644 --- a/scripts/inference-vae-v2.py +++ b/scripts/inference-vae-v2.py @@ -12,7 +12,7 @@ from opensora.registry import MODELS, SCHEDULERS, build_module from opensora.utils.config_utils import parse_configs from opensora.utils.misc import to_torch_dtype from opensora.datasets import prepare_dataloader, prepare_variable_dataloader -from opensora.datasets import DATASETS, MODELS, build_module +from opensora.registry import DATASETS, MODELS, build_module from opensora.acceleration.parallel_states import ( get_data_parallel_group, set_data_parallel_group,