Merge branch 'dev/v1.2' of github.com:hpcaitech/Open-Sora-dev into dev/v1.2

This commit is contained in:
zhengzangw 2024-05-30 08:49:47 +00:00
commit 6fe7aa36b9
2 changed files with 59 additions and 7 deletions

View file

@ -68,10 +68,9 @@ class VideoTextDataset(torch.utils.data.Dataset):
if file_type == "video":
# loading
vframes, _, infos = read_video(filename=path, pts_unit="sec", output_format="TCHW")
vframes, vinfo = read_video(path, backend='cv2')
video_fps = vinfo["video_fps"] if 'video_fps' in vinfo else 24
if "video_fps" in infos:
video_fps = infos["video_fps"]
# Sampling video frames
video = temporal_random_crop(vframes, self.num_frames, self.frame_interval)
@ -146,9 +145,8 @@ class VariableVideoTextDataset(VideoTextDataset):
video_fps = 24 # default fps
if file_type == "video":
# loading
vframes, _, infos = read_video(filename=path, pts_unit="sec", output_format="TCHW")
if "video_fps" in infos:
video_fps = infos["video_fps"]
vframes, vinfo = read_video(path, backend='cv2')
video_fps = vinfo["video_fps"] if 'video_fps' in vinfo else 24
# Sampling video frames
video = temporal_random_crop(vframes, num_frames, self.frame_interval)

View file

@ -1,5 +1,6 @@
import gc
import math
import cv2
import os
from fractions import Fraction
from typing import Any, Dict, List, Optional, Tuple, Union
@ -16,7 +17,7 @@ from torchvision.io.video import (
)
def read_video(
def read_video_av(
filename: str,
start_pts: Union[float, Fraction] = 0,
end_pts: Optional[Union[float, Fraction]] = None,
@ -132,3 +133,56 @@ def read_video(
vframes = vframes.permute(0, 3, 1, 2)
return vframes, aframes, info
def read_video_cv2(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
# print("Error: Unable to open video")
raise ValueError
else:
fps = cap.get(cv2.CAP_PROP_FPS)
vinfo = {
'video_fps': fps,
}
frames = []
while True:
# Read a frame from the video
ret, frame = cap.read()
# If frame is not read correctly, break the loop
if not ret:
break
frames.append(frame[:, :, ::-1]) # BGR to RGB
# Exit if 'q' is pressed
if cv2.waitKey(25) & 0xFF == ord('q'):
break
# Release the video capture object and close all windows
cap.release()
cv2.destroyAllWindows()
frames = np.stack(frames)
frames = torch.from_numpy(frames) # [T, H, W, C=3]
frames = frames.permute(0, 3, 1, 2)
return frames, vinfo
def read_video(video_path, backend='cv2'):
if backend == 'cv2':
vframes, vinfo = read_video_cv2(video_path)
elif backend == 'av':
vframes, _, vinfo = read_video_av(filename=video_path, pts_unit="sec", output_format="TCHW")
else:
raise ValueError
return vframes, vinfo
if __name__ == '__main__':
vframes, vinfo = read_video('./data/colors/9.mp4', backend='cv2')
x = 0