mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
register dataset
This commit is contained in:
parent
01728dc28a
commit
552b7e8f79
|
|
@ -1,8 +1,11 @@
|
|||
# Define dataset
|
||||
data_path = "CSV_PATH"
|
||||
num_frames = 16
|
||||
frame_interval = 3
|
||||
image_size = (256, 256)
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=16,
|
||||
frame_interval=3,
|
||||
image_size=(256, 256),
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
from .datasets import DatasetFromCSV
|
||||
from .datasets import VideoTextDataset
|
||||
from .utils import get_transforms_image, get_transforms_video, prepare_dataloader, save_sample
|
||||
|
|
|
|||
|
|
@ -6,12 +6,13 @@ import torch
|
|||
import torchvision
|
||||
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
|
||||
|
||||
from opensora.registry import DATASETS
|
||||
|
||||
from . import video_transforms
|
||||
from .utils import VID_EXTENSIONS
|
||||
from .utils import get_transforms_image, get_transforms_video
|
||||
from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video
|
||||
|
||||
|
||||
class DatasetFromCSV(torch.utils.data.Dataset):
|
||||
@DATASETS.register_module()
|
||||
class VideoTextDataset(torch.utils.data.Dataset):
|
||||
"""load video according to the csv file.
|
||||
|
||||
Args:
|
||||
|
|
@ -22,15 +23,16 @@ class DatasetFromCSV(torch.utils.data.Dataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
csv_path,
|
||||
data_path,
|
||||
num_frames=16,
|
||||
frame_interval=1,
|
||||
image_size=(256, 256),
|
||||
):
|
||||
self.csv_path = csv_path
|
||||
self.data = pd.read_csv(csv_path)
|
||||
self.data_path = data_path
|
||||
self.data = pd.read_csv(data_path)
|
||||
self.num_frames = num_frames
|
||||
self.frame_interval = frame_interval
|
||||
self.image_size = image_size
|
||||
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
|
||||
self.transforms = {
|
||||
"image": get_transforms_image(image_size[0]),
|
||||
|
|
|
|||
|
|
@ -37,3 +37,8 @@ SCHEDULERS = Registry(
|
|||
"scheduler",
|
||||
locations=["opensora.schedulers"],
|
||||
)
|
||||
|
||||
DATASETS = Registry(
|
||||
"dataset",
|
||||
locations=["opensora.datasets"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -47,6 +47,9 @@ def merge_args(cfg, args, training=False):
|
|||
if args.ckpt_path is not None:
|
||||
cfg.model["from_pretrained"] = args.ckpt_path
|
||||
args.ckpt_path = None
|
||||
if args.data_path is not None:
|
||||
cfg.dataset["data_path"] = args.data_path
|
||||
args.data_path = None
|
||||
|
||||
for k, v in vars(args).items():
|
||||
if k in cfg and v is not None:
|
||||
|
|
@ -96,7 +99,7 @@ def create_experiment_workspace(cfg):
|
|||
|
||||
# Create an experiment folder
|
||||
model_name = cfg.model["type"].replace("/", "-")
|
||||
exp_name = f"{experiment_index:03d}-F{cfg.num_frames}S{cfg.frame_interval}-{model_name}"
|
||||
exp_name = f"{experiment_index:03d}-{model_name}"
|
||||
exp_dir = f"{cfg.outputs}/{exp_name}"
|
||||
os.makedirs(exp_dir, exist_ok=True)
|
||||
return exp_name, exp_dir
|
||||
|
|
|
|||
|
|
@ -18,8 +18,8 @@ from opensora.acceleration.parallel_states import (
|
|||
set_sequence_parallel_group,
|
||||
)
|
||||
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
|
||||
from opensora.datasets import DatasetFromCSV, get_transforms_image, get_transforms_video, prepare_dataloader
|
||||
from opensora.registry import MODELS, SCHEDULERS, build_module
|
||||
from opensora.datasets import prepare_dataloader
|
||||
from opensora.registry import MODELS, SCHEDULERS, DATASETS, build_module
|
||||
from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save
|
||||
from opensora.utils.config_utils import (
|
||||
create_experiment_workspace,
|
||||
|
|
@ -89,12 +89,7 @@ def main():
|
|||
# ======================================================
|
||||
# 3. build dataset and dataloader
|
||||
# ======================================================
|
||||
dataset = DatasetFromCSV(
|
||||
csv_path=cfg.data_path,
|
||||
num_frames=cfg.num_frames,
|
||||
frame_interval=cfg.frame_interval,
|
||||
image_size=cfg.image_size,
|
||||
)
|
||||
dataset = build_module(cfg.dataset, DATASETS)
|
||||
|
||||
# TODO: use plugin's prepare dataloader
|
||||
# a batch contains:
|
||||
|
|
@ -111,7 +106,7 @@ def main():
|
|||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
)
|
||||
logger.info(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})")
|
||||
logger.info(f"Dataset contains {len(dataset):,} videos ({dataset.data_path})")
|
||||
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
|
||||
logger.info(f"Total batch size: {total_batch_size}")
|
||||
|
|
@ -120,10 +115,10 @@ def main():
|
|||
# 4. build model
|
||||
# ======================================================
|
||||
# 4.1. build model
|
||||
input_size = (cfg.num_frames, *cfg.image_size)
|
||||
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
|
||||
vae = build_module(cfg.vae, MODELS)
|
||||
input_size = (dataset.num_frames, *dataset.image_size)
|
||||
latent_size = vae.get_latent_size(input_size)
|
||||
text_encoder = build_module(cfg.text_encoder, MODELS, device=device) # T5 must be fp32
|
||||
model = build_module(
|
||||
cfg.model,
|
||||
MODELS,
|
||||
|
|
|
|||
Loading…
Reference in a new issue