mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
[fix] read_video for tools
This commit is contained in:
parent
fd51586704
commit
2988d5be99
|
|
@ -172,7 +172,7 @@ def read_video_cv2(video_path):
|
|||
return frames, vinfo
|
||||
|
||||
|
||||
def read_video(video_path, backend='cv2'):
|
||||
def read_video(video_path, backend='av'):
|
||||
if backend == 'cv2':
|
||||
vframes, vinfo = read_video_cv2(video_path)
|
||||
elif backend == 'av':
|
||||
|
|
|
|||
|
|
@ -10,11 +10,12 @@ from glob import glob
|
|||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import torchvision
|
||||
|
||||
from .utils import IMG_EXTENSIONS
|
||||
from opensora.datasets.read_video import read_video
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
|
|
@ -116,7 +117,8 @@ def get_image_info(path, backend="pillow"):
|
|||
def get_video_info(path, backend="torchvision"):
|
||||
if backend == "torchvision":
|
||||
try:
|
||||
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
# vframes, infos = read_video(path)
|
||||
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="THWC")
|
||||
num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3]
|
||||
if "video_fps" in infos:
|
||||
fps = infos["video_fps"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue