From a1a2f29e2b29a87e41100aae99cd5bc3bed6f191 Mon Sep 17 00:00:00 2001 From: xyupeng Date: Thu, 30 May 2024 16:31:37 +0800 Subject: [PATCH] update read_video() --- opensora/datasets/datasets.py | 10 +++--- opensora/datasets/read_video.py | 56 ++++++++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index 582eebb..9287c70 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -68,10 +68,9 @@ class VideoTextDataset(torch.utils.data.Dataset): if file_type == "video": # loading - vframes, _, infos = read_video(filename=path, pts_unit="sec", output_format="TCHW") + vframes, vinfo = read_video(path, backend='cv2') + video_fps = vinfo["video_fps"] if 'video_fps' in vinfo else 24 - if "video_fps" in infos: - video_fps = infos["video_fps"] # Sampling video frames video = temporal_random_crop(vframes, self.num_frames, self.frame_interval) @@ -146,9 +145,8 @@ class VariableVideoTextDataset(VideoTextDataset): video_fps = 24 # default fps if file_type == "video": # loading - vframes, _, infos = read_video(filename=path, pts_unit="sec", output_format="TCHW") - if "video_fps" in infos: - video_fps = infos["video_fps"] + vframes, vinfo = read_video(path, backend='cv2') + video_fps = vinfo["video_fps"] if 'video_fps' in vinfo else 24 # Sampling video frames video = temporal_random_crop(vframes, num_frames, self.frame_interval) diff --git a/opensora/datasets/read_video.py b/opensora/datasets/read_video.py index b19849e..7061507 100644 --- a/opensora/datasets/read_video.py +++ b/opensora/datasets/read_video.py @@ -1,4 +1,5 @@ import math +import cv2 import os from fractions import Fraction from typing import Any, Dict, List, Optional, Tuple, Union @@ -15,7 +16,7 @@ from torchvision.io.video import ( ) -def read_video( +def read_video_av( filename: str, start_pts: Union[float, Fraction] = 0, end_pts: Optional[Union[float, Fraction]] = None, @@ -127,3 +128,56 @@ def read_video( vframes = vframes.permute(0, 3, 1, 2) return vframes, aframes, info + + +def read_video_cv2(video_path): + cap = cv2.VideoCapture(video_path) + + if not cap.isOpened(): + # print("Error: Unable to open video") + raise ValueError + else: + fps = cap.get(cv2.CAP_PROP_FPS) + vinfo = { + 'video_fps': fps, + } + + frames = [] + while True: + # Read a frame from the video + ret, frame = cap.read() + + # If frame is not read correctly, break the loop + if not ret: + break + + frames.append(frame[:, :, ::-1]) # BGR to RGB + + # Exit if 'q' is pressed + if cv2.waitKey(25) & 0xFF == ord('q'): + break + + # Release the video capture object and close all windows + cap.release() + cv2.destroyAllWindows() + + frames = np.stack(frames) + frames = torch.from_numpy(frames) # [T, H, W, C=3] + frames = frames.permute(0, 3, 1, 2) + return frames, vinfo + + +def read_video(video_path, backend='cv2'): + if backend == 'cv2': + vframes, vinfo = read_video_cv2(video_path) + elif backend == 'av': + vframes, _, vinfo = read_video_av(filename=video_path, pts_unit="sec", output_format="TCHW") + else: + raise ValueError + + return vframes, vinfo + + +if __name__ == '__main__': + vframes, vinfo = read_video('./data/colors/9.mp4', backend='cv2') + x = 0