align pllava video loader with the one in get video info (#167)

This commit is contained in:
Tom Young 2024-07-04 11:17:05 +08:00 committed by GitHub
parent 08c5222cb9
commit 32b06c62d0

View file

@ -1,3 +1,17 @@
import sys
import os
import os
from pathlib import Path
current_file = Path(__file__) # Gets the path of the current file
fourth_level_parent = current_file.parents[3]
datasets_dir = os.path.join(fourth_level_parent, "opensora/datasets")
import sys
sys.path.append(datasets_dir)
from read_video import read_video_av
sys.path.remove(datasets_dir)
import itertools
import logging
import multiprocessing as mp
@ -95,21 +109,49 @@ def get_index(num_frames, num_segments):
return offsets
# def load_video(video_path, num_frames, return_msg=False, resolution=336):
# transforms = torchvision.transforms.Resize(size=resolution)
# vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
# total_num_frames = len(vr)
# frame_indices = get_index(total_num_frames, num_frames)
# images_group = list()
# for frame_index in frame_indices:
# img = Image.fromarray(vr[frame_index].asnumpy())
# images_group.append(transforms(img))
# if return_msg:
# fps = float(vr.get_avg_fps())
# sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
# # " " should be added in the start and end
# msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
# return images_group, msg
# else:
# return images_group
def load_video(video_path, num_frames, return_msg=False, resolution=336):
transforms = torchvision.transforms.Resize(size=resolution)
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
total_num_frames = len(vr)
# vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
vframes, aframes, info = read_video_av(
video_path,
pts_unit="sec",
output_format="THWC"
)
print(vframes.shape)
total_num_frames = len(vframes)
# print("Video path: ", video_path)
# print("Total number of frames: ", total_num_frames)
frame_indices = get_index(total_num_frames, num_frames)
images_group = list()
for frame_index in frame_indices:
img = Image.fromarray(vr[frame_index].asnumpy())
img = Image.fromarray(vframes[frame_index].numpy())
images_group.append(transforms(img))
if return_msg:
fps = float(vr.get_avg_fps())
sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
# " " should be added in the start and end
msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
return images_group, msg
# fps = float(vframes.get_avg_fps())
# sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
# # " " should be added in the start and end
# msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
# return images_group, msg
exit('return_msg not implemented yet')
else:
return images_group