[fix] read_video for tools

This commit is contained in:
zhengzangw 2024-05-31 07:17:01 +00:00
parent fd51586704
commit 2988d5be99
2 changed files with 5 additions and 3 deletions

View file

@ -172,7 +172,7 @@ def read_video_cv2(video_path):
return frames, vinfo
def read_video(video_path, backend='cv2'):
def read_video(video_path, backend='av'):
if backend == 'cv2':
vframes, vinfo = read_video_cv2(video_path)
elif backend == 'av':

View file

@ -10,11 +10,12 @@ from glob import glob
import cv2
import numpy as np
import pandas as pd
import torchvision
from PIL import Image
from tqdm import tqdm
import torchvision
from .utils import IMG_EXTENSIONS
from opensora.datasets.read_video import read_video
tqdm.pandas()
@ -116,7 +117,8 @@ def get_image_info(path, backend="pillow"):
def get_video_info(path, backend="torchvision"):
if backend == "torchvision":
try:
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
# vframes, infos = read_video(path)
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="THWC")
num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3]
if "video_fps" in infos:
fps = infos["video_fps"]