mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
Merge branch 'dev/v1.2' of github.com:hpcaitech/Open-Sora-dev into dev/v1.2
This commit is contained in:
commit
6fe7aa36b9
|
|
@ -68,10 +68,9 @@ class VideoTextDataset(torch.utils.data.Dataset):
|
|||
|
||||
if file_type == "video":
|
||||
# loading
|
||||
vframes, _, infos = read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
vframes, vinfo = read_video(path, backend='cv2')
|
||||
video_fps = vinfo["video_fps"] if 'video_fps' in vinfo else 24
|
||||
|
||||
if "video_fps" in infos:
|
||||
video_fps = infos["video_fps"]
|
||||
# Sampling video frames
|
||||
video = temporal_random_crop(vframes, self.num_frames, self.frame_interval)
|
||||
|
||||
|
|
@ -146,9 +145,8 @@ class VariableVideoTextDataset(VideoTextDataset):
|
|||
video_fps = 24 # default fps
|
||||
if file_type == "video":
|
||||
# loading
|
||||
vframes, _, infos = read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
if "video_fps" in infos:
|
||||
video_fps = infos["video_fps"]
|
||||
vframes, vinfo = read_video(path, backend='cv2')
|
||||
video_fps = vinfo["video_fps"] if 'video_fps' in vinfo else 24
|
||||
|
||||
# Sampling video frames
|
||||
video = temporal_random_crop(vframes, num_frames, self.frame_interval)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import gc
|
||||
import math
|
||||
import cv2
|
||||
import os
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
|
@ -16,7 +17,7 @@ from torchvision.io.video import (
|
|||
)
|
||||
|
||||
|
||||
def read_video(
|
||||
def read_video_av(
|
||||
filename: str,
|
||||
start_pts: Union[float, Fraction] = 0,
|
||||
end_pts: Optional[Union[float, Fraction]] = None,
|
||||
|
|
@ -132,3 +133,56 @@ def read_video(
|
|||
vframes = vframes.permute(0, 3, 1, 2)
|
||||
|
||||
return vframes, aframes, info
|
||||
|
||||
|
||||
def read_video_cv2(video_path):
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
if not cap.isOpened():
|
||||
# print("Error: Unable to open video")
|
||||
raise ValueError
|
||||
else:
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
vinfo = {
|
||||
'video_fps': fps,
|
||||
}
|
||||
|
||||
frames = []
|
||||
while True:
|
||||
# Read a frame from the video
|
||||
ret, frame = cap.read()
|
||||
|
||||
# If frame is not read correctly, break the loop
|
||||
if not ret:
|
||||
break
|
||||
|
||||
frames.append(frame[:, :, ::-1]) # BGR to RGB
|
||||
|
||||
# Exit if 'q' is pressed
|
||||
if cv2.waitKey(25) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
# Release the video capture object and close all windows
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
frames = np.stack(frames)
|
||||
frames = torch.from_numpy(frames) # [T, H, W, C=3]
|
||||
frames = frames.permute(0, 3, 1, 2)
|
||||
return frames, vinfo
|
||||
|
||||
|
||||
def read_video(video_path, backend='cv2'):
|
||||
if backend == 'cv2':
|
||||
vframes, vinfo = read_video_cv2(video_path)
|
||||
elif backend == 'av':
|
||||
vframes, _, vinfo = read_video_av(filename=video_path, pts_unit="sec", output_format="TCHW")
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return vframes, vinfo
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
vframes, vinfo = read_video('./data/colors/9.mp4', backend='cv2')
|
||||
x = 0
|
||||
|
|
|
|||
Loading…
Reference in a new issue