Open-Sora/opensora/datasets/datasets.py

89 lines
2.6 KiB
Python
Raw Normal View History

2024-03-15 15:00:46 +01:00
import numpy as np
2024-03-25 11:36:56 +01:00
import pandas as pd
2024-03-15 15:00:46 +01:00
import torch
import torchvision
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
2024-03-26 10:02:41 +01:00
from opensora.registry import DATASETS
2024-03-15 15:00:46 +01:00
2024-03-26 10:32:15 +01:00
from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video, temporal_random_crop
2024-03-15 15:00:46 +01:00
2024-03-26 10:02:41 +01:00
@DATASETS.register_module()
class VideoTextDataset(torch.utils.data.Dataset):
2024-03-15 15:00:46 +01:00
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(
self,
2024-03-26 10:02:41 +01:00
data_path,
2024-03-15 15:00:46 +01:00
num_frames=16,
frame_interval=1,
2024-03-26 09:50:36 +01:00
image_size=(256, 256),
2024-03-15 15:00:46 +01:00
):
2024-03-26 10:02:41 +01:00
self.data_path = data_path
self.data = pd.read_csv(data_path)
2024-03-26 09:50:36 +01:00
self.num_frames = num_frames
self.frame_interval = frame_interval
2024-03-26 10:02:41 +01:00
self.image_size = image_size
2024-03-26 09:50:36 +01:00
self.transforms = {
"image": get_transforms_image(image_size[0]),
"video": get_transforms_video(image_size[0]),
}
2024-03-15 15:00:46 +01:00
2024-03-26 09:50:36 +01:00
def get_type(self, path):
ext = path.split(".")[-1]
2024-03-23 09:32:51 +01:00
if ext.lower() in VID_EXTENSIONS:
2024-03-26 09:50:36 +01:00
return "video"
2024-03-15 15:00:46 +01:00
else:
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
2024-03-26 09:50:36 +01:00
return "image"
2024-03-15 15:00:46 +01:00
def getitem(self, index):
2024-03-25 11:36:56 +01:00
sample = self.data.iloc[index]
path = sample["path"]
text = sample["text"]
2024-03-26 09:50:36 +01:00
file_type = self.get_type(path)
2024-03-15 15:00:46 +01:00
2024-03-26 09:50:36 +01:00
if file_type == "video":
# loading
vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
2024-03-15 15:00:46 +01:00
# Sampling video frames
2024-03-26 10:32:15 +01:00
video = temporal_random_crop(vframes, self.num_frames, self.frame_interval)
2024-03-26 09:50:36 +01:00
# transform
transform = self.transforms["video"]
video = transform(video) # T C H W
2024-03-15 15:00:46 +01:00
else:
2024-03-26 09:50:36 +01:00
# loading
2024-03-15 15:00:46 +01:00
image = pil_loader(path)
2024-03-26 09:50:36 +01:00
# transform
transform = self.transforms["image"]
image = transform(image)
# repeat
2024-03-15 15:00:46 +01:00
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text}
def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
print(e)
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")
def __len__(self):
2024-03-25 11:36:56 +01:00
return len(self.data)