mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
refactored llava captioning (#12)
This commit is contained in:
parent
fb6bd34443
commit
0f2bb1700b
|
|
@ -1,24 +1,23 @@
|
|||
import argparse
|
||||
import csv
|
||||
import math
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from datetime import timedelta
|
||||
|
||||
import colossalai
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils import get_current_device, set_seed
|
||||
from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
|
||||
from llava.conversation import conv_templates
|
||||
from llava.mm_utils import get_anyres_image_grid_shape, get_model_name_from_path, process_images, tokenizer_image_token
|
||||
from llava.model.builder import load_pretrained_model
|
||||
from llava.model.llava_arch import unpad_image
|
||||
from llava.utils import disable_torch_init
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from .acceleration.llava.policy import LlavaForCausalLMPolicy
|
||||
|
|
@ -26,6 +25,18 @@ from .utils import Timer, extract_frames, prompts
|
|||
|
||||
disable_torch_init()
|
||||
|
||||
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
|
||||
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||
|
||||
|
||||
def is_video(filename):
|
||||
ext = os.path.splitext(filename)[-1].lower()
|
||||
return ext in VID_EXTENSIONS
|
||||
|
||||
|
||||
def get_image(image_path):
|
||||
return Image.open(image_path).convert("RGB")
|
||||
|
||||
|
||||
def prepare_inputs_labels_for_multimodal(
|
||||
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
|
||||
|
|
@ -251,12 +262,52 @@ def prepare_inputs_labels_for_multimodal(
|
|||
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
||||
|
||||
|
||||
class VideoTextDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, csv_path, transform=None, points=(0.1, 0.5, 0.9)):
|
||||
self.csv_path = csv_path
|
||||
self.transform = transform
|
||||
self.data = pd.read_csv(csv_path)
|
||||
self.points = points
|
||||
|
||||
def getitem(self, index):
|
||||
sample = self.data.iloc[index]
|
||||
path = sample["path"]
|
||||
if not is_video(path):
|
||||
images = [get_image(path)]
|
||||
length = 1
|
||||
else:
|
||||
images, length = extract_frames(sample["path"], points=self.points)
|
||||
imgs_size = [img.size for img in images]
|
||||
images = self.transform(images)
|
||||
|
||||
# we put images into a list as pytorch dataloader does not accept Pill
|
||||
out = dict(path=path, image=images, length=length, img_size=imgs_size)
|
||||
return out
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.getitem(index)
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
paths = [item["path"] for item in batch]
|
||||
images = [item["image"] for item in batch]
|
||||
lengths = [item["length"] for item in batch]
|
||||
img_sizes = [item["img_size"] for item in batch]
|
||||
return paths, images, lengths, img_sizes
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(args):
|
||||
# ======================================================
|
||||
# 1. init environment
|
||||
# ======================================================
|
||||
colossalai.launch_from_torch({})
|
||||
# we set a very large timeout to avoid some processes exit early
|
||||
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
||||
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
||||
set_seed(1024)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# prepare the dp and tp groups
|
||||
|
|
@ -267,36 +318,6 @@ def main(args):
|
|||
dp_group = mesh.get_group_along_axis(0)
|
||||
tp_group = mesh.get_group_along_axis(1)
|
||||
|
||||
# ======================================================
|
||||
# 2. read video list
|
||||
# ======================================================
|
||||
output_file = args.input.replace(".csv", "_caption.csv")
|
||||
data_info = pd.read_csv(args.input)
|
||||
videos = data_info["path"].tolist()
|
||||
|
||||
# shard by DP
|
||||
dp_size = dist.get_world_size(dp_group)
|
||||
dp_rank = dist.get_rank(dp_group)
|
||||
video_partition_size = math.ceil(len(videos) / dp_size)
|
||||
videos = videos[dp_rank * video_partition_size : (dp_rank + 1) * video_partition_size]
|
||||
|
||||
# create csv writer
|
||||
has_main_writer = dist.get_rank() == 0
|
||||
has_dp_writter = dist.get_rank(tp_group) == 0
|
||||
|
||||
if has_main_writer:
|
||||
# we keep track of the processed videos in main file
|
||||
# so we use append mode
|
||||
main_file = open(output_file, "a")
|
||||
main_writer = csv.writer(main_file)
|
||||
|
||||
if has_dp_writter:
|
||||
# the dp writer takes care of the files processed on the current dp rank
|
||||
# so we use write mode
|
||||
output_file_split = f"{output_file}.part{dp_rank}"
|
||||
dp_file = open(output_file_split, "w")
|
||||
dp_writer = csv.writer(dp_file)
|
||||
|
||||
# ======================================================
|
||||
# 3. load model and prepare prompts
|
||||
# ======================================================
|
||||
|
|
@ -317,6 +338,7 @@ def main(args):
|
|||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
dist.barrier()
|
||||
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
||||
input_ids = input_ids.unsqueeze(0).cuda()
|
||||
|
||||
|
|
@ -333,18 +355,67 @@ def main(args):
|
|||
model = shard_former.optimize(model, policy=LlavaForCausalLMPolicy())[0].cuda()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# ======================================================
|
||||
# 5. Prepare dataloader
|
||||
# ======================================================
|
||||
# build dataset
|
||||
def transform(imgs):
|
||||
imgs = process_images(imgs, image_processor, model.config)
|
||||
imgs = imgs.to(dtype=torch.float16)
|
||||
return imgs
|
||||
|
||||
dataset = VideoTextDataset(args.input, points=(0.2, 0.5, 0.8), transform=transform)
|
||||
total_num_videos = len(dataset)
|
||||
|
||||
# build sampler
|
||||
dp_rank = dist.get_rank(dp_group)
|
||||
dp_size = dist.get_world_size(dp_group)
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset, rank=dp_rank, num_replicas=dp_size, shuffle=False
|
||||
)
|
||||
|
||||
# build dataloader
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.bs,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=args.prefetch_factor,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
# prepare output file reader
|
||||
output_file = args.input.replace(".csv", "_caption.csv")
|
||||
|
||||
# create csv writer
|
||||
has_main_writer = dist.get_rank() == 0
|
||||
has_dp_writter = dist.get_rank(tp_group) == 0
|
||||
|
||||
if has_main_writer:
|
||||
# we keep track of the processed videos in main file
|
||||
# so we use append mode
|
||||
main_file = open(output_file, "a")
|
||||
main_writer = csv.writer(main_file)
|
||||
|
||||
if has_dp_writter:
|
||||
# the dp writer takes care of the files processed on the current dp rank
|
||||
# so we use write mode
|
||||
output_file_split = f"{output_file}.part{dp_rank}"
|
||||
dp_file = open(output_file_split, "w")
|
||||
dp_writer = csv.writer(dp_file)
|
||||
|
||||
# ======================================================
|
||||
# 5. generate captions
|
||||
# ======================================================
|
||||
bs = args.bs
|
||||
args.bs
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
pbar = tqdm(range(0, len(videos), bs))
|
||||
if dist.get_rank(tp_group) == 0:
|
||||
pbar = tqdm(dataloader, position=dp_rank, desc=f"Data Parallel Rank {dist.get_rank(dp_group)}")
|
||||
else:
|
||||
pbar = tqdm(range(0, len(videos), bs))
|
||||
|
||||
# set up a multiprocessing pool
|
||||
pool = mp.Pool(args.num_worker)
|
||||
pbar = dataloader
|
||||
|
||||
if args.profile:
|
||||
encode_time = []
|
||||
|
|
@ -352,34 +423,19 @@ def main(args):
|
|||
generate_time = []
|
||||
total_time = []
|
||||
|
||||
for i in pbar:
|
||||
for batch in pbar:
|
||||
# measure time
|
||||
if args.profile:
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
video_files = videos[i : i + bs]
|
||||
frames = []
|
||||
video_lengths = []
|
||||
|
||||
with Timer() as frame_extraction_timer:
|
||||
# prepare a batch of inputs with parallel frame extraction
|
||||
for frame, length in pool.map(extract_frames, video_files):
|
||||
if len(frame) < 3:
|
||||
continue
|
||||
frames.append(frame)
|
||||
video_lengths.append(length)
|
||||
|
||||
if len(frames) == 0:
|
||||
continue
|
||||
video_files, frames, video_lengths, img_size_list = batch
|
||||
|
||||
# encode the batch of inputs
|
||||
with Timer() as encode_timer:
|
||||
samples = []
|
||||
for imgs in frames:
|
||||
imgs_size = [img.size for img in imgs]
|
||||
imgs = process_images(imgs, image_processor, model.config)
|
||||
imgs = imgs.cuda().to(dtype=torch.float16)
|
||||
for imgs, imgs_size in zip(frames, img_size_list):
|
||||
imgs = imgs.cuda()
|
||||
_, _, _, _, inputs_embeds, _ = prepare_inputs_labels_for_multimodal(
|
||||
model, input_ids, None, None, None, None, images=imgs, image_sizes=imgs_size
|
||||
)
|
||||
|
|
@ -426,7 +482,6 @@ def main(args):
|
|||
time_taken = time.time() - start_time
|
||||
|
||||
total_time.append(time_taken)
|
||||
frame_extraction_time.append(frame_extraction_timer.time_taken)
|
||||
encode_time.append(encode_timer.time_taken)
|
||||
generate_time.append(generate_timer.time_taken)
|
||||
|
||||
|
|
@ -438,7 +493,7 @@ def main(args):
|
|||
|
||||
# display profiling info
|
||||
if args.profile:
|
||||
num_samples_after_warmup = len(videos[1 * bs :]) * args.dp_size
|
||||
num_samples_after_warmup = total_num_videos - args.bs * args.profile_warmup * dp_size
|
||||
print(f"throughput (video/s): {num_samples_after_warmup / sum(total_time)}")
|
||||
print(f"average frame extraction time per sample: {sum(frame_extraction_time) / num_samples_after_warmup}")
|
||||
print(f"average encode time per sample: {sum(encode_time) / num_samples_after_warmup}")
|
||||
|
|
@ -477,8 +532,9 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--prompt", type=str, default="three_frames")
|
||||
parser.add_argument("--tp-size", type=int, default=1)
|
||||
parser.add_argument("--dp-size", type=int, default=1)
|
||||
parser.add_argument("--num-worker", type=int, default=8)
|
||||
parser.add_argument("--num-workers", type=int, default=8)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument("--profile-warmup", type=int, default=1)
|
||||
parser.add_argument("--prefetch-factor", type=int, default=8, help="Prefetch factor")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
|
|
|||
Loading…
Reference in a new issue