From cabf1d7746b376c39da4d8d1f70ea51d16fad34d Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 3 Apr 2024 14:42:59 +0800 Subject: [PATCH] [feature] support dp for matching (#29) --- tools/scoring/README.md | 4 +- tools/scoring/matching/inference.py | 30 ++-- tools/scoring/matching/inference_parallel.py | 154 +++++++++++++++++++ 3 files changed, 176 insertions(+), 12 deletions(-) create mode 100644 tools/scoring/matching/inference_parallel.py diff --git a/tools/scoring/README.md b/tools/scoring/README.md index d3270e1..8b8544a 100644 --- a/tools/scoring/README.md +++ b/tools/scoring/README.md @@ -46,7 +46,7 @@ wget https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmfl Then run: ``` -python tools/scoring/optical_flow/inference.py /path/to/meta.csv +torchrun --standalone --nproc_per_node 8 tools/scoring/optical_flow/inference_parallel.py /path/to/meta.csv ``` The output should be `/path/to/meta_flow.csv` with column `flow`. @@ -57,6 +57,6 @@ For videos, we compute the matching score of the middle frame and the caption. **Make sure** meta files contain the column `text`, which is the caption of the sample. Then run: ``` -python tools/scoring/matching/inference.py /path/to/meta.csv +torchrun --standalone --nproc_per_node 8 tools/scoring/matching/inference_parallel.py /path/to/meta.csv ``` The output should be `/path/to/meta_match.csv` with column `match`. Higher matching scores indicate better image-text/video-text alignment. diff --git a/tools/scoring/matching/inference.py b/tools/scoring/matching/inference.py index f162cc2..a332583 100644 --- a/tools/scoring/matching/inference.py +++ b/tools/scoring/matching/inference.py @@ -2,17 +2,25 @@ import argparse import os import av +import clip import numpy as np import pandas as pd import torch import torch.nn.functional as F from torchvision.datasets.folder import pil_loader - from tqdm import tqdm -import clip - -IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") +IMG_EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".ppm", + ".bmp", + ".pgm", + ".tif", + ".tiff", + ".webp", +) VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") @@ -27,7 +35,9 @@ def extract_frames(video_path, points=[0.5]): frames = [] for point in points: target_frame = total_frames * point - target_timestamp = int((target_frame * av.time_base) / container.streams.video[0].average_rate) + target_timestamp = int( + (target_frame * av.time_base) / container.streams.video[0].average_rate + ) container.seek(target_timestamp) frame = next(container.decode(video=0)).to_image() frames.append(frame) @@ -42,17 +52,17 @@ class VideoTextDataset(torch.utils.data.Dataset): def __getitem__(self, index): row = self.meta.iloc[index] - path = row['path'] - + path = row["path"] + if is_video(path): img = extract_frames(path, points=[0.5])[0] else: img = pil_loader(path) - + img = self.transform(img) - text = row['text'] - text = clip.tokenize(text).squeeze() + text = row["text"] + text = clip.tokenize(text, truncate=True).squeeze() return img, text diff --git a/tools/scoring/matching/inference_parallel.py b/tools/scoring/matching/inference_parallel.py new file mode 100644 index 0000000..e8a1704 --- /dev/null +++ b/tools/scoring/matching/inference_parallel.py @@ -0,0 +1,154 @@ +import argparse +import os + +import av +import clip +import colossalai +import numpy as np +import pandas as pd +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.utils.data import DataLoader, DistributedSampler +from torchvision.datasets.folder import pil_loader +from tqdm import tqdm + +IMG_EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".ppm", + ".bmp", + ".pgm", + ".tif", + ".tiff", + ".webp", +) +VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") + + +def is_video(filename): + ext = os.path.splitext(filename)[-1].lower() + return ext in VID_EXTENSIONS + + +def extract_frames(video_path, points=[0.5]): + container = av.open(video_path) + total_frames = container.streams.video[0].frames + frames = [] + for point in points: + target_frame = total_frames * point + target_timestamp = int( + (target_frame * av.time_base) / container.streams.video[0].average_rate + ) + container.seek(target_timestamp) + frame = next(container.decode(video=0)).to_image() + frames.append(frame) + return frames + + +class VideoTextDataset(torch.utils.data.Dataset): + def __init__(self, meta_path, transform): + self.meta_path = meta_path + self.meta = pd.read_csv(meta_path) + self.transform = transform + + def __getitem__(self, index): + row = self.meta.iloc[index] + path = row["path"] + + if is_video(path): + img = extract_frames(path, points=[0.5])[0] + else: + img = pil_loader(path) + + img = self.transform(img) + + text = row["text"] + text = clip.tokenize(text, truncate=True).squeeze() + + return img, text, index + + def __len__(self): + return len(self.meta) + + +def merge_scores(gathered_list: list, meta: pd.DataFrame): + # reorder + indices_list = list(map(lambda x: x[0], gathered_list)) + scores_list = list(map(lambda x: x[1], gathered_list)) + flat_indices = [] + for x in zip(*indices_list): + flat_indices.extend(x) + flat_scores = [] + for x in zip(*scores_list): + flat_scores.extend(x) + flat_indices = np.array(flat_indices) + flat_scores = np.array(flat_scores) + # filter duplicates + unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True) + meta.loc[unique_indices, "match"] = flat_scores[unique_indices_idx] + + +def main(): + colossalai.launch_from_torch({}) + parser = argparse.ArgumentParser() + parser.add_argument("meta_path", type=str, help="Path to the input CSV file") + parser.add_argument("--bs", type=int, default=16, help="Batch size") + parser.add_argument("--num_workers", type=int, default=16, help="Number of workers") + args = parser.parse_args() + + meta_path = args.meta_path + wo_ext, ext = os.path.splitext(meta_path) + out_path = f"{wo_ext}_match{ext}" + + # build model + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model, preprocess = clip.load("ViT-L/14", device=device) + logit_scale = model.logit_scale.exp().item() + + # build dataset + dataset = VideoTextDataset(meta_path=meta_path, transform=preprocess) + dataloader = DataLoader( + dataset, + batch_size=args.bs, + num_workers=args.num_workers, + sampler=DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=False, + drop_last=False, + ), + ) + + # compute scores + dataset.meta["match"] = np.nan + indices_list = [] + scores_list = [] + model.eval() + for imgs, text, indices in tqdm(dataloader, disable=dist.get_rank() != 0): + imgs = imgs.to(device) + text = text.to(device) + + with torch.no_grad(): + feat_img = model.encode_image(imgs) + feat_text = model.encode_text(text) + + feat_img = F.normalize(feat_img, dim=1) + feat_text = F.normalize(feat_text, dim=1) + clip_scores = logit_scale * (feat_img * feat_text).sum(dim=1) + clip_scores = clip_scores.cpu().tolist() + indices_list.extend(indices) + scores_list.extend(clip_scores) + + gathered_list = [None] * dist.get_world_size() + dist.all_gather_object(gathered_list, (indices_list, scores_list)) + if dist.get_rank() == 0: + merge_scores(gathered_list, dataset.meta) + dataset.meta.to_csv(out_path, index=False) + print(f"New meta with matching scores saved to '{out_path}'.") + + +if __name__ == "__main__": + main()