diff --git a/tools/caption/README.md b/tools/caption/README.md index e9289d0..e0aff58 100644 --- a/tools/caption/README.md +++ b/tools/caption/README.md @@ -19,7 +19,8 @@ The cost is approximately $0.01 per video (3 frames per video). The output is a First, install LLaVA according to their [official instructions](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#install). We use the `liuhaotian/llava-v1.6-34b` model for captioning, which can be download [here](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b). Then, run the following command to generate captions for videos with LLaVA: ```bash -CUDA_VISIBLE_DEVICES=0,1 python -m tools.caption.caption_llava samples output.csv +# we run this on 8xH800 GPUs +torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava samples output.csv --tp-size 2 --dp-size 4 --bs 16 ``` The Yi-34B requires 2 80GB GPUs and 3s/sample. The output is a CSV file with path and caption. diff --git a/tools/caption/acceleration/__init__.py b/tools/caption/acceleration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/caption/acceleration/llava/__init__.py b/tools/caption/acceleration/llava/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/caption/acceleration/llava/policy.py b/tools/caption/acceleration/llava/policy.py new file mode 100644 index 0000000..235d904 --- /dev/null +++ b/tools/caption/acceleration/llava/policy.py @@ -0,0 +1,103 @@ +import warnings +from typing import Dict, Union + +import torch.nn as nn +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["LlavaPolicy", "LlavaForCausalLMPolicy"] + + +class LlavaPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from transformers.models.llama.modeling_llama import LlamaDecoderLayer + + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + + if self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( + self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + ) + + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, + ), + ], + ) + + return policy + + def postprocess(self): + return self.model + + +class LlavaForCausalLMPolicy(LlavaPolicy): + def module_policy(self): + from transformers import LlamaForCausalLM + + policy = super().module_policy() + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": True} + ) + ], + ) + } + policy.update(new_item) + return policy diff --git a/tools/caption/caption_llava.py b/tools/caption/caption_llava.py index 8f4278c..1c5df73 100644 --- a/tools/caption/caption_llava.py +++ b/tools/caption/caption_llava.py @@ -1,9 +1,17 @@ import argparse import csv +import math +import multiprocessing as mp import os +import time import warnings +import colossalai import torch +import torch.distributed as dist +from colossalai.cluster import DistCoordinator, ProcessGroupMesh +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.utils import get_current_device from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX from llava.conversation import conv_templates from llava.mm_utils import get_anyres_image_grid_shape, get_model_name_from_path, process_images, tokenizer_image_token @@ -12,7 +20,8 @@ from llava.model.llava_arch import unpad_image from llava.utils import disable_torch_init from tqdm import tqdm -from .utils import extract_frames, prompts, read_video_list +from .acceleration.llava.policy import LlavaForCausalLMPolicy +from .utils import Timer, extract_frames, prompts, read_video_list disable_torch_init() @@ -244,14 +253,53 @@ def prepare_inputs_labels_for_multimodal( @torch.inference_mode() def main(args): # ====================================================== - # 1. read video list + # 1. init environment # ====================================================== - videos = read_video_list(args.video_folder, args.output_file) - f = open(args.output_file, "a") - writer = csv.writer(f) + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + # prepare the dp and tp groups + assert ( + args.dp_size * args.tp_size == coordinator.world_size + ), f"DP size {args.dp_size} * TP size {args.tp_size} must equal to world size {coordinator.world_size}" + mesh = ProcessGroupMesh(args.dp_size, args.tp_size) + dp_group = mesh.get_group_along_axis(0) + tp_group = mesh.get_group_along_axis(1) # ====================================================== - # 2. load model and prepare prompts + # 2. read video list + # ====================================================== + videos = read_video_list(args.video_folder, args.output_file) + + if len(videos) == 0: + print("No videos are found or all videos have been processed.") + return + + # shard by DP + dp_size = dist.get_world_size(dp_group) + dp_rank = dist.get_rank(dp_group) + video_partition_size = math.ceil(len(videos) / dp_size) + videos = videos[dp_rank * video_partition_size : (dp_rank + 1) * video_partition_size] + + # create csv writer + has_main_writer = dist.get_rank() == 0 + has_dp_writter = dist.get_rank(tp_group) == 0 + + if has_main_writer: + # we keep track of the processed videos in main file + # so we use append mode + main_file = open(args.output_file, "a") + main_writer = csv.writer(main_file) + + if has_dp_writter: + # the dp writer takes care of the files processed on the current dp rank + # so we use write mode + output_file_split = f"{args.output_file}.part{dp_rank}" + dp_file = open(output_file_split, "w") + dp_writer = csv.writer(dp_file) + + # ====================================================== + # 3. load model and prepare prompts # ====================================================== model_path = "liuhaotian/llava-v1.6-34b" query = prompts[args.prompt] @@ -266,39 +314,77 @@ def main(args): model_path=model_path, model_base=None, model_name=get_model_name_from_path(model_path), + device=get_current_device(), + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", ) input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") - input_ids = input_ids.unsqueeze(0).to(model.device) + input_ids = input_ids.unsqueeze(0).cuda() # ====================================================== - # 3. generate captions + # 4. Apply system optimization + # ====================================================== + # create huggingface model as normal + tp_size = dist.get_world_size(tp_group) + shard_config = ShardConfig( + tensor_parallel_process_group=tp_group if tp_size > 1 else None, + enable_tensor_parallelism=True if tp_size > 1 else False, + ) + shard_former = ShardFormer(shard_config=shard_config) + model = shard_former.optimize(model, policy=LlavaForCausalLMPolicy())[0].cuda() + torch.cuda.empty_cache() + + # ====================================================== + # 5. generate captions # ====================================================== bs = args.bs - for i in tqdm(range(0, len(videos), bs)): - # prepare a batch of inputs + + if dist.get_rank() == 0: + pbar = tqdm(range(0, len(videos), bs)) + else: + pbar = tqdm(range(0, len(videos), bs)) + + # set up a multiprocessing pool + pool = mp.Pool(args.num_worker) + + if args.profile: + encode_time = [] + frame_extraction_time = [] + generate_time = [] + total_time = [] + + for i in pbar: + # measure time + if args.profile: + torch.cuda.synchronize() + start_time = time.time() + video_files = videos[i : i + bs] frames = [] video_lengths = [] - for video_file in video_files: - frame, length = extract_frames(os.path.join(args.video_folder, video_file)) - if len(frame) < 3: + + with Timer() as frame_extraction_timer: + # prepare a batch of inputs with parallel frame extraction + for frame, length in pool.map(extract_frames, video_files): + if len(frame) < 3: + continue + frames.append(frame) + video_lengths.append(length) + + if len(frames) == 0: continue - frames.append(frame) - video_lengths.append(length) - if len(frames) == 0: - continue # encode the batch of inputs - samples = [] - for imgs in frames: - imgs_size = [img.size for img in imgs] - imgs = process_images(imgs, image_processor, model.config) - imgs = imgs.to(model.device, dtype=torch.float16) - with torch.inference_mode(): + with Timer() as encode_timer: + samples = [] + for imgs in frames: + imgs_size = [img.size for img in imgs] + imgs = process_images(imgs, image_processor, model.config) + imgs = imgs.cuda().to(dtype=torch.float16) _, _, _, _, inputs_embeds, _ = prepare_inputs_labels_for_multimodal( model, input_ids, None, None, None, None, images=imgs, image_sizes=imgs_size ) - samples.append(inputs_embeds) + samples.append(inputs_embeds) # padding max_len = max([sample.shape[1] for sample in samples]) @@ -322,23 +408,66 @@ def main(args): inputs_embeds = torch.cat(inputs_embeds, dim=0) # generate outputs - output_ids = super(type(model), model).generate( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - do_sample=True, - temperature=0.2, - max_new_tokens=512, - use_cache=True, - ) - outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - outputs = [output.replace("\n", " ").strip() for output in outputs] + with Timer() as generate_timer: + output_ids = super(type(model), model).generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + do_sample=args.tp_size == 1, # sampling is not deterministic and may cause TP to hang + temperature=0.2, + max_new_tokens=512, + use_cache=True, + ) + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + outputs = [output.replace("\n", " ").strip() for output in outputs] + + # warmup for 1 iter + if args.profile and i < args.profile_warmup: + # measure time + torch.cuda.synchronize() + time_taken = time.time() - start_time + + total_time.append(time_taken) + frame_extraction_time.append(frame_extraction_timer.time_taken) + encode_time.append(encode_timer.time_taken) + generate_time.append(generate_timer.time_taken) # save results - result = list(zip(video_files, outputs, video_lengths)) - for t in result: - writer.writerow(t) + if has_dp_writter: + result = list(zip(video_files, outputs, video_lengths)) + for t in result: + dp_writer.writerow(t) - f.close() + # display profiling info + if args.profile: + num_samples_after_warmup = len(videos[1 * bs :]) * args.dp_size + print(f"throughput (video/s): {num_samples_after_warmup / sum(total_time)}") + print(f"average frame extraction time per sample: {sum(frame_extraction_time) / num_samples_after_warmup}") + print(f"average encode time per sample: {sum(encode_time) / num_samples_after_warmup}") + print(f"average generate time per sample: {sum(generate_time) / num_samples_after_warmup}") + print(f"Max GPU allocated / GB: {torch.cuda.max_memory_allocated() / 1024**3}") + print(f"Max GPU reserved / GB: {torch.cuda.max_memory_reserved() / 1024**3}") + + # ====================================================== + # 6. shutdown + # ====================================================== + # close file writing + if has_dp_writter: + dp_file.close() + dist.barrier() + + # merge files + if has_main_writer: + for i in range(dp_size): + output_file_split = f"{args.output_file}.part{i}" + with open(output_file_split, "r") as f: + reader = csv.reader(f) + for row in reader: + main_writer.writerow(row) + os.remove(output_file_split) + main_file.close() + + # terminate distributed env + dist.destroy_process_group() if __name__ == "__main__": @@ -347,6 +476,10 @@ if __name__ == "__main__": parser.add_argument("output_file", type=str) parser.add_argument("--bs", type=int, default=32) parser.add_argument("--prompt", type=str, default="three_frames") + parser.add_argument("--tp-size", type=int, default=1) + parser.add_argument("--dp-size", type=int, default=1) + parser.add_argument("--num-worker", type=int, default=8) + parser.add_argument("--profile", action="store_true") + parser.add_argument("--profile-warmup", type=int, default=1) args = parser.parse_args() - main(args) diff --git a/tools/caption/utils.py b/tools/caption/utils.py index 3912f0c..fb6085e 100644 --- a/tools/caption/utils.py +++ b/tools/caption/utils.py @@ -1,6 +1,7 @@ import base64 import csv import os +import time import cv2 from PIL import Image @@ -18,7 +19,7 @@ def get_filelist(file_path): for filename in files: ext = filename.split(".")[-1] if ext in VID_EXTENSIONS: - Filelist.append(filename) + Filelist.append(os.path.join(home, filename)) return Filelist @@ -65,3 +66,18 @@ def read_video_list(video_folder, output_file): videos = [video for video in videos if video not in processed_videos] print(f"Processing {len(videos)} new videos.") return videos + + +class Timer: + def __init__(self): + self.time_taken = 0 + self.start_time = 0 + self.end_time = 0 + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self.end_time = time.time() + self.time_taken = self.end_time - self.start_time