From 62097da2d378baa04780c308116f0d4beb20c29f Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Tue, 26 Mar 2024 17:32:15 +0800 Subject: [PATCH] refactor temporal_random_crop --- opensora/datasets/datasets.py | 15 +---- opensora/datasets/datasets_variable.py | 88 ++++++++++++++++++++++++++ opensora/datasets/utils.py | 9 +++ opensora/datasets/video_transforms.py | 18 ------ 4 files changed, 100 insertions(+), 30 deletions(-) create mode 100644 opensora/datasets/datasets_variable.py diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index 318779f..bef4a25 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -1,5 +1,3 @@ -import os - import numpy as np import pandas as pd import torch @@ -8,8 +6,8 @@ from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader from opensora.registry import DATASETS -from . import video_transforms -from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video +from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video, temporal_random_crop + @DATASETS.register_module() class VideoTextDataset(torch.utils.data.Dataset): @@ -33,7 +31,6 @@ class VideoTextDataset(torch.utils.data.Dataset): self.num_frames = num_frames self.frame_interval = frame_interval self.image_size = image_size - self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval) self.transforms = { "image": get_transforms_image(image_size[0]), "video": get_transforms_video(image_size[0]), @@ -58,13 +55,7 @@ class VideoTextDataset(torch.utils.data.Dataset): vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") # Sampling video frames - total_frames = len(vframes) - start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) - assert ( - end_frame_ind - start_frame_ind >= self.num_frames - ), f"{path} with index {index} has not enough frames." - frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) - video = vframes[frame_indice] + video = temporal_random_crop(vframes, self.num_frames, self.frame_interval) # transform transform = self.transforms["video"] diff --git a/opensora/datasets/datasets_variable.py b/opensora/datasets/datasets_variable.py new file mode 100644 index 0000000..45507ff --- /dev/null +++ b/opensora/datasets/datasets_variable.py @@ -0,0 +1,88 @@ +import numpy as np +import pandas as pd +import torch +import torchvision +from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader + +from opensora.registry import DATASETS + +from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video, temporal_random_crop + + +@DATASETS.register_module() +class VariableVideoTextDataset(torch.utils.data.Dataset): + """load video according to the csv file. + + Args: + target_video_len (int): the number of video frames will be load. + align_transform (callable): Align different videos in a specified size. + temporal_sample (callable): Sample the target length of a video. + """ + + def __init__( + self, + data_path, + num_frames=16, + frame_interval=1, + image_size=(256, 256), + ): + self.data_path = data_path + self.data = pd.read_csv(data_path) + self.num_frames = num_frames + self.frame_interval = frame_interval + self.image_size = image_size + self.transforms = { + "image": get_transforms_image(image_size[0]), + "video": get_transforms_video(image_size[0]), + } + + def get_type(self, path): + ext = path.split(".")[-1] + if ext.lower() in VID_EXTENSIONS: + return "video" + else: + assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}" + return "image" + + def getitem(self, index): + sample = self.data.iloc[index] + path = sample["path"] + text = sample["text"] + file_type = self.get_type(path) + + if file_type == "video": + # loading + vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") + + # Sampling video frames + video = temporal_random_crop(vframes, self.num_frames, self.frame_interval) + + # transform + transform = self.transforms["video"] + video = transform(video) # T C H W + else: + # loading + image = pil_loader(path) + + # transform + transform = self.transforms["image"] + image = transform(image) + + # repeat + video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1) + + # TCHW -> CTHW + video = video.permute(1, 0, 2, 3) + return {"video": video, "text": text} + + def __getitem__(self, index): + for _ in range(10): + try: + return self.getitem(index) + except Exception as e: + print(e) + index = np.random.randint(len(self)) + raise RuntimeError("Too many bad data.") + + def __len__(self): + return len(self.data) diff --git a/opensora/datasets/utils.py b/opensora/datasets/utils.py index 206d29d..38d50cd 100644 --- a/opensora/datasets/utils.py +++ b/opensora/datasets/utils.py @@ -18,6 +18,15 @@ from . import video_transforms VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv") +def temporal_random_crop(vframes, num_frames, frame_interval): + temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval) + total_frames = len(vframes) + start_frame_ind, end_frame_ind = temporal_sample(total_frames) + assert end_frame_ind - start_frame_ind >= num_frames + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int) + video = vframes[frame_indice] + return video + def get_transforms_video(resolution=256): transform_video = transforms.Compose( diff --git a/opensora/datasets/video_transforms.py b/opensora/datasets/video_transforms.py index 8d7d095..fa0b328 100644 --- a/opensora/datasets/video_transforms.py +++ b/opensora/datasets/video_transforms.py @@ -20,7 +20,6 @@ import random import numpy as np import torch -from PIL import Image def _is_tensor_video_clip(clip): @@ -33,23 +32,6 @@ def _is_tensor_video_clip(clip): return True -def center_crop_arr(pil_image, image_size): - """ - Center cropping implementation from ADM. - https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 - """ - while min(*pil_image.size) >= 2 * image_size: - pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) - - scale = image_size / min(*pil_image.size) - pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) - - arr = np.array(pil_image) - crop_y = (arr.shape[0] - image_size) // 2 - crop_x = (arr.shape[1] - image_size) // 2 - return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) - - def crop(clip, i, j, h, w): """ Args: