refactored llava captioning (#12)

This commit is contained in:
Frank Lee 2024-03-31 01:11:03 +08:00 committed by GitHub
parent fb6bd34443
commit 0f2bb1700b

View file

@ -1,24 +1,23 @@
import argparse
import csv
import math
import multiprocessing as mp
import os
import time
import warnings
from datetime import timedelta
import colossalai
import pandas as pd
import torch
import torch.distributed as dist
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device
from colossalai.utils import get_current_device, set_seed
from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
from llava.conversation import conv_templates
from llava.mm_utils import get_anyres_image_grid_shape, get_model_name_from_path, process_images, tokenizer_image_token
from llava.model.builder import load_pretrained_model
from llava.model.llava_arch import unpad_image
from llava.utils import disable_torch_init
from PIL import Image
from tqdm import tqdm
from .acceleration.llava.policy import LlavaForCausalLMPolicy
@ -26,6 +25,18 @@ from .utils import Timer, extract_frames, prompts
disable_torch_init()
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
def is_video(filename):
ext = os.path.splitext(filename)[-1].lower()
return ext in VID_EXTENSIONS
def get_image(image_path):
return Image.open(image_path).convert("RGB")
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
@ -251,12 +262,52 @@ def prepare_inputs_labels_for_multimodal(
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
class VideoTextDataset(torch.utils.data.Dataset):
def __init__(self, csv_path, transform=None, points=(0.1, 0.5, 0.9)):
self.csv_path = csv_path
self.transform = transform
self.data = pd.read_csv(csv_path)
self.points = points
def getitem(self, index):
sample = self.data.iloc[index]
path = sample["path"]
if not is_video(path):
images = [get_image(path)]
length = 1
else:
images, length = extract_frames(sample["path"], points=self.points)
imgs_size = [img.size for img in images]
images = self.transform(images)
# we put images into a list as pytorch dataloader does not accept Pill
out = dict(path=path, image=images, length=length, img_size=imgs_size)
return out
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.getitem(index)
def collate_fn(batch):
paths = [item["path"] for item in batch]
images = [item["image"] for item in batch]
lengths = [item["length"] for item in batch]
img_sizes = [item["img_size"] for item in batch]
return paths, images, lengths, img_sizes
@torch.inference_mode()
def main(args):
# ======================================================
# 1. init environment
# ======================================================
colossalai.launch_from_torch({})
# we set a very large timeout to avoid some processes exit early
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(1024)
coordinator = DistCoordinator()
# prepare the dp and tp groups
@ -267,36 +318,6 @@ def main(args):
dp_group = mesh.get_group_along_axis(0)
tp_group = mesh.get_group_along_axis(1)
# ======================================================
# 2. read video list
# ======================================================
output_file = args.input.replace(".csv", "_caption.csv")
data_info = pd.read_csv(args.input)
videos = data_info["path"].tolist()
# shard by DP
dp_size = dist.get_world_size(dp_group)
dp_rank = dist.get_rank(dp_group)
video_partition_size = math.ceil(len(videos) / dp_size)
videos = videos[dp_rank * video_partition_size : (dp_rank + 1) * video_partition_size]
# create csv writer
has_main_writer = dist.get_rank() == 0
has_dp_writter = dist.get_rank(tp_group) == 0
if has_main_writer:
# we keep track of the processed videos in main file
# so we use append mode
main_file = open(output_file, "a")
main_writer = csv.writer(main_file)
if has_dp_writter:
# the dp writer takes care of the files processed on the current dp rank
# so we use write mode
output_file_split = f"{output_file}.part{dp_rank}"
dp_file = open(output_file_split, "w")
dp_writer = csv.writer(dp_file)
# ======================================================
# 3. load model and prepare prompts
# ======================================================
@ -317,6 +338,7 @@ def main(args):
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
)
dist.barrier()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
input_ids = input_ids.unsqueeze(0).cuda()
@ -333,18 +355,67 @@ def main(args):
model = shard_former.optimize(model, policy=LlavaForCausalLMPolicy())[0].cuda()
torch.cuda.empty_cache()
# ======================================================
# 5. Prepare dataloader
# ======================================================
# build dataset
def transform(imgs):
imgs = process_images(imgs, image_processor, model.config)
imgs = imgs.to(dtype=torch.float16)
return imgs
dataset = VideoTextDataset(args.input, points=(0.2, 0.5, 0.8), transform=transform)
total_num_videos = len(dataset)
# build sampler
dp_rank = dist.get_rank(dp_group)
dp_size = dist.get_world_size(dp_group)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, rank=dp_rank, num_replicas=dp_size, shuffle=False
)
# build dataloader
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.bs,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
prefetch_factor=args.prefetch_factor,
sampler=sampler,
collate_fn=collate_fn,
)
# prepare output file reader
output_file = args.input.replace(".csv", "_caption.csv")
# create csv writer
has_main_writer = dist.get_rank() == 0
has_dp_writter = dist.get_rank(tp_group) == 0
if has_main_writer:
# we keep track of the processed videos in main file
# so we use append mode
main_file = open(output_file, "a")
main_writer = csv.writer(main_file)
if has_dp_writter:
# the dp writer takes care of the files processed on the current dp rank
# so we use write mode
output_file_split = f"{output_file}.part{dp_rank}"
dp_file = open(output_file_split, "w")
dp_writer = csv.writer(dp_file)
# ======================================================
# 5. generate captions
# ======================================================
bs = args.bs
args.bs
if dist.get_rank() == 0:
pbar = tqdm(range(0, len(videos), bs))
if dist.get_rank(tp_group) == 0:
pbar = tqdm(dataloader, position=dp_rank, desc=f"Data Parallel Rank {dist.get_rank(dp_group)}")
else:
pbar = tqdm(range(0, len(videos), bs))
# set up a multiprocessing pool
pool = mp.Pool(args.num_worker)
pbar = dataloader
if args.profile:
encode_time = []
@ -352,34 +423,19 @@ def main(args):
generate_time = []
total_time = []
for i in pbar:
for batch in pbar:
# measure time
if args.profile:
torch.cuda.synchronize()
start_time = time.time()
video_files = videos[i : i + bs]
frames = []
video_lengths = []
with Timer() as frame_extraction_timer:
# prepare a batch of inputs with parallel frame extraction
for frame, length in pool.map(extract_frames, video_files):
if len(frame) < 3:
continue
frames.append(frame)
video_lengths.append(length)
if len(frames) == 0:
continue
video_files, frames, video_lengths, img_size_list = batch
# encode the batch of inputs
with Timer() as encode_timer:
samples = []
for imgs in frames:
imgs_size = [img.size for img in imgs]
imgs = process_images(imgs, image_processor, model.config)
imgs = imgs.cuda().to(dtype=torch.float16)
for imgs, imgs_size in zip(frames, img_size_list):
imgs = imgs.cuda()
_, _, _, _, inputs_embeds, _ = prepare_inputs_labels_for_multimodal(
model, input_ids, None, None, None, None, images=imgs, image_sizes=imgs_size
)
@ -426,7 +482,6 @@ def main(args):
time_taken = time.time() - start_time
total_time.append(time_taken)
frame_extraction_time.append(frame_extraction_timer.time_taken)
encode_time.append(encode_timer.time_taken)
generate_time.append(generate_timer.time_taken)
@ -438,7 +493,7 @@ def main(args):
# display profiling info
if args.profile:
num_samples_after_warmup = len(videos[1 * bs :]) * args.dp_size
num_samples_after_warmup = total_num_videos - args.bs * args.profile_warmup * dp_size
print(f"throughput (video/s): {num_samples_after_warmup / sum(total_time)}")
print(f"average frame extraction time per sample: {sum(frame_extraction_time) / num_samples_after_warmup}")
print(f"average encode time per sample: {sum(encode_time) / num_samples_after_warmup}")
@ -477,8 +532,9 @@ if __name__ == "__main__":
parser.add_argument("--prompt", type=str, default="three_frames")
parser.add_argument("--tp-size", type=int, default=1)
parser.add_argument("--dp-size", type=int, default=1)
parser.add_argument("--num-worker", type=int, default=8)
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--profile-warmup", type=int, default=1)
parser.add_argument("--prefetch-factor", type=int, default=8, help="Prefetch factor")
args = parser.parse_args()
main(args)