mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-17 14:25:07 +02:00
refactor temporal_random_crop
This commit is contained in:
parent
3a0b85456c
commit
62097da2d3
|
|
@ -1,5 +1,3 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
|
@ -8,8 +6,8 @@ from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
|
|||
|
||||
from opensora.registry import DATASETS
|
||||
|
||||
from . import video_transforms
|
||||
from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video
|
||||
from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video, temporal_random_crop
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class VideoTextDataset(torch.utils.data.Dataset):
|
||||
|
|
@ -33,7 +31,6 @@ class VideoTextDataset(torch.utils.data.Dataset):
|
|||
self.num_frames = num_frames
|
||||
self.frame_interval = frame_interval
|
||||
self.image_size = image_size
|
||||
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
|
||||
self.transforms = {
|
||||
"image": get_transforms_image(image_size[0]),
|
||||
"video": get_transforms_video(image_size[0]),
|
||||
|
|
@ -58,13 +55,7 @@ class VideoTextDataset(torch.utils.data.Dataset):
|
|||
vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
|
||||
# Sampling video frames
|
||||
total_frames = len(vframes)
|
||||
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
||||
assert (
|
||||
end_frame_ind - start_frame_ind >= self.num_frames
|
||||
), f"{path} with index {index} has not enough frames."
|
||||
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
|
||||
video = vframes[frame_indice]
|
||||
video = temporal_random_crop(vframes, self.num_frames, self.frame_interval)
|
||||
|
||||
# transform
|
||||
transform = self.transforms["video"]
|
||||
|
|
|
|||
88
opensora/datasets/datasets_variable.py
Normal file
88
opensora/datasets/datasets_variable.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchvision
|
||||
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
|
||||
|
||||
from opensora.registry import DATASETS
|
||||
|
||||
from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video, temporal_random_crop
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class VariableVideoTextDataset(torch.utils.data.Dataset):
|
||||
"""load video according to the csv file.
|
||||
|
||||
Args:
|
||||
target_video_len (int): the number of video frames will be load.
|
||||
align_transform (callable): Align different videos in a specified size.
|
||||
temporal_sample (callable): Sample the target length of a video.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path,
|
||||
num_frames=16,
|
||||
frame_interval=1,
|
||||
image_size=(256, 256),
|
||||
):
|
||||
self.data_path = data_path
|
||||
self.data = pd.read_csv(data_path)
|
||||
self.num_frames = num_frames
|
||||
self.frame_interval = frame_interval
|
||||
self.image_size = image_size
|
||||
self.transforms = {
|
||||
"image": get_transforms_image(image_size[0]),
|
||||
"video": get_transforms_video(image_size[0]),
|
||||
}
|
||||
|
||||
def get_type(self, path):
|
||||
ext = path.split(".")[-1]
|
||||
if ext.lower() in VID_EXTENSIONS:
|
||||
return "video"
|
||||
else:
|
||||
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
||||
return "image"
|
||||
|
||||
def getitem(self, index):
|
||||
sample = self.data.iloc[index]
|
||||
path = sample["path"]
|
||||
text = sample["text"]
|
||||
file_type = self.get_type(path)
|
||||
|
||||
if file_type == "video":
|
||||
# loading
|
||||
vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
|
||||
# Sampling video frames
|
||||
video = temporal_random_crop(vframes, self.num_frames, self.frame_interval)
|
||||
|
||||
# transform
|
||||
transform = self.transforms["video"]
|
||||
video = transform(video) # T C H W
|
||||
else:
|
||||
# loading
|
||||
image = pil_loader(path)
|
||||
|
||||
# transform
|
||||
transform = self.transforms["image"]
|
||||
image = transform(image)
|
||||
|
||||
# repeat
|
||||
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
|
||||
|
||||
# TCHW -> CTHW
|
||||
video = video.permute(1, 0, 2, 3)
|
||||
return {"video": video, "text": text}
|
||||
|
||||
def __getitem__(self, index):
|
||||
for _ in range(10):
|
||||
try:
|
||||
return self.getitem(index)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
index = np.random.randint(len(self))
|
||||
raise RuntimeError("Too many bad data.")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
|
@ -18,6 +18,15 @@ from . import video_transforms
|
|||
|
||||
VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv")
|
||||
|
||||
def temporal_random_crop(vframes, num_frames, frame_interval):
|
||||
temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
|
||||
total_frames = len(vframes)
|
||||
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
|
||||
assert end_frame_ind - start_frame_ind >= num_frames
|
||||
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int)
|
||||
video = vframes[frame_indice]
|
||||
return video
|
||||
|
||||
|
||||
def get_transforms_video(resolution=256):
|
||||
transform_video = transforms.Compose(
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import random
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def _is_tensor_video_clip(clip):
|
||||
|
|
@ -33,23 +32,6 @@ def _is_tensor_video_clip(clip):
|
|||
return True
|
||||
|
||||
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
|
||||
|
||||
|
||||
def crop(clip, i, j, h, w):
|
||||
"""
|
||||
Args:
|
||||
|
|
|
|||
Loading…
Reference in a new issue