mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 12:49:38 +02:00
align pllava video loader with the one in get video info (#167)
This commit is contained in:
parent
08c5222cb9
commit
32b06c62d0
|
|
@ -1,3 +1,17 @@
|
|||
import sys
|
||||
import os
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
current_file = Path(__file__) # Gets the path of the current file
|
||||
fourth_level_parent = current_file.parents[3]
|
||||
|
||||
datasets_dir = os.path.join(fourth_level_parent, "opensora/datasets")
|
||||
import sys
|
||||
sys.path.append(datasets_dir)
|
||||
from read_video import read_video_av
|
||||
sys.path.remove(datasets_dir)
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
|
|
@ -95,21 +109,49 @@ def get_index(num_frames, num_segments):
|
|||
return offsets
|
||||
|
||||
|
||||
# def load_video(video_path, num_frames, return_msg=False, resolution=336):
|
||||
# transforms = torchvision.transforms.Resize(size=resolution)
|
||||
# vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
||||
# total_num_frames = len(vr)
|
||||
# frame_indices = get_index(total_num_frames, num_frames)
|
||||
# images_group = list()
|
||||
# for frame_index in frame_indices:
|
||||
# img = Image.fromarray(vr[frame_index].asnumpy())
|
||||
# images_group.append(transforms(img))
|
||||
# if return_msg:
|
||||
# fps = float(vr.get_avg_fps())
|
||||
# sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
|
||||
# # " " should be added in the start and end
|
||||
# msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
|
||||
# return images_group, msg
|
||||
# else:
|
||||
# return images_group
|
||||
|
||||
|
||||
def load_video(video_path, num_frames, return_msg=False, resolution=336):
|
||||
transforms = torchvision.transforms.Resize(size=resolution)
|
||||
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
||||
total_num_frames = len(vr)
|
||||
# vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
||||
vframes, aframes, info = read_video_av(
|
||||
video_path,
|
||||
pts_unit="sec",
|
||||
output_format="THWC"
|
||||
)
|
||||
print(vframes.shape)
|
||||
total_num_frames = len(vframes)
|
||||
# print("Video path: ", video_path)
|
||||
# print("Total number of frames: ", total_num_frames)
|
||||
frame_indices = get_index(total_num_frames, num_frames)
|
||||
images_group = list()
|
||||
for frame_index in frame_indices:
|
||||
img = Image.fromarray(vr[frame_index].asnumpy())
|
||||
img = Image.fromarray(vframes[frame_index].numpy())
|
||||
images_group.append(transforms(img))
|
||||
if return_msg:
|
||||
fps = float(vr.get_avg_fps())
|
||||
sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
|
||||
# " " should be added in the start and end
|
||||
msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
|
||||
return images_group, msg
|
||||
# fps = float(vframes.get_avg_fps())
|
||||
# sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
|
||||
# # " " should be added in the start and end
|
||||
# msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
|
||||
# return images_group, msg
|
||||
exit('return_msg not implemented yet')
|
||||
else:
|
||||
return images_group
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue