mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-12 13:54:53 +02:00
Feature/llava speedup (#2)
* [caption] accelerated llava with flash attention and parallel frame extraction * supported dp and tp in llava * code formatting
This commit is contained in:
parent
62097da2d3
commit
b704d6c0f8
|
|
@ -19,7 +19,8 @@ The cost is approximately $0.01 per video (3 frames per video). The output is a
|
|||
First, install LLaVA according to their [official instructions](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#install). We use the `liuhaotian/llava-v1.6-34b` model for captioning, which can be download [here](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b). Then, run the following command to generate captions for videos with LLaVA:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 python -m tools.caption.caption_llava samples output.csv
|
||||
# we run this on 8xH800 GPUs
|
||||
torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava samples output.csv --tp-size 2 --dp-size 4 --bs 16
|
||||
```
|
||||
|
||||
The Yi-34B requires 2 80GB GPUs and 3s/sample. The output is a CSV file with path and caption.
|
||||
|
|
|
|||
0
tools/caption/acceleration/__init__.py
Normal file
0
tools/caption/acceleration/__init__.py
Normal file
0
tools/caption/acceleration/llava/__init__.py
Normal file
0
tools/caption/acceleration/llava/__init__.py
Normal file
103
tools/caption/acceleration/llava/policy.py
Normal file
103
tools/caption/acceleration/llava/policy.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
import warnings
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["LlavaPolicy", "LlavaForCausalLMPolicy"]
|
||||
|
||||
|
||||
class LlavaPolicy(Policy):
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
||||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
|
||||
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
||||
)
|
||||
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
||||
class LlavaForCausalLMPolicy(LlavaPolicy):
|
||||
def module_policy(self):
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
LlamaForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": True}
|
||||
)
|
||||
],
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
|
@ -1,9 +1,17 @@
|
|||
import argparse
|
||||
import csv
|
||||
import math
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.utils import get_current_device
|
||||
from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
|
||||
from llava.conversation import conv_templates
|
||||
from llava.mm_utils import get_anyres_image_grid_shape, get_model_name_from_path, process_images, tokenizer_image_token
|
||||
|
|
@ -12,7 +20,8 @@ from llava.model.llava_arch import unpad_image
|
|||
from llava.utils import disable_torch_init
|
||||
from tqdm import tqdm
|
||||
|
||||
from .utils import extract_frames, prompts, read_video_list
|
||||
from .acceleration.llava.policy import LlavaForCausalLMPolicy
|
||||
from .utils import Timer, extract_frames, prompts, read_video_list
|
||||
|
||||
disable_torch_init()
|
||||
|
||||
|
|
@ -244,14 +253,53 @@ def prepare_inputs_labels_for_multimodal(
|
|||
@torch.inference_mode()
|
||||
def main(args):
|
||||
# ======================================================
|
||||
# 1. read video list
|
||||
# 1. init environment
|
||||
# ======================================================
|
||||
videos = read_video_list(args.video_folder, args.output_file)
|
||||
f = open(args.output_file, "a")
|
||||
writer = csv.writer(f)
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# prepare the dp and tp groups
|
||||
assert (
|
||||
args.dp_size * args.tp_size == coordinator.world_size
|
||||
), f"DP size {args.dp_size} * TP size {args.tp_size} must equal to world size {coordinator.world_size}"
|
||||
mesh = ProcessGroupMesh(args.dp_size, args.tp_size)
|
||||
dp_group = mesh.get_group_along_axis(0)
|
||||
tp_group = mesh.get_group_along_axis(1)
|
||||
|
||||
# ======================================================
|
||||
# 2. load model and prepare prompts
|
||||
# 2. read video list
|
||||
# ======================================================
|
||||
videos = read_video_list(args.video_folder, args.output_file)
|
||||
|
||||
if len(videos) == 0:
|
||||
print("No videos are found or all videos have been processed.")
|
||||
return
|
||||
|
||||
# shard by DP
|
||||
dp_size = dist.get_world_size(dp_group)
|
||||
dp_rank = dist.get_rank(dp_group)
|
||||
video_partition_size = math.ceil(len(videos) / dp_size)
|
||||
videos = videos[dp_rank * video_partition_size : (dp_rank + 1) * video_partition_size]
|
||||
|
||||
# create csv writer
|
||||
has_main_writer = dist.get_rank() == 0
|
||||
has_dp_writter = dist.get_rank(tp_group) == 0
|
||||
|
||||
if has_main_writer:
|
||||
# we keep track of the processed videos in main file
|
||||
# so we use append mode
|
||||
main_file = open(args.output_file, "a")
|
||||
main_writer = csv.writer(main_file)
|
||||
|
||||
if has_dp_writter:
|
||||
# the dp writer takes care of the files processed on the current dp rank
|
||||
# so we use write mode
|
||||
output_file_split = f"{args.output_file}.part{dp_rank}"
|
||||
dp_file = open(output_file_split, "w")
|
||||
dp_writer = csv.writer(dp_file)
|
||||
|
||||
# ======================================================
|
||||
# 3. load model and prepare prompts
|
||||
# ======================================================
|
||||
model_path = "liuhaotian/llava-v1.6-34b"
|
||||
query = prompts[args.prompt]
|
||||
|
|
@ -266,39 +314,77 @@ def main(args):
|
|||
model_path=model_path,
|
||||
model_base=None,
|
||||
model_name=get_model_name_from_path(model_path),
|
||||
device=get_current_device(),
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
||||
input_ids = input_ids.unsqueeze(0).to(model.device)
|
||||
input_ids = input_ids.unsqueeze(0).cuda()
|
||||
|
||||
# ======================================================
|
||||
# 3. generate captions
|
||||
# 4. Apply system optimization
|
||||
# ======================================================
|
||||
# create huggingface model as normal
|
||||
tp_size = dist.get_world_size(tp_group)
|
||||
shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=tp_group if tp_size > 1 else None,
|
||||
enable_tensor_parallelism=True if tp_size > 1 else False,
|
||||
)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
model = shard_former.optimize(model, policy=LlavaForCausalLMPolicy())[0].cuda()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# ======================================================
|
||||
# 5. generate captions
|
||||
# ======================================================
|
||||
bs = args.bs
|
||||
for i in tqdm(range(0, len(videos), bs)):
|
||||
# prepare a batch of inputs
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
pbar = tqdm(range(0, len(videos), bs))
|
||||
else:
|
||||
pbar = tqdm(range(0, len(videos), bs))
|
||||
|
||||
# set up a multiprocessing pool
|
||||
pool = mp.Pool(args.num_worker)
|
||||
|
||||
if args.profile:
|
||||
encode_time = []
|
||||
frame_extraction_time = []
|
||||
generate_time = []
|
||||
total_time = []
|
||||
|
||||
for i in pbar:
|
||||
# measure time
|
||||
if args.profile:
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
video_files = videos[i : i + bs]
|
||||
frames = []
|
||||
video_lengths = []
|
||||
for video_file in video_files:
|
||||
frame, length = extract_frames(os.path.join(args.video_folder, video_file))
|
||||
if len(frame) < 3:
|
||||
|
||||
with Timer() as frame_extraction_timer:
|
||||
# prepare a batch of inputs with parallel frame extraction
|
||||
for frame, length in pool.map(extract_frames, video_files):
|
||||
if len(frame) < 3:
|
||||
continue
|
||||
frames.append(frame)
|
||||
video_lengths.append(length)
|
||||
|
||||
if len(frames) == 0:
|
||||
continue
|
||||
frames.append(frame)
|
||||
video_lengths.append(length)
|
||||
if len(frames) == 0:
|
||||
continue
|
||||
|
||||
# encode the batch of inputs
|
||||
samples = []
|
||||
for imgs in frames:
|
||||
imgs_size = [img.size for img in imgs]
|
||||
imgs = process_images(imgs, image_processor, model.config)
|
||||
imgs = imgs.to(model.device, dtype=torch.float16)
|
||||
with torch.inference_mode():
|
||||
with Timer() as encode_timer:
|
||||
samples = []
|
||||
for imgs in frames:
|
||||
imgs_size = [img.size for img in imgs]
|
||||
imgs = process_images(imgs, image_processor, model.config)
|
||||
imgs = imgs.cuda().to(dtype=torch.float16)
|
||||
_, _, _, _, inputs_embeds, _ = prepare_inputs_labels_for_multimodal(
|
||||
model, input_ids, None, None, None, None, images=imgs, image_sizes=imgs_size
|
||||
)
|
||||
samples.append(inputs_embeds)
|
||||
samples.append(inputs_embeds)
|
||||
|
||||
# padding
|
||||
max_len = max([sample.shape[1] for sample in samples])
|
||||
|
|
@ -322,23 +408,66 @@ def main(args):
|
|||
inputs_embeds = torch.cat(inputs_embeds, dim=0)
|
||||
|
||||
# generate outputs
|
||||
output_ids = super(type(model), model).generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=True,
|
||||
temperature=0.2,
|
||||
max_new_tokens=512,
|
||||
use_cache=True,
|
||||
)
|
||||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
outputs = [output.replace("\n", " ").strip() for output in outputs]
|
||||
with Timer() as generate_timer:
|
||||
output_ids = super(type(model), model).generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=args.tp_size == 1, # sampling is not deterministic and may cause TP to hang
|
||||
temperature=0.2,
|
||||
max_new_tokens=512,
|
||||
use_cache=True,
|
||||
)
|
||||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
outputs = [output.replace("\n", " ").strip() for output in outputs]
|
||||
|
||||
# warmup for 1 iter
|
||||
if args.profile and i < args.profile_warmup:
|
||||
# measure time
|
||||
torch.cuda.synchronize()
|
||||
time_taken = time.time() - start_time
|
||||
|
||||
total_time.append(time_taken)
|
||||
frame_extraction_time.append(frame_extraction_timer.time_taken)
|
||||
encode_time.append(encode_timer.time_taken)
|
||||
generate_time.append(generate_timer.time_taken)
|
||||
|
||||
# save results
|
||||
result = list(zip(video_files, outputs, video_lengths))
|
||||
for t in result:
|
||||
writer.writerow(t)
|
||||
if has_dp_writter:
|
||||
result = list(zip(video_files, outputs, video_lengths))
|
||||
for t in result:
|
||||
dp_writer.writerow(t)
|
||||
|
||||
f.close()
|
||||
# display profiling info
|
||||
if args.profile:
|
||||
num_samples_after_warmup = len(videos[1 * bs :]) * args.dp_size
|
||||
print(f"throughput (video/s): {num_samples_after_warmup / sum(total_time)}")
|
||||
print(f"average frame extraction time per sample: {sum(frame_extraction_time) / num_samples_after_warmup}")
|
||||
print(f"average encode time per sample: {sum(encode_time) / num_samples_after_warmup}")
|
||||
print(f"average generate time per sample: {sum(generate_time) / num_samples_after_warmup}")
|
||||
print(f"Max GPU allocated / GB: {torch.cuda.max_memory_allocated() / 1024**3}")
|
||||
print(f"Max GPU reserved / GB: {torch.cuda.max_memory_reserved() / 1024**3}")
|
||||
|
||||
# ======================================================
|
||||
# 6. shutdown
|
||||
# ======================================================
|
||||
# close file writing
|
||||
if has_dp_writter:
|
||||
dp_file.close()
|
||||
dist.barrier()
|
||||
|
||||
# merge files
|
||||
if has_main_writer:
|
||||
for i in range(dp_size):
|
||||
output_file_split = f"{args.output_file}.part{i}"
|
||||
with open(output_file_split, "r") as f:
|
||||
reader = csv.reader(f)
|
||||
for row in reader:
|
||||
main_writer.writerow(row)
|
||||
os.remove(output_file_split)
|
||||
main_file.close()
|
||||
|
||||
# terminate distributed env
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -347,6 +476,10 @@ if __name__ == "__main__":
|
|||
parser.add_argument("output_file", type=str)
|
||||
parser.add_argument("--bs", type=int, default=32)
|
||||
parser.add_argument("--prompt", type=str, default="three_frames")
|
||||
parser.add_argument("--tp-size", type=int, default=1)
|
||||
parser.add_argument("--dp-size", type=int, default=1)
|
||||
parser.add_argument("--num-worker", type=int, default=8)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument("--profile-warmup", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import base64
|
||||
import csv
|
||||
import os
|
||||
import time
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
|
@ -18,7 +19,7 @@ def get_filelist(file_path):
|
|||
for filename in files:
|
||||
ext = filename.split(".")[-1]
|
||||
if ext in VID_EXTENSIONS:
|
||||
Filelist.append(filename)
|
||||
Filelist.append(os.path.join(home, filename))
|
||||
return Filelist
|
||||
|
||||
|
||||
|
|
@ -65,3 +66,18 @@ def read_video_list(video_folder, output_file):
|
|||
videos = [video for video in videos if video not in processed_videos]
|
||||
print(f"Processing {len(videos)} new videos.")
|
||||
return videos
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.time_taken = 0
|
||||
self.start_time = 0
|
||||
self.end_time = 0
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
||||
self.end_time = time.time()
|
||||
self.time_taken = self.end_time - self.start_time
|
||||
|
|
|
|||
Loading…
Reference in a new issue