2024-03-15 15:00:46 +01:00
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
2024-03-25 11:36:56 +01:00
|
|
|
import pandas as pd
|
2024-03-15 15:00:46 +01:00
|
|
|
import torch
|
|
|
|
|
import torchvision
|
|
|
|
|
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
|
|
|
|
|
|
|
|
|
|
from . import video_transforms
|
2024-03-23 13:28:34 +01:00
|
|
|
from .utils import VID_EXTENSIONS
|
2024-03-15 15:00:46 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class DatasetFromCSV(torch.utils.data.Dataset):
|
|
|
|
|
"""load video according to the csv file.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
target_video_len (int): the number of video frames will be load.
|
|
|
|
|
align_transform (callable): Align different videos in a specified size.
|
|
|
|
|
temporal_sample (callable): Sample the target length of a video.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
csv_path,
|
|
|
|
|
num_frames=16,
|
|
|
|
|
frame_interval=1,
|
|
|
|
|
transform=None,
|
|
|
|
|
root=None,
|
|
|
|
|
):
|
|
|
|
|
self.csv_path = csv_path
|
2024-03-25 11:36:56 +01:00
|
|
|
self.data = pd.read_csv(csv_path)
|
2024-03-15 15:00:46 +01:00
|
|
|
|
2024-03-25 11:36:56 +01:00
|
|
|
ext = self.data["path"][0].split(".")[-1]
|
2024-03-23 09:32:51 +01:00
|
|
|
if ext.lower() in VID_EXTENSIONS:
|
2024-03-15 15:00:46 +01:00
|
|
|
self.is_video = True
|
|
|
|
|
else:
|
2024-03-16 14:17:16 +01:00
|
|
|
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
2024-03-15 15:00:46 +01:00
|
|
|
self.is_video = False
|
|
|
|
|
|
|
|
|
|
self.transform = transform
|
|
|
|
|
self.num_frames = num_frames
|
|
|
|
|
self.frame_interval = frame_interval
|
|
|
|
|
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
|
|
|
|
|
self.root = root
|
|
|
|
|
|
|
|
|
|
def getitem(self, index):
|
2024-03-25 11:36:56 +01:00
|
|
|
sample = self.data.iloc[index]
|
|
|
|
|
path = sample["path"]
|
2024-03-15 15:00:46 +01:00
|
|
|
if self.root:
|
|
|
|
|
path = os.path.join(self.root, path)
|
2024-03-25 11:36:56 +01:00
|
|
|
text = sample["text"]
|
2024-03-15 15:00:46 +01:00
|
|
|
|
|
|
|
|
if self.is_video:
|
|
|
|
|
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
|
|
|
|
total_frames = len(vframes)
|
|
|
|
|
|
|
|
|
|
# Sampling video frames
|
|
|
|
|
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
|
|
|
|
assert (
|
|
|
|
|
end_frame_ind - start_frame_ind >= self.num_frames
|
|
|
|
|
), f"{path} with index {index} has not enough frames."
|
|
|
|
|
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
|
|
|
|
|
|
|
|
|
|
video = vframes[frame_indice]
|
|
|
|
|
video = self.transform(video) # T C H W
|
|
|
|
|
else:
|
|
|
|
|
image = pil_loader(path)
|
|
|
|
|
image = self.transform(image)
|
|
|
|
|
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
|
|
|
|
|
|
|
|
|
|
# TCHW -> CTHW
|
|
|
|
|
video = video.permute(1, 0, 2, 3)
|
|
|
|
|
|
|
|
|
|
return {"video": video, "text": text}
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
for _ in range(10):
|
|
|
|
|
try:
|
|
|
|
|
return self.getitem(index)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(e)
|
|
|
|
|
index = np.random.randint(len(self))
|
|
|
|
|
raise RuntimeError("Too many bad data.")
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
2024-03-25 11:36:56 +01:00
|
|
|
return len(self.data)
|