Open-Sora/opensora/datasets/datasets.py

245 lines
7.5 KiB
Python
Raw Normal View History

2024-03-28 15:04:43 +01:00
import os
2024-05-21 07:45:06 +02:00
from glob import glob
2024-03-28 15:04:43 +01:00
2024-03-15 15:00:46 +01:00
import numpy as np
import torch
import torchvision
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
2024-03-26 10:02:41 +01:00
from opensora.registry import DATASETS
2024-03-15 15:00:46 +01:00
2024-04-11 09:28:49 +02:00
from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video, read_file, temporal_random_crop
2024-05-30 10:09:03 +02:00
from .read_video import read_video
2024-03-26 10:32:15 +01:00
IMG_FPS = 120
2024-03-15 15:00:46 +01:00
2024-03-26 10:02:41 +01:00
@DATASETS.register_module()
class VideoTextDataset(torch.utils.data.Dataset):
2024-03-15 15:00:46 +01:00
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(
self,
2024-04-29 08:00:14 +02:00
data_path=None,
2024-03-15 15:00:46 +01:00
num_frames=16,
frame_interval=1,
2024-03-26 09:50:36 +01:00
image_size=(256, 256),
2024-03-26 17:24:46 +01:00
transform_name="center",
2024-03-15 15:00:46 +01:00
):
2024-03-26 10:02:41 +01:00
self.data_path = data_path
2024-04-11 09:28:49 +02:00
self.data = read_file(data_path)
2024-04-29 19:02:22 +02:00
self.get_text = "text" in self.data.columns
2024-03-26 09:50:36 +01:00
self.num_frames = num_frames
self.frame_interval = frame_interval
2024-03-26 10:02:41 +01:00
self.image_size = image_size
2024-03-26 09:50:36 +01:00
self.transforms = {
2024-03-26 17:24:46 +01:00
"image": get_transforms_image(transform_name, image_size),
"video": get_transforms_video(transform_name, image_size),
2024-03-26 09:50:36 +01:00
}
2024-03-15 15:00:46 +01:00
2024-03-30 06:34:19 +01:00
def _print_data_number(self):
num_videos = 0
num_images = 0
for path in self.data["path"]:
if self.get_type(path) == "video":
num_videos += 1
else:
num_images += 1
print(f"Dataset contains {num_videos} videos and {num_images} images.")
2024-03-26 09:50:36 +01:00
def get_type(self, path):
2024-03-28 15:04:43 +01:00
ext = os.path.splitext(path)[-1].lower()
2024-03-23 09:32:51 +01:00
if ext.lower() in VID_EXTENSIONS:
2024-03-26 09:50:36 +01:00
return "video"
2024-03-15 15:00:46 +01:00
else:
2024-03-28 15:04:43 +01:00
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
2024-03-26 09:50:36 +01:00
return "image"
2024-03-15 15:00:46 +01:00
def getitem(self, index):
2024-03-25 11:36:56 +01:00
sample = self.data.iloc[index]
path = sample["path"]
2024-03-26 09:50:36 +01:00
file_type = self.get_type(path)
2024-03-15 15:00:46 +01:00
2024-03-26 09:50:36 +01:00
if file_type == "video":
# loading
2024-05-30 10:09:03 +02:00
vframes, _, infos = read_video(filename=path, pts_unit="sec", output_format="TCHW")
2024-03-15 15:00:46 +01:00
if "video_fps" in infos:
video_fps = infos["video_fps"]
2024-03-15 15:00:46 +01:00
# Sampling video frames
2024-03-26 10:32:15 +01:00
video = temporal_random_crop(vframes, self.num_frames, self.frame_interval)
2024-03-26 09:50:36 +01:00
# transform
transform = self.transforms["video"]
video = transform(video) # T C H W
2024-03-15 15:00:46 +01:00
else:
2024-03-26 09:50:36 +01:00
# loading
2024-03-15 15:00:46 +01:00
image = pil_loader(path)
video_fps = IMG_FPS
2024-03-26 09:50:36 +01:00
# transform
transform = self.transforms["image"]
image = transform(image)
# repeat
2024-03-15 15:00:46 +01:00
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
2024-04-27 08:52:52 +02:00
ret = {"video": video, "fps": video_fps}
2024-04-27 08:52:52 +02:00
if self.get_text:
2024-04-29 19:02:22 +02:00
ret["text"] = sample["text"]
return ret
2024-03-15 15:00:46 +01:00
def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
2024-04-17 10:45:18 +02:00
path = self.data.iloc[index]["path"]
print(f"data {path}: {e}")
2024-03-15 15:00:46 +01:00
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")
def __len__(self):
2024-03-25 11:36:56 +01:00
return len(self.data)
@DATASETS.register_module()
class VariableVideoTextDataset(VideoTextDataset):
def __init__(
self,
2024-04-29 08:00:14 +02:00
data_path=None,
num_frames=None,
frame_interval=1,
2024-04-29 08:00:14 +02:00
image_size=(None, None),
transform_name=None,
2024-05-20 11:40:28 +02:00
dummy_text_feature=False,
):
super().__init__(data_path, num_frames, frame_interval, image_size, transform_name=None)
self.transform_name = transform_name
self.data["id"] = np.arange(len(self.data))
2024-05-20 11:40:28 +02:00
self.dummy_text_feature = dummy_text_feature
def get_data_info(self, index):
T = self.data.iloc[index]["num_frames"]
H = self.data.iloc[index]["height"]
W = self.data.iloc[index]["width"]
return T, H, W
def getitem(self, index):
2024-03-31 14:59:22 +02:00
# a hack to pass in the (time, height, width) info from sampler
index, num_frames, height, width = [int(val) for val in index.split("-")]
sample = self.data.iloc[index]
path = sample["path"]
file_type = self.get_type(path)
2024-05-08 13:48:56 +02:00
ar = height / width
video_fps = 24 # default fps
if file_type == "video":
# loading
2024-05-30 10:09:03 +02:00
vframes, _, infos = read_video(filename=path, pts_unit="sec", output_format="TCHW")
if "video_fps" in infos:
video_fps = infos["video_fps"]
# Sampling video frames
video = temporal_random_crop(vframes, num_frames, self.frame_interval)
# transform
transform = get_transforms_video(self.transform_name, (height, width))
video = transform(video) # T C H W
else:
# loading
image = pil_loader(path)
video_fps = IMG_FPS
# transform
transform = get_transforms_image(self.transform_name, (height, width))
image = transform(image)
# repeat
2024-03-28 15:04:43 +01:00
video = image.unsqueeze(0)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
2024-04-29 19:02:22 +02:00
ret = {
"video": video,
"num_frames": num_frames,
"height": height,
"width": width,
"ar": ar,
"fps": video_fps,
}
2024-04-29 19:02:22 +02:00
if self.get_text:
ret["text"] = sample["text"]
2024-05-20 11:40:28 +02:00
if self.dummy_text_feature:
2024-05-21 06:05:02 +02:00
text_len = 50
ret["text"] = torch.zeros((1, text_len, 1152))
ret["mask"] = text_len
2024-04-29 19:02:22 +02:00
return ret
2024-05-12 16:08:37 +02:00
def __getitem__(self, index):
2024-05-21 07:45:06 +02:00
return self.getitem(index)
@DATASETS.register_module()
2024-05-21 07:45:06 +02:00
class BatchFeatureDataset(torch.utils.data.Dataset):
"""
The dataset is composed of multiple .bin files.
Each .bin file is a list of batch data (like a buffer). All .bin files have the same length.
In each training iteration, one batch is fetched from the current buffer.
Once a buffer is consumed, load another one.
Avoid loading the same .bin on two difference GPUs, i.e., one .bin is assigned to one GPU only.
"""
2024-05-21 07:45:06 +02:00
def __init__(self, data_path=None):
self.path_list = sorted(glob(data_path + "/**/*.bin"))
self._len_buffer = len(torch.load(self.path_list[0]))
self._num_buffers = len(self.path_list)
self.num_samples = self.len_buffer * len(self.path_list)
self.cur_file_idx = -1
2024-05-21 07:45:06 +02:00
self.cur_buffer = None
2024-05-21 06:05:02 +02:00
@property
def num_buffers(self):
return self._num_buffers
@property
def len_buffer(self):
return self._len_buffer
2024-05-21 06:05:02 +02:00
def _load_buffer(self, idx):
file_idx = idx // self.len_buffer
2024-05-21 07:45:06 +02:00
if file_idx != self.cur_file_idx:
self.cur_file_idx = file_idx
self.cur_buffer = torch.load(self.path_list[file_idx])
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
self._load_buffer(idx)
batch = self.cur_buffer[idx % self.len_buffer] # dict; keys are {'x', 'fps'} and text related
2024-05-21 09:20:14 +02:00
ret = {
"video": batch["x"],
"text": batch["y"],
"mask": batch["mask"],
"fps": batch["fps"],
"height": batch["height"],
"width": batch["width"],
"num_frames": batch["num_frames"],
}
return ret