From 552b7e8f7947834d3741962285b06a7da3c1c588 Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Tue, 26 Mar 2024 17:02:41 +0800 Subject: [PATCH] register dataset --- configs/opensora-v1-1/train/Vx256x256.py | 11 +++++++---- opensora/datasets/__init__.py | 2 +- opensora/datasets/datasets.py | 16 +++++++++------- opensora/registry.py | 5 +++++ opensora/utils/config_utils.py | 5 ++++- scripts/train.py | 17 ++++++----------- 6 files changed, 32 insertions(+), 24 deletions(-) diff --git a/configs/opensora-v1-1/train/Vx256x256.py b/configs/opensora-v1-1/train/Vx256x256.py index c2bd76d..594193d 100644 --- a/configs/opensora-v1-1/train/Vx256x256.py +++ b/configs/opensora-v1-1/train/Vx256x256.py @@ -1,8 +1,11 @@ # Define dataset -data_path = "CSV_PATH" -num_frames = 16 -frame_interval = 3 -image_size = (256, 256) +dataset = dict( + type="VideoTextDataset", + data_path=None, + num_frames=16, + frame_interval=3, + image_size=(256, 256), +) # Define acceleration num_workers = 4 diff --git a/opensora/datasets/__init__.py b/opensora/datasets/__init__.py index 94c5447..2526d70 100644 --- a/opensora/datasets/__init__.py +++ b/opensora/datasets/__init__.py @@ -1,2 +1,2 @@ -from .datasets import DatasetFromCSV +from .datasets import VideoTextDataset from .utils import get_transforms_image, get_transforms_video, prepare_dataloader, save_sample diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index fde74e5..318779f 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -6,12 +6,13 @@ import torch import torchvision from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader +from opensora.registry import DATASETS + from . import video_transforms -from .utils import VID_EXTENSIONS -from .utils import get_transforms_image, get_transforms_video +from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video - -class DatasetFromCSV(torch.utils.data.Dataset): +@DATASETS.register_module() +class VideoTextDataset(torch.utils.data.Dataset): """load video according to the csv file. Args: @@ -22,15 +23,16 @@ class DatasetFromCSV(torch.utils.data.Dataset): def __init__( self, - csv_path, + data_path, num_frames=16, frame_interval=1, image_size=(256, 256), ): - self.csv_path = csv_path - self.data = pd.read_csv(csv_path) + self.data_path = data_path + self.data = pd.read_csv(data_path) self.num_frames = num_frames self.frame_interval = frame_interval + self.image_size = image_size self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval) self.transforms = { "image": get_transforms_image(image_size[0]), diff --git a/opensora/registry.py b/opensora/registry.py index 7797d36..4335d38 100644 --- a/opensora/registry.py +++ b/opensora/registry.py @@ -37,3 +37,8 @@ SCHEDULERS = Registry( "scheduler", locations=["opensora.schedulers"], ) + +DATASETS = Registry( + "dataset", + locations=["opensora.datasets"], +) diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index 350cc8e..d820058 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -47,6 +47,9 @@ def merge_args(cfg, args, training=False): if args.ckpt_path is not None: cfg.model["from_pretrained"] = args.ckpt_path args.ckpt_path = None + if args.data_path is not None: + cfg.dataset["data_path"] = args.data_path + args.data_path = None for k, v in vars(args).items(): if k in cfg and v is not None: @@ -96,7 +99,7 @@ def create_experiment_workspace(cfg): # Create an experiment folder model_name = cfg.model["type"].replace("/", "-") - exp_name = f"{experiment_index:03d}-F{cfg.num_frames}S{cfg.frame_interval}-{model_name}" + exp_name = f"{experiment_index:03d}-{model_name}" exp_dir = f"{cfg.outputs}/{exp_name}" os.makedirs(exp_dir, exist_ok=True) return exp_name, exp_dir diff --git a/scripts/train.py b/scripts/train.py index 63be9f7..774fcec 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -18,8 +18,8 @@ from opensora.acceleration.parallel_states import ( set_sequence_parallel_group, ) from opensora.acceleration.plugin import ZeroSeqParallelPlugin -from opensora.datasets import DatasetFromCSV, get_transforms_image, get_transforms_video, prepare_dataloader -from opensora.registry import MODELS, SCHEDULERS, build_module +from opensora.datasets import prepare_dataloader +from opensora.registry import MODELS, SCHEDULERS, DATASETS, build_module from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save from opensora.utils.config_utils import ( create_experiment_workspace, @@ -89,12 +89,7 @@ def main(): # ====================================================== # 3. build dataset and dataloader # ====================================================== - dataset = DatasetFromCSV( - csv_path=cfg.data_path, - num_frames=cfg.num_frames, - frame_interval=cfg.frame_interval, - image_size=cfg.image_size, - ) + dataset = build_module(cfg.dataset, DATASETS) # TODO: use plugin's prepare dataloader # a batch contains: @@ -111,7 +106,7 @@ def main(): pin_memory=True, process_group=get_data_parallel_group(), ) - logger.info(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})") + logger.info(f"Dataset contains {len(dataset):,} videos ({dataset.data_path})") total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size logger.info(f"Total batch size: {total_batch_size}") @@ -120,10 +115,10 @@ def main(): # 4. build model # ====================================================== # 4.1. build model - input_size = (cfg.num_frames, *cfg.image_size) + text_encoder = build_module(cfg.text_encoder, MODELS, device=device) vae = build_module(cfg.vae, MODELS) + input_size = (dataset.num_frames, *dataset.image_size) latent_size = vae.get_latent_size(input_size) - text_encoder = build_module(cfg.text_encoder, MODELS, device=device) # T5 must be fp32 model = build_module( cfg.model, MODELS,