Feature/llava speedup (#2)

* [caption] accelerated llava with flash attention and parallel frame extraction

* supported dp and tp in llava

* code formatting
This commit is contained in:
Frank Lee 2024-03-27 16:55:25 +08:00 committed by GitHub
parent 62097da2d3
commit b704d6c0f8
6 changed files with 294 additions and 41 deletions

View file

@ -19,7 +19,8 @@ The cost is approximately $0.01 per video (3 frames per video). The output is a
First, install LLaVA according to their [official instructions](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#install). We use the `liuhaotian/llava-v1.6-34b` model for captioning, which can be download [here](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b). Then, run the following command to generate captions for videos with LLaVA:
```bash
CUDA_VISIBLE_DEVICES=0,1 python -m tools.caption.caption_llava samples output.csv
# we run this on 8xH800 GPUs
torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava samples output.csv --tp-size 2 --dp-size 4 --bs 16
```
The Yi-34B requires 2 80GB GPUs and 3s/sample. The output is a CSV file with path and caption.

View file

View file

@ -0,0 +1,103 @@
import warnings
from typing import Dict, Union
import torch.nn as nn
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["LlavaPolicy", "LlavaForCausalLMPolicy"]
class LlavaPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
)
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
),
],
)
return policy
def postprocess(self):
return self.model
class LlavaForCausalLMPolicy(LlavaPolicy):
def module_policy(self):
from transformers import LlamaForCausalLM
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": True}
)
],
)
}
policy.update(new_item)
return policy

View file

@ -1,9 +1,17 @@
import argparse
import csv
import math
import multiprocessing as mp
import os
import time
import warnings
import colossalai
import torch
import torch.distributed as dist
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device
from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
from llava.conversation import conv_templates
from llava.mm_utils import get_anyres_image_grid_shape, get_model_name_from_path, process_images, tokenizer_image_token
@ -12,7 +20,8 @@ from llava.model.llava_arch import unpad_image
from llava.utils import disable_torch_init
from tqdm import tqdm
from .utils import extract_frames, prompts, read_video_list
from .acceleration.llava.policy import LlavaForCausalLMPolicy
from .utils import Timer, extract_frames, prompts, read_video_list
disable_torch_init()
@ -244,14 +253,53 @@ def prepare_inputs_labels_for_multimodal(
@torch.inference_mode()
def main(args):
# ======================================================
# 1. read video list
# 1. init environment
# ======================================================
videos = read_video_list(args.video_folder, args.output_file)
f = open(args.output_file, "a")
writer = csv.writer(f)
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
# prepare the dp and tp groups
assert (
args.dp_size * args.tp_size == coordinator.world_size
), f"DP size {args.dp_size} * TP size {args.tp_size} must equal to world size {coordinator.world_size}"
mesh = ProcessGroupMesh(args.dp_size, args.tp_size)
dp_group = mesh.get_group_along_axis(0)
tp_group = mesh.get_group_along_axis(1)
# ======================================================
# 2. load model and prepare prompts
# 2. read video list
# ======================================================
videos = read_video_list(args.video_folder, args.output_file)
if len(videos) == 0:
print("No videos are found or all videos have been processed.")
return
# shard by DP
dp_size = dist.get_world_size(dp_group)
dp_rank = dist.get_rank(dp_group)
video_partition_size = math.ceil(len(videos) / dp_size)
videos = videos[dp_rank * video_partition_size : (dp_rank + 1) * video_partition_size]
# create csv writer
has_main_writer = dist.get_rank() == 0
has_dp_writter = dist.get_rank(tp_group) == 0
if has_main_writer:
# we keep track of the processed videos in main file
# so we use append mode
main_file = open(args.output_file, "a")
main_writer = csv.writer(main_file)
if has_dp_writter:
# the dp writer takes care of the files processed on the current dp rank
# so we use write mode
output_file_split = f"{args.output_file}.part{dp_rank}"
dp_file = open(output_file_split, "w")
dp_writer = csv.writer(dp_file)
# ======================================================
# 3. load model and prepare prompts
# ======================================================
model_path = "liuhaotian/llava-v1.6-34b"
query = prompts[args.prompt]
@ -266,39 +314,77 @@ def main(args):
model_path=model_path,
model_base=None,
model_name=get_model_name_from_path(model_path),
device=get_current_device(),
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
)
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
input_ids = input_ids.unsqueeze(0).to(model.device)
input_ids = input_ids.unsqueeze(0).cuda()
# ======================================================
# 3. generate captions
# 4. Apply system optimization
# ======================================================
# create huggingface model as normal
tp_size = dist.get_world_size(tp_group)
shard_config = ShardConfig(
tensor_parallel_process_group=tp_group if tp_size > 1 else None,
enable_tensor_parallelism=True if tp_size > 1 else False,
)
shard_former = ShardFormer(shard_config=shard_config)
model = shard_former.optimize(model, policy=LlavaForCausalLMPolicy())[0].cuda()
torch.cuda.empty_cache()
# ======================================================
# 5. generate captions
# ======================================================
bs = args.bs
for i in tqdm(range(0, len(videos), bs)):
# prepare a batch of inputs
if dist.get_rank() == 0:
pbar = tqdm(range(0, len(videos), bs))
else:
pbar = tqdm(range(0, len(videos), bs))
# set up a multiprocessing pool
pool = mp.Pool(args.num_worker)
if args.profile:
encode_time = []
frame_extraction_time = []
generate_time = []
total_time = []
for i in pbar:
# measure time
if args.profile:
torch.cuda.synchronize()
start_time = time.time()
video_files = videos[i : i + bs]
frames = []
video_lengths = []
for video_file in video_files:
frame, length = extract_frames(os.path.join(args.video_folder, video_file))
if len(frame) < 3:
with Timer() as frame_extraction_timer:
# prepare a batch of inputs with parallel frame extraction
for frame, length in pool.map(extract_frames, video_files):
if len(frame) < 3:
continue
frames.append(frame)
video_lengths.append(length)
if len(frames) == 0:
continue
frames.append(frame)
video_lengths.append(length)
if len(frames) == 0:
continue
# encode the batch of inputs
samples = []
for imgs in frames:
imgs_size = [img.size for img in imgs]
imgs = process_images(imgs, image_processor, model.config)
imgs = imgs.to(model.device, dtype=torch.float16)
with torch.inference_mode():
with Timer() as encode_timer:
samples = []
for imgs in frames:
imgs_size = [img.size for img in imgs]
imgs = process_images(imgs, image_processor, model.config)
imgs = imgs.cuda().to(dtype=torch.float16)
_, _, _, _, inputs_embeds, _ = prepare_inputs_labels_for_multimodal(
model, input_ids, None, None, None, None, images=imgs, image_sizes=imgs_size
)
samples.append(inputs_embeds)
samples.append(inputs_embeds)
# padding
max_len = max([sample.shape[1] for sample in samples])
@ -322,23 +408,66 @@ def main(args):
inputs_embeds = torch.cat(inputs_embeds, dim=0)
# generate outputs
output_ids = super(type(model), model).generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
do_sample=True,
temperature=0.2,
max_new_tokens=512,
use_cache=True,
)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
outputs = [output.replace("\n", " ").strip() for output in outputs]
with Timer() as generate_timer:
output_ids = super(type(model), model).generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
do_sample=args.tp_size == 1, # sampling is not deterministic and may cause TP to hang
temperature=0.2,
max_new_tokens=512,
use_cache=True,
)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
outputs = [output.replace("\n", " ").strip() for output in outputs]
# warmup for 1 iter
if args.profile and i < args.profile_warmup:
# measure time
torch.cuda.synchronize()
time_taken = time.time() - start_time
total_time.append(time_taken)
frame_extraction_time.append(frame_extraction_timer.time_taken)
encode_time.append(encode_timer.time_taken)
generate_time.append(generate_timer.time_taken)
# save results
result = list(zip(video_files, outputs, video_lengths))
for t in result:
writer.writerow(t)
if has_dp_writter:
result = list(zip(video_files, outputs, video_lengths))
for t in result:
dp_writer.writerow(t)
f.close()
# display profiling info
if args.profile:
num_samples_after_warmup = len(videos[1 * bs :]) * args.dp_size
print(f"throughput (video/s): {num_samples_after_warmup / sum(total_time)}")
print(f"average frame extraction time per sample: {sum(frame_extraction_time) / num_samples_after_warmup}")
print(f"average encode time per sample: {sum(encode_time) / num_samples_after_warmup}")
print(f"average generate time per sample: {sum(generate_time) / num_samples_after_warmup}")
print(f"Max GPU allocated / GB: {torch.cuda.max_memory_allocated() / 1024**3}")
print(f"Max GPU reserved / GB: {torch.cuda.max_memory_reserved() / 1024**3}")
# ======================================================
# 6. shutdown
# ======================================================
# close file writing
if has_dp_writter:
dp_file.close()
dist.barrier()
# merge files
if has_main_writer:
for i in range(dp_size):
output_file_split = f"{args.output_file}.part{i}"
with open(output_file_split, "r") as f:
reader = csv.reader(f)
for row in reader:
main_writer.writerow(row)
os.remove(output_file_split)
main_file.close()
# terminate distributed env
dist.destroy_process_group()
if __name__ == "__main__":
@ -347,6 +476,10 @@ if __name__ == "__main__":
parser.add_argument("output_file", type=str)
parser.add_argument("--bs", type=int, default=32)
parser.add_argument("--prompt", type=str, default="three_frames")
parser.add_argument("--tp-size", type=int, default=1)
parser.add_argument("--dp-size", type=int, default=1)
parser.add_argument("--num-worker", type=int, default=8)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--profile-warmup", type=int, default=1)
args = parser.parse_args()
main(args)

View file

@ -1,6 +1,7 @@
import base64
import csv
import os
import time
import cv2
from PIL import Image
@ -18,7 +19,7 @@ def get_filelist(file_path):
for filename in files:
ext = filename.split(".")[-1]
if ext in VID_EXTENSIONS:
Filelist.append(filename)
Filelist.append(os.path.join(home, filename))
return Filelist
@ -65,3 +66,18 @@ def read_video_list(video_folder, output_file):
videos = [video for video in videos if video not in processed_videos]
print(f"Processing {len(videos)} new videos.")
return videos
class Timer:
def __init__(self):
self.time_taken = 0
self.start_time = 0
self.end_time = 0
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, exc_tb):
self.end_time = time.time()
self.time_taken = self.end_time - self.start_time