This commit is contained in:
Shen-Chenhui 2024-04-27 17:18:19 +08:00
parent 7b134d71dc
commit 6fb4e3cd22
3 changed files with 6 additions and 10 deletions

View file

@ -11,24 +11,21 @@ dataset = dict(
fps = 24 // 3
is_vae = True
use_pipeline = True
# Define dataset
root = None
data_path = "CSV_PATH"
max_test_samples = -1
use_image_transform = False
num_workers = 4
video_contains_first_frame = True
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
use_pipeline = True
video_contains_first_frame = True
# Define model
@ -52,7 +49,6 @@ model = dict(
disable_space = True,
custom_conv_padding = None,
encoder_double_z = False,
custom_conv_padding=None,
)
discriminator = dict(

View file

@ -76,7 +76,7 @@ def merge_args(cfg, args, training=False):
if cfg.get("discriminator") is not None:
cfg.discriminator["from_pretrained"] = args.ckpt_path
args.ckpt_path = None
if training and args.data_path is not None:
if (training or cfg.get("is_vae", False)) and args.data_path is not None:
cfg.dataset["data_path"] = args.data_path
args.data_path = None
if not training and args.cfg_scale is not None:

View file

@ -12,7 +12,7 @@ from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import to_torch_dtype
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
from opensora.datasets import DATASETS, MODELS, build_module
from opensora.registry import DATASETS, MODELS, build_module
from opensora.acceleration.parallel_states import (
get_data_parallel_group,
set_data_parallel_group,