register dataset

This commit is contained in:
Zangwei Zheng 2024-03-26 17:02:41 +08:00
parent 01728dc28a
commit 552b7e8f79
6 changed files with 32 additions and 24 deletions

View file

@ -1,8 +1,11 @@
# Define dataset
data_path = "CSV_PATH"
num_frames = 16
frame_interval = 3
image_size = (256, 256)
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=16,
frame_interval=3,
image_size=(256, 256),
)
# Define acceleration
num_workers = 4

View file

@ -1,2 +1,2 @@
from .datasets import DatasetFromCSV
from .datasets import VideoTextDataset
from .utils import get_transforms_image, get_transforms_video, prepare_dataloader, save_sample

View file

@ -6,12 +6,13 @@ import torch
import torchvision
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from opensora.registry import DATASETS
from . import video_transforms
from .utils import VID_EXTENSIONS
from .utils import get_transforms_image, get_transforms_video
from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video
class DatasetFromCSV(torch.utils.data.Dataset):
@DATASETS.register_module()
class VideoTextDataset(torch.utils.data.Dataset):
"""load video according to the csv file.
Args:
@ -22,15 +23,16 @@ class DatasetFromCSV(torch.utils.data.Dataset):
def __init__(
self,
csv_path,
data_path,
num_frames=16,
frame_interval=1,
image_size=(256, 256),
):
self.csv_path = csv_path
self.data = pd.read_csv(csv_path)
self.data_path = data_path
self.data = pd.read_csv(data_path)
self.num_frames = num_frames
self.frame_interval = frame_interval
self.image_size = image_size
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
self.transforms = {
"image": get_transforms_image(image_size[0]),

View file

@ -37,3 +37,8 @@ SCHEDULERS = Registry(
"scheduler",
locations=["opensora.schedulers"],
)
DATASETS = Registry(
"dataset",
locations=["opensora.datasets"],
)

View file

@ -47,6 +47,9 @@ def merge_args(cfg, args, training=False):
if args.ckpt_path is not None:
cfg.model["from_pretrained"] = args.ckpt_path
args.ckpt_path = None
if args.data_path is not None:
cfg.dataset["data_path"] = args.data_path
args.data_path = None
for k, v in vars(args).items():
if k in cfg and v is not None:
@ -96,7 +99,7 @@ def create_experiment_workspace(cfg):
# Create an experiment folder
model_name = cfg.model["type"].replace("/", "-")
exp_name = f"{experiment_index:03d}-F{cfg.num_frames}S{cfg.frame_interval}-{model_name}"
exp_name = f"{experiment_index:03d}-{model_name}"
exp_dir = f"{cfg.outputs}/{exp_name}"
os.makedirs(exp_dir, exist_ok=True)
return exp_name, exp_dir

View file

@ -18,8 +18,8 @@ from opensora.acceleration.parallel_states import (
set_sequence_parallel_group,
)
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
from opensora.datasets import DatasetFromCSV, get_transforms_image, get_transforms_video, prepare_dataloader
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.datasets import prepare_dataloader
from opensora.registry import MODELS, SCHEDULERS, DATASETS, build_module
from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save
from opensora.utils.config_utils import (
create_experiment_workspace,
@ -89,12 +89,7 @@ def main():
# ======================================================
# 3. build dataset and dataloader
# ======================================================
dataset = DatasetFromCSV(
csv_path=cfg.data_path,
num_frames=cfg.num_frames,
frame_interval=cfg.frame_interval,
image_size=cfg.image_size,
)
dataset = build_module(cfg.dataset, DATASETS)
# TODO: use plugin's prepare dataloader
# a batch contains:
@ -111,7 +106,7 @@ def main():
pin_memory=True,
process_group=get_data_parallel_group(),
)
logger.info(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})")
logger.info(f"Dataset contains {len(dataset):,} videos ({dataset.data_path})")
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
logger.info(f"Total batch size: {total_batch_size}")
@ -120,10 +115,10 @@ def main():
# 4. build model
# ======================================================
# 4.1. build model
input_size = (cfg.num_frames, *cfg.image_size)
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
vae = build_module(cfg.vae, MODELS)
input_size = (dataset.num_frames, *dataset.image_size)
latent_size = vae.get_latent_size(input_size)
text_encoder = build_module(cfg.text_encoder, MODELS, device=device) # T5 must be fp32
model = build_module(
cfg.model,
MODELS,