2024-07-04 05:17:05 +02:00
import sys
import os
import os
from pathlib import Path
current_file = Path ( __file__ ) # Gets the path of the current file
fourth_level_parent = current_file . parents [ 3 ]
datasets_dir = os . path . join ( fourth_level_parent , " opensora/datasets " )
import sys
sys . path . append ( datasets_dir )
from read_video import read_video_av
sys . path . remove ( datasets_dir )
2024-05-15 09:19:12 +02:00
import itertools
import logging
import multiprocessing as mp
from argparse import ArgumentParser
2024-06-17 17:37:23 +02:00
from multiprocessing import Process , Queue
2024-05-15 09:19:12 +02:00
import numpy as np
2024-06-17 17:37:23 +02:00
import pandas as pd
2024-05-15 09:19:12 +02:00
import torch
import torchvision
import transformers
2024-06-17 17:37:23 +02:00
from decord import VideoReader , cpu
from PIL import Image
2024-05-15 09:19:12 +02:00
from tasks . eval . eval_utils import Conversation
2024-06-17 17:37:23 +02:00
from tasks . eval . model_utils import load_pllava
from torch . utils . data import Dataset
from tqdm import tqdm
from transformers . feature_extraction_utils import BatchFeature
2024-05-15 09:19:12 +02:00
conv_template = Conversation (
system = " Describe this video. Pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. Here are some examples of good descriptions: 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. 3. Drone view of waves crashing against the rugged cliffs along Big Sur ' s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff ' s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’ s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway. " ,
roles = ( " USER: " , " ASSISTANT: " ) ,
messages = [ ] ,
sep = ( " " , " </s> " ) ,
2024-06-17 17:37:23 +02:00
mm_token = " <image> " ,
2024-05-15 09:19:12 +02:00
)
logging . basicConfig ( )
logger = logging . getLogger ( __name__ )
logger . setLevel ( logging . INFO )
2024-06-17 17:37:23 +02:00
RESOLUTION = 672 #
2024-05-15 09:19:12 +02:00
2024-06-17 17:37:23 +02:00
def pllava_answer (
conv : Conversation ,
model ,
processor ,
video_list ,
do_sample = True ,
max_new_tokens = 200 ,
num_beams = 1 ,
min_length = 1 ,
top_p = 0.9 ,
repetition_penalty = 1.0 ,
length_penalty = 1 ,
temperature = 1.0 ,
stop_criteria_keywords = None ,
print_res = False ,
) :
2024-05-27 03:11:10 +02:00
# torch.cuda.empty_cache()
prompt = conv . get_prompt ( )
2024-06-02 06:59:09 +02:00
inputs_list = [ processor ( text = prompt , images = video , return_tensors = " pt " ) for video in video_list ]
2024-06-17 17:37:23 +02:00
inputs_batched = dict ( ) # add batch dimension by cat
2024-06-02 06:59:09 +02:00
for input_type in list ( inputs_list [ 0 ] . keys ( ) ) :
inputs_batched [ input_type ] = torch . cat ( [ inputs [ input_type ] for inputs in inputs_list ] )
inputs_batched = BatchFeature ( inputs_batched , tensor_type = " pt " ) . to ( model . device )
2024-05-27 03:11:10 +02:00
with torch . no_grad ( ) :
2024-06-17 17:37:23 +02:00
output_texts = model . generate (
* * inputs_batched ,
media_type = " video " ,
do_sample = do_sample ,
max_new_tokens = max_new_tokens ,
num_beams = num_beams ,
min_length = min_length ,
top_p = top_p ,
repetition_penalty = repetition_penalty ,
length_penalty = length_penalty ,
temperature = temperature ,
)
output_texts = processor . batch_decode (
output_texts , skip_special_tokens = True , clean_up_tokenization_spaces = False
)
2024-05-27 03:11:10 +02:00
for i in range ( len ( output_texts ) ) :
2024-06-17 17:37:23 +02:00
if print_res : # debug usage
print ( " ### PROMPTING LM WITH: " , prompt )
print ( " ### LM OUTPUT TEXT: " , output_texts [ i ] )
2024-05-27 03:11:10 +02:00
if conv . roles [ - 1 ] == " <|im_start|>assistant \n " :
split_tag = " <|im_start|> assistant \n "
else :
split_tag = conv . roles [ - 1 ]
output_texts [ i ] = output_texts [ i ] . split ( split_tag ) [ - 1 ]
ending = conv . sep if isinstance ( conv . sep , str ) else conv . sep [ 1 ]
output_texts [ i ] = output_texts [ i ] . removesuffix ( ending ) . strip ( )
2024-06-17 17:37:23 +02:00
output_texts [ i ] = output_texts [ i ] . replace ( " \n " , " " )
2024-05-27 03:11:10 +02:00
conv . messages [ - 1 ] [ 1 ] = output_texts [ i ]
return output_texts , conv
2024-06-17 17:37:23 +02:00
2024-05-15 09:19:12 +02:00
def get_index ( num_frames , num_segments ) :
seg_size = float ( num_frames - 1 ) / num_segments
start = int ( seg_size / 2 )
2024-06-17 17:37:23 +02:00
offsets = np . array ( [ start + int ( np . round ( seg_size * idx ) ) for idx in range ( num_segments ) ] )
2024-05-15 09:19:12 +02:00
return offsets
2024-06-17 17:37:23 +02:00
2024-07-04 05:17:05 +02:00
# def load_video(video_path, num_frames, return_msg=False, resolution=336):
# transforms = torchvision.transforms.Resize(size=resolution)
# vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
# total_num_frames = len(vr)
# frame_indices = get_index(total_num_frames, num_frames)
# images_group = list()
# for frame_index in frame_indices:
# img = Image.fromarray(vr[frame_index].asnumpy())
# images_group.append(transforms(img))
# if return_msg:
# fps = float(vr.get_avg_fps())
# sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
# # " " should be added in the start and end
# msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
# return images_group, msg
# else:
# return images_group
2024-05-15 09:19:12 +02:00
def load_video ( video_path , num_frames , return_msg = False , resolution = 336 ) :
transforms = torchvision . transforms . Resize ( size = resolution )
2024-07-04 05:17:05 +02:00
# vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
vframes , aframes , info = read_video_av (
video_path ,
pts_unit = " sec " ,
output_format = " THWC "
)
print ( vframes . shape )
total_num_frames = len ( vframes )
# print("Video path: ", video_path)
# print("Total number of frames: ", total_num_frames)
2024-05-15 09:19:12 +02:00
frame_indices = get_index ( total_num_frames , num_frames )
images_group = list ( )
for frame_index in frame_indices :
2024-07-04 05:17:05 +02:00
img = Image . fromarray ( vframes [ frame_index ] . numpy ( ) )
2024-05-15 09:19:12 +02:00
images_group . append ( transforms ( img ) )
if return_msg :
2024-07-04 05:17:05 +02:00
# fps = float(vframes.get_avg_fps())
# sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
# # " " should be added in the start and end
# msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
# return images_group, msg
exit ( ' return_msg not implemented yet ' )
2024-05-15 09:19:12 +02:00
else :
return images_group
2024-06-17 17:37:23 +02:00
2024-05-27 03:11:10 +02:00
def collate_fn ( batch ) :
return batch
2024-06-17 17:37:23 +02:00
2024-05-15 09:19:12 +02:00
class CSVDataset ( Dataset ) :
2024-05-27 03:11:10 +02:00
def __init__ ( self , csv_path , num_frames ) :
2024-05-15 09:19:12 +02:00
self . df = pd . read_csv ( csv_path )
self . data_list = self . df . path . tolist ( )
2024-05-27 03:11:10 +02:00
self . num_frames = num_frames
2024-06-17 17:37:23 +02:00
2024-05-15 09:19:12 +02:00
def __len__ ( self ) :
return len ( self . data_list )
def __getitem__ ( self , idx ) :
if idx < 0 or idx > = len ( self . data_list ) :
raise IndexError
2024-06-28 10:30:29 +02:00
try :
video = load_video ( self . data_list [ idx ] , self . num_frames , resolution = RESOLUTION )
except :
return None
2024-05-27 03:11:10 +02:00
return video
2024-05-15 09:19:12 +02:00
def set_rank_and_world_size ( self , rank , world_size ) :
self . rank = rank
self . world_size = world_size
self . data_per_gpu = len ( self ) / / world_size
start_index = rank * self . data_per_gpu
end_index = ( rank + 1 ) * self . data_per_gpu if rank != world_size - 1 else len ( self )
self . data_list = self . data_list [ start_index : end_index ]
2024-06-17 17:37:23 +02:00
2024-05-15 09:19:12 +02:00
def parse_args ( ) :
parser = ArgumentParser ( )
2024-06-17 17:37:23 +02:00
parser . add_argument ( " --pretrained_model_name_or_path " , type = str , required = True , default = " llava-hf/llava-1.5-7b-hf " )
2024-05-27 03:11:10 +02:00
parser . add_argument (
" --batch_size " ,
type = int ,
required = False ,
2024-06-17 17:37:23 +02:00
default = 1 ,
2024-05-27 03:11:10 +02:00
)
2024-05-15 09:19:12 +02:00
parser . add_argument (
" --csv_path " ,
type = str ,
required = True ,
)
parser . add_argument (
" --num_frames " ,
type = int ,
required = True ,
default = 4 ,
)
2024-06-17 17:37:23 +02:00
parser . add_argument ( " --use_lora " , action = " store_true " )
2024-05-15 09:19:12 +02:00
parser . add_argument (
" --lora_alpha " ,
type = int ,
required = False ,
default = 4 ,
)
parser . add_argument (
" --weight_dir " ,
type = str ,
required = False ,
default = None ,
)
parser . add_argument (
2024-06-17 17:37:23 +02:00
" --conv_mode " ,
2024-05-15 09:19:12 +02:00
type = str ,
required = False ,
2024-06-17 17:37:23 +02:00
default = " eval_mvbench " ,
2024-05-15 09:19:12 +02:00
)
parser . add_argument (
2024-06-17 17:37:23 +02:00
" --pooling_shape " ,
2024-05-15 09:19:12 +02:00
type = str ,
required = False ,
default = None ,
)
parser . add_argument (
" --error_message " ,
type = str ,
required = False ,
2024-06-28 10:30:29 +02:00
default = ' error occured during captioning ' ,
2024-05-15 09:19:12 +02:00
)
args = parser . parse_args ( )
return args
2024-06-17 17:37:23 +02:00
def load_model_and_dataset (
rank ,
world_size ,
pretrained_model_name_or_path ,
num_frames ,
use_lora ,
lora_alpha ,
weight_dir ,
csv_path ,
pooling_shape = ( 16 , 12 , 12 ) ,
) :
2024-05-15 09:19:12 +02:00
# remind that, once the model goes larger (30B+) may cause the memory to be heavily used up. Even Tearing Nodes.
2024-06-17 17:37:23 +02:00
model , processor = load_pllava (
pretrained_model_name_or_path ,
num_frames = num_frames ,
use_lora = use_lora ,
weight_dir = weight_dir ,
lora_alpha = lora_alpha ,
pooling_shape = pooling_shape ,
)
logger . info ( " done loading llava " )
2024-05-15 09:19:12 +02:00
# position embedding
model = model . to ( torch . device ( rank ) )
model = model . eval ( )
2024-05-27 03:11:10 +02:00
dataset = CSVDataset ( csv_path , num_frames )
2024-05-15 09:19:12 +02:00
dataset . set_rank_and_world_size ( rank , world_size )
return model , processor , dataset
2024-06-17 17:37:23 +02:00
def infer (
model ,
processor ,
video_list ,
conv_mode ,
2024-06-30 05:01:59 +02:00
print_res = False ,
2024-06-17 17:37:23 +02:00
) :
2024-06-28 10:30:29 +02:00
# check if any video in video_list is None, if so, raise an exception
if any ( [ video is None for video in video_list ] ) :
raise Exception ( " Video not loaded properly " )
2024-05-15 09:19:12 +02:00
conv = conv_template . copy ( )
conv . user_query ( " Describe the video in details. " , is_mm = True )
2024-05-27 03:11:10 +02:00
llm_responses , conv = pllava_answer (
2024-05-15 09:19:12 +02:00
conv = conv ,
model = model ,
processor = processor ,
2024-05-27 03:11:10 +02:00
video_list = video_list ,
2024-05-15 09:19:12 +02:00
max_new_tokens = 256 ,
do_sample = False ,
2024-06-17 17:37:23 +02:00
print_res = print_res ,
2024-05-15 09:19:12 +02:00
)
2024-06-17 17:37:23 +02:00
2024-05-27 03:11:10 +02:00
return llm_responses
2024-05-15 09:19:12 +02:00
2024-06-17 17:37:23 +02:00
2024-05-27 03:11:10 +02:00
def run ( rank , args , world_size , output_queue ) :
2024-05-15 09:19:12 +02:00
if rank == 0 :
2024-06-17 17:37:23 +02:00
import os
if os . getenv ( " DEBUG_ADDRESS " ) != None :
import ptvsd
ptvsd . enable_attach ( address = ( " localhost " , int ( os . getenv ( " DEBUG_ADDRESS " ) ) ) , redirect_output = True )
ptvsd . wait_for_attach ( )
print ( " waiting for debugger attachment " )
2024-05-15 09:19:12 +02:00
if rank != 0 :
transformers . utils . logging . set_verbosity_error ( )
logger . setLevel ( transformers . logging . ERROR )
print_res = False
2024-06-17 17:37:23 +02:00
conv_mode = args . conv_mode
2024-05-15 09:19:12 +02:00
if args . pooling_shape is not None :
2024-06-17 17:37:23 +02:00
pooling_shape = tuple ( [ int ( x ) for x in args . pooling_shape . split ( " - " ) ] )
logger . info ( f " loading model and constructing dataset to gpu { rank } ... " )
model , processor , dataset = load_model_and_dataset (
rank ,
world_size ,
pretrained_model_name_or_path = args . pretrained_model_name_or_path ,
num_frames = args . num_frames ,
use_lora = args . use_lora ,
lora_alpha = args . lora_alpha ,
weight_dir = args . weight_dir ,
pooling_shape = pooling_shape ,
csv_path = args . csv_path ,
)
logger . info ( f " done model and dataset... " )
logger . info ( " constructing dataset... " )
logger . info ( " single test... " )
2024-05-27 03:11:10 +02:00
dataloader = torch . utils . data . DataLoader (
dataset ,
num_workers = 2 ,
batch_size = args . batch_size ,
collate_fn = collate_fn ,
shuffle = False ,
)
2024-06-17 17:37:23 +02:00
2024-05-15 09:19:12 +02:00
total = 0
result_list = [ ]
print ( len ( dataset ) )
2024-05-27 03:11:10 +02:00
for batch in tqdm ( dataloader ) :
2024-05-15 09:19:12 +02:00
total + = 1
try :
2024-05-27 03:11:10 +02:00
preds = infer (
2024-05-15 09:19:12 +02:00
model ,
processor ,
2024-05-27 03:11:10 +02:00
batch ,
2024-05-15 09:19:12 +02:00
conv_mode = conv_mode ,
print_res = print_res ,
)
except Exception as e :
2024-06-17 17:37:23 +02:00
logger . error ( f " error in { batch } : { str ( e ) } " )
2024-06-30 05:01:59 +02:00
# preds = args.error_message duplicated for each video in the batch
preds = [ args . error_message ] * len ( batch )
2024-05-27 03:11:10 +02:00
result_list . extend ( preds )
output_queue . put ( ( rank , result_list ) )
2024-05-15 09:19:12 +02:00
return result_list
2024-06-17 17:37:23 +02:00
2024-05-15 09:19:12 +02:00
def main ( ) :
multiprocess = True
2024-06-17 17:37:23 +02:00
mp . set_start_method ( " spawn " )
2024-05-15 09:19:12 +02:00
args = parse_args ( )
# csv_path = '/home/tom/PLLaVA/test_short_caption_part2.csv'
if multiprocess :
n_gpus = torch . cuda . device_count ( )
world_size = n_gpus
2024-06-17 17:37:23 +02:00
print ( f " world_size: { world_size } " )
2024-05-27 03:11:10 +02:00
# Create a queue to collect results from each process
output_queue = Queue ( )
2024-06-17 17:37:23 +02:00
2024-05-27 03:11:10 +02:00
# with Pool(world_size) as pool:
# func = functools.partial(run, args=args, world_size=world_size)
# result_lists = pool.map(func, range(world_size))
processes = [ ]
for i in range ( world_size ) :
# Each process will now also take the output queue as an argument
p = Process ( target = run , args = ( i , args , world_size , output_queue ) )
p . daemon = False
processes . append ( p )
p . start ( )
2024-06-17 17:37:23 +02:00
2024-05-27 03:11:10 +02:00
results_by_rank = { }
for _ in range ( world_size ) :
rank , results = output_queue . get ( ) # Retrieve results as they finish
results_by_rank [ rank ] = results
2024-06-17 17:37:23 +02:00
print ( f " Results received from rank { rank } " )
2024-05-27 03:11:10 +02:00
# ORDER THE RESULTS BY RANK
2024-06-17 17:37:23 +02:00
logger . info ( " finished running " )
2024-05-27 03:11:10 +02:00
for p in processes :
p . join ( )
2024-06-17 17:37:23 +02:00
2024-05-27 03:11:10 +02:00
results_list = list ( itertools . chain . from_iterable ( results_by_rank [ i ] for i in range ( world_size ) ) )
# results_list = list(itertools.chain([results_by_rank[i] for i in range(world_size)]))
# (data[key] for key in sorted_keys)
# results_list = [item for sublist in results_by_rank.values() for item in sublist]
2024-05-15 09:19:12 +02:00
else :
2024-06-17 17:37:23 +02:00
results_list = run ( 0 , world_size = 1 , args = args ) # debug
2024-05-27 03:11:10 +02:00
print ( results_list )
2024-05-15 09:19:12 +02:00
df = pd . read_csv ( args . csv_path )
# add a new column to the dataframe
2024-06-17 17:37:23 +02:00
df [ " text " ] = results_list
2024-05-15 09:19:12 +02:00
drop_failed = True
if drop_failed :
# iterate through the dataframe and delete the entire row if captioning failed
for i in tqdm ( range ( len ( df ) ) ) :
2024-06-17 17:37:23 +02:00
if df [ " text " ] [ i ] == args . error_message :
2024-05-15 09:19:12 +02:00
df = df . drop ( i )
# write the dataframe to a new csv file called '*_pllava_13b_caption.csv'
2024-06-17 17:37:23 +02:00
new_csv_path = args . csv_path . replace ( " .csv " , " _text.csv " )
2024-05-15 09:19:12 +02:00
df . to_csv ( new_csv_path , index = False )
2024-06-29 08:03:48 +02:00
print ( f " Results saved to { new_csv_path } " )
2024-06-17 17:37:23 +02:00
2024-05-15 09:19:12 +02:00
if __name__ == " __main__ " :
2024-06-17 17:37:23 +02:00
main ( )