From 3bbee0043666d96099d2ccce866441c0407d5f0e Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 4 Apr 2024 16:03:43 +0800 Subject: [PATCH] accelerate aesthetic scoring (#32) * accelerate aesthetic scoring * polish --- tools/scoring/README.md | 4 +- tools/scoring/aesthetic/inference.py | 72 ++++++++++++++++++++-------- 2 files changed, 56 insertions(+), 20 deletions(-) diff --git a/tools/scoring/README.md b/tools/scoring/README.md index 8b8544a..6466d75 100644 --- a/tools/scoring/README.md +++ b/tools/scoring/README.md @@ -32,9 +32,11 @@ With `meta.csv` containing the paths to the videos, run the following command: ```bash # output: meta_aes.csv -python -m tools.scoring.aesthetic.inference meta.csv +torchrun --nproc_per_node 8 -m tools.scoring.aesthetic.inference meta.csv --bs 1024 --num_workers 16 ``` +This will generate multiple part files, you can use `python -m tools.datasets.csvutil DATA1.csv DATA2.csv` to merge these part files. + ## Optical Flow Score Optical flow scores are used to assess the motion of a video. Higher optical flow scores indicate larger movement. TODO: acknowledge UniMatch. diff --git a/tools/scoring/aesthetic/inference.py b/tools/scoring/aesthetic/inference.py index d73a133..243e765 100644 --- a/tools/scoring/aesthetic/inference.py +++ b/tools/scoring/aesthetic/inference.py @@ -1,19 +1,30 @@ # adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py import argparse import os +from datetime import timedelta import clip import decord import numpy as np import pandas as pd import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from colossalai.utils import set_seed from einops import rearrange from PIL import Image from torchvision.datasets.folder import pil_loader from tqdm import tqdm +try: + from torchvision.transforms import InterpolationMode + + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") @@ -49,8 +60,7 @@ class VideoTextDataset(torch.utils.data.Dataset): images = extract_frames(sample["path"], points=self.points) images = [self.transform(img) for img in images] images = torch.stack(images) - - return dict(index=index, image=images) + return dict(index=index, images=images) def __len__(self): return len(self.data) @@ -96,39 +106,62 @@ class AestheticScorer(nn.Module): @torch.inference_mode() def main(args): - output_file = args.input.replace(".csv", "_aes.csv") + dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) + torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) + set_seed(1024) + rank = dist.get_rank() + world_size = dist.get_world_size() + + output_file = args.input.replace(".csv", f"_aes_part{rank}.csv") # build model device = "cuda" if torch.cuda.is_available() else "cpu" model = AestheticScorer(768, device) preprocess = model.preprocess - model = torch.nn.DataParallel(model) # build dataset dataset = VideoTextDataset(args.input, transform=preprocess) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset=dataset, num_replicas=world_size, rank=rank, shuffle=False + ) + dataloader = torch.utils.data.DataLoader( dataset, + sampler=sampler, batch_size=args.bs, shuffle=False, - num_workers=0, + num_workers=args.num_workers, pin_memory=True, - # prefetch_factor=args.prefetch_factor, + prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None, ) # compute aesthetic scores dataset.data["aes"] = np.nan - index = 0 - for batch in tqdm(dataloader): - images = batch["image"].to(device) - B = images.shape[0] - images = rearrange(images, "b p c h w -> (b p) c h w") - with torch.no_grad(): + + with tqdm(dataloader, position=rank, desc=f"Data Parallel Rank {rank}") as t: + for idx, batch in enumerate(t): + if idx == 12: + break + + image_indices = batch["index"] + images = batch["images"].to(device, non_blocking=True) + B = images.shape[0] + images = rearrange(images, "b p c h w -> (b p) c h w") + + # compute score scores = model(images) - scores = rearrange(scores, "(b p) 1 -> b p", b=B) - scores = scores.mean(dim=1) - scores_np = scores.cpu().numpy() - dataset.data.loc[index : index + len(scores_np) - 1, "aes"] = scores_np - index += len(images) + scores = rearrange(scores, "(b p) 1 -> b p", b=B) + scores = scores.mean(dim=1) + scores_np = scores.to(torch.float32).cpu().numpy() + + # assign the score + dataset.data.loc[image_indices, "aes"] = scores_np + + # wait for all ranks to finish data processing + dist.barrier() + + # exclude rows whose aes is nanĀ and save file + dataset.data = dataset.data[dataset.data["aes"] > 0] dataset.data.to_csv(output_file, index=False) print(f"New meta with aesthetic scores saved to '{output_file}'.") @@ -137,8 +170,9 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("input", type=str, help="Path to the input CSV file") parser.add_argument("--bs", type=int, default=512, help="Batch size") - parser.add_argument("--num_workers", type=int, default=64, help="Number of workers") - parser.add_argument("--prefetch_factor", type=int, default=8, help="Prefetch factor") + parser.add_argument("--num_workers", type=int, default=16, help="Number of workers") + parser.add_argument("--accumulate", type=int, default=1, help="batch to accumulate") + parser.add_argument("--prefetch_factor", type=int, default=2, help="Prefetch factor") args = parser.parse_args() main(args)