diff --git a/tools/caption/pllava_dir/caption_pllava.py b/tools/caption/pllava_dir/caption_pllava.py index f220f5a..ceb0721 100644 --- a/tools/caption/pllava_dir/caption_pllava.py +++ b/tools/caption/pllava_dir/caption_pllava.py @@ -1,3 +1,17 @@ +import sys +import os +import os +from pathlib import Path + +current_file = Path(__file__) # Gets the path of the current file +fourth_level_parent = current_file.parents[3] + +datasets_dir = os.path.join(fourth_level_parent, "opensora/datasets") +import sys +sys.path.append(datasets_dir) +from read_video import read_video_av +sys.path.remove(datasets_dir) + import itertools import logging import multiprocessing as mp @@ -95,21 +109,49 @@ def get_index(num_frames, num_segments): return offsets +# def load_video(video_path, num_frames, return_msg=False, resolution=336): +# transforms = torchvision.transforms.Resize(size=resolution) +# vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) +# total_num_frames = len(vr) +# frame_indices = get_index(total_num_frames, num_frames) +# images_group = list() +# for frame_index in frame_indices: +# img = Image.fromarray(vr[frame_index].asnumpy()) +# images_group.append(transforms(img)) +# if return_msg: +# fps = float(vr.get_avg_fps()) +# sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) +# # " " should be added in the start and end +# msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." +# return images_group, msg +# else: +# return images_group + + def load_video(video_path, num_frames, return_msg=False, resolution=336): transforms = torchvision.transforms.Resize(size=resolution) - vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) - total_num_frames = len(vr) + # vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) + vframes, aframes, info = read_video_av( + video_path, + pts_unit="sec", + output_format="THWC" + ) + print(vframes.shape) + total_num_frames = len(vframes) + # print("Video path: ", video_path) + # print("Total number of frames: ", total_num_frames) frame_indices = get_index(total_num_frames, num_frames) images_group = list() for frame_index in frame_indices: - img = Image.fromarray(vr[frame_index].asnumpy()) + img = Image.fromarray(vframes[frame_index].numpy()) images_group.append(transforms(img)) if return_msg: - fps = float(vr.get_avg_fps()) - sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) - # " " should be added in the start and end - msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." - return images_group, msg + # fps = float(vframes.get_avg_fps()) + # sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) + # # " " should be added in the start and end + # msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." + # return images_group, msg + exit('return_msg not implemented yet') else: return images_group