From 2988d5be99bee0e3b4dc60eadb6597156413ed7e Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Fri, 31 May 2024 07:17:01 +0000 Subject: [PATCH] [fix] read_video for tools --- opensora/datasets/read_video.py | 2 +- tools/datasets/datautil.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/opensora/datasets/read_video.py b/opensora/datasets/read_video.py index d344267..8cc4bca 100644 --- a/opensora/datasets/read_video.py +++ b/opensora/datasets/read_video.py @@ -172,7 +172,7 @@ def read_video_cv2(video_path): return frames, vinfo -def read_video(video_path, backend='cv2'): +def read_video(video_path, backend='av'): if backend == 'cv2': vframes, vinfo = read_video_cv2(video_path) elif backend == 'av': diff --git a/tools/datasets/datautil.py b/tools/datasets/datautil.py index 6deba6a..3cbac44 100644 --- a/tools/datasets/datautil.py +++ b/tools/datasets/datautil.py @@ -10,11 +10,12 @@ from glob import glob import cv2 import numpy as np import pandas as pd -import torchvision from PIL import Image from tqdm import tqdm +import torchvision from .utils import IMG_EXTENSIONS +from opensora.datasets.read_video import read_video tqdm.pandas() @@ -116,7 +117,8 @@ def get_image_info(path, backend="pillow"): def get_video_info(path, backend="torchvision"): if backend == "torchvision": try: - vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") + # vframes, infos = read_video(path) + vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="THWC") num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3] if "video_fps" in infos: fps = infos["video_fps"]