From 0f2bb1700bfd076e8b859486ee1637a7f016fbf4 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Sun, 31 Mar 2024 01:11:03 +0800 Subject: [PATCH] refactored llava captioning (#12) --- tools/caption/caption_llava.py | 184 +++++++++++++++++++++------------ 1 file changed, 120 insertions(+), 64 deletions(-) diff --git a/tools/caption/caption_llava.py b/tools/caption/caption_llava.py index b6333c7..58cdd03 100644 --- a/tools/caption/caption_llava.py +++ b/tools/caption/caption_llava.py @@ -1,24 +1,23 @@ import argparse import csv -import math -import multiprocessing as mp import os import time import warnings +from datetime import timedelta -import colossalai import pandas as pd import torch import torch.distributed as dist from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.utils import get_current_device +from colossalai.utils import get_current_device, set_seed from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX from llava.conversation import conv_templates from llava.mm_utils import get_anyres_image_grid_shape, get_model_name_from_path, process_images, tokenizer_image_token from llava.model.builder import load_pretrained_model from llava.model.llava_arch import unpad_image from llava.utils import disable_torch_init +from PIL import Image from tqdm import tqdm from .acceleration.llava.policy import LlavaForCausalLMPolicy @@ -26,6 +25,18 @@ from .utils import Timer, extract_frames, prompts disable_torch_init() +IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") +VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") + + +def is_video(filename): + ext = os.path.splitext(filename)[-1].lower() + return ext in VID_EXTENSIONS + + +def get_image(image_path): + return Image.open(image_path).convert("RGB") + def prepare_inputs_labels_for_multimodal( self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None @@ -251,12 +262,52 @@ def prepare_inputs_labels_for_multimodal( return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels +class VideoTextDataset(torch.utils.data.Dataset): + def __init__(self, csv_path, transform=None, points=(0.1, 0.5, 0.9)): + self.csv_path = csv_path + self.transform = transform + self.data = pd.read_csv(csv_path) + self.points = points + + def getitem(self, index): + sample = self.data.iloc[index] + path = sample["path"] + if not is_video(path): + images = [get_image(path)] + length = 1 + else: + images, length = extract_frames(sample["path"], points=self.points) + imgs_size = [img.size for img in images] + images = self.transform(images) + + # we put images into a list as pytorch dataloader does not accept Pill + out = dict(path=path, image=images, length=length, img_size=imgs_size) + return out + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.getitem(index) + + +def collate_fn(batch): + paths = [item["path"] for item in batch] + images = [item["image"] for item in batch] + lengths = [item["length"] for item in batch] + img_sizes = [item["img_size"] for item in batch] + return paths, images, lengths, img_sizes + + @torch.inference_mode() def main(args): # ====================================================== # 1. init environment # ====================================================== - colossalai.launch_from_torch({}) + # we set a very large timeout to avoid some processes exit early + dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) + torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) + set_seed(1024) coordinator = DistCoordinator() # prepare the dp and tp groups @@ -267,36 +318,6 @@ def main(args): dp_group = mesh.get_group_along_axis(0) tp_group = mesh.get_group_along_axis(1) - # ====================================================== - # 2. read video list - # ====================================================== - output_file = args.input.replace(".csv", "_caption.csv") - data_info = pd.read_csv(args.input) - videos = data_info["path"].tolist() - - # shard by DP - dp_size = dist.get_world_size(dp_group) - dp_rank = dist.get_rank(dp_group) - video_partition_size = math.ceil(len(videos) / dp_size) - videos = videos[dp_rank * video_partition_size : (dp_rank + 1) * video_partition_size] - - # create csv writer - has_main_writer = dist.get_rank() == 0 - has_dp_writter = dist.get_rank(tp_group) == 0 - - if has_main_writer: - # we keep track of the processed videos in main file - # so we use append mode - main_file = open(output_file, "a") - main_writer = csv.writer(main_file) - - if has_dp_writter: - # the dp writer takes care of the files processed on the current dp rank - # so we use write mode - output_file_split = f"{output_file}.part{dp_rank}" - dp_file = open(output_file_split, "w") - dp_writer = csv.writer(dp_file) - # ====================================================== # 3. load model and prepare prompts # ====================================================== @@ -317,6 +338,7 @@ def main(args): torch_dtype=torch.float16, attn_implementation="flash_attention_2", ) + dist.barrier() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") input_ids = input_ids.unsqueeze(0).cuda() @@ -333,18 +355,67 @@ def main(args): model = shard_former.optimize(model, policy=LlavaForCausalLMPolicy())[0].cuda() torch.cuda.empty_cache() + # ====================================================== + # 5. Prepare dataloader + # ====================================================== + # build dataset + def transform(imgs): + imgs = process_images(imgs, image_processor, model.config) + imgs = imgs.to(dtype=torch.float16) + return imgs + + dataset = VideoTextDataset(args.input, points=(0.2, 0.5, 0.8), transform=transform) + total_num_videos = len(dataset) + + # build sampler + dp_rank = dist.get_rank(dp_group) + dp_size = dist.get_world_size(dp_group) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, rank=dp_rank, num_replicas=dp_size, shuffle=False + ) + + # build dataloader + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.bs, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + prefetch_factor=args.prefetch_factor, + sampler=sampler, + collate_fn=collate_fn, + ) + + # prepare output file reader + output_file = args.input.replace(".csv", "_caption.csv") + + # create csv writer + has_main_writer = dist.get_rank() == 0 + has_dp_writter = dist.get_rank(tp_group) == 0 + + if has_main_writer: + # we keep track of the processed videos in main file + # so we use append mode + main_file = open(output_file, "a") + main_writer = csv.writer(main_file) + + if has_dp_writter: + # the dp writer takes care of the files processed on the current dp rank + # so we use write mode + output_file_split = f"{output_file}.part{dp_rank}" + dp_file = open(output_file_split, "w") + dp_writer = csv.writer(dp_file) + # ====================================================== # 5. generate captions # ====================================================== - bs = args.bs + args.bs - if dist.get_rank() == 0: - pbar = tqdm(range(0, len(videos), bs)) + if dist.get_rank(tp_group) == 0: + pbar = tqdm(dataloader, position=dp_rank, desc=f"Data Parallel Rank {dist.get_rank(dp_group)}") else: - pbar = tqdm(range(0, len(videos), bs)) - - # set up a multiprocessing pool - pool = mp.Pool(args.num_worker) + pbar = dataloader if args.profile: encode_time = [] @@ -352,34 +423,19 @@ def main(args): generate_time = [] total_time = [] - for i in pbar: + for batch in pbar: # measure time if args.profile: torch.cuda.synchronize() start_time = time.time() - video_files = videos[i : i + bs] - frames = [] - video_lengths = [] - - with Timer() as frame_extraction_timer: - # prepare a batch of inputs with parallel frame extraction - for frame, length in pool.map(extract_frames, video_files): - if len(frame) < 3: - continue - frames.append(frame) - video_lengths.append(length) - - if len(frames) == 0: - continue + video_files, frames, video_lengths, img_size_list = batch # encode the batch of inputs with Timer() as encode_timer: samples = [] - for imgs in frames: - imgs_size = [img.size for img in imgs] - imgs = process_images(imgs, image_processor, model.config) - imgs = imgs.cuda().to(dtype=torch.float16) + for imgs, imgs_size in zip(frames, img_size_list): + imgs = imgs.cuda() _, _, _, _, inputs_embeds, _ = prepare_inputs_labels_for_multimodal( model, input_ids, None, None, None, None, images=imgs, image_sizes=imgs_size ) @@ -426,7 +482,6 @@ def main(args): time_taken = time.time() - start_time total_time.append(time_taken) - frame_extraction_time.append(frame_extraction_timer.time_taken) encode_time.append(encode_timer.time_taken) generate_time.append(generate_timer.time_taken) @@ -438,7 +493,7 @@ def main(args): # display profiling info if args.profile: - num_samples_after_warmup = len(videos[1 * bs :]) * args.dp_size + num_samples_after_warmup = total_num_videos - args.bs * args.profile_warmup * dp_size print(f"throughput (video/s): {num_samples_after_warmup / sum(total_time)}") print(f"average frame extraction time per sample: {sum(frame_extraction_time) / num_samples_after_warmup}") print(f"average encode time per sample: {sum(encode_time) / num_samples_after_warmup}") @@ -477,8 +532,9 @@ if __name__ == "__main__": parser.add_argument("--prompt", type=str, default="three_frames") parser.add_argument("--tp-size", type=int, default=1) parser.add_argument("--dp-size", type=int, default=1) - parser.add_argument("--num-worker", type=int, default=8) + parser.add_argument("--num-workers", type=int, default=8) parser.add_argument("--profile", action="store_true") parser.add_argument("--profile-warmup", type=int, default=1) + parser.add_argument("--prefetch-factor", type=int, default=8, help="Prefetch factor") args = parser.parse_args() main(args)