From 4d3b68e3ad3c281a17735bf3916a031f5ec7c940 Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Mon, 25 Mar 2024 20:54:02 +0800 Subject: [PATCH] accelerate aesthetic --- tools/aesthetic/inference.py | 102 ++++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 44 deletions(-) diff --git a/tools/aesthetic/inference.py b/tools/aesthetic/inference.py index 99287af..1e2a3f7 100644 --- a/tools/aesthetic/inference.py +++ b/tools/aesthetic/inference.py @@ -1,52 +1,47 @@ # adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py +import argparse + +import av import clip -import cv2 +import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F -from PIL import Image +from einops import rearrange from tqdm import tqdm -def get_video_length(cap): - return int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - - -def extract_frames(video_path, points=(0.5,)): - cap = cv2.VideoCapture(video_path) - length = get_video_length(cap) - points = [int(length * point) for point in points] +def extract_frames(video_path, points=(0.0, 0.5, 0.9)): + container = av.open(video_path) + total_frames = container.streams.video[0].frames frames = [] - if length < 3: - return frames, length for point in points: - cap.set(cv2.CAP_PROP_POS_FRAMES, point) - ret, frame = cap.read() - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frame = Image.fromarray(frame) + target_frame = total_frames * point + target_timestamp = int((target_frame * av.time_base) / container.streams.video[0].average_rate) + container.seek(target_timestamp) + frame = next(container.decode(video=0)).to_image() frames.append(frame) - if len(frames) == 1: - frames = frames[0] - return frames, length + return frames class VideoTextDataset(torch.utils.data.Dataset): - def __init__(self, csv_path, transform=None): + def __init__(self, csv_path, transform=None, points=(0.1, 0.5, 0.9)): self.csv_path = csv_path - self.samples = pd.read_csv(csv_path, header=None) + self.data = pd.read_csv(csv_path) self.transform = transform + self.points = points def getitem(self, index): - sample = self.samples.iloc[index] - img = extract_frames(sample[0])[0] - img = self.transform(img) - text = sample[1] + sample = self.data.iloc[index] + images = extract_frames(sample["path"], points=self.points) + images = [self.transform(img) for img in images] + images = torch.stack(images) - return dict(index=index, image=img, text=text) + return dict(index=index, image=images) def __len__(self): - return len(self.samples) + return len(self.data) def __getitem__(self, index): return self.getitem(index) @@ -87,29 +82,48 @@ class AestheticScorer(nn.Module): return self.mlp(image_features) -def main(): +def main(args): + output_file = args.input.replace(".csv", "_aes.csv") + + # build model device = "cuda" if torch.cuda.is_available() else "cpu" model = AestheticScorer(768, device) - - dataset = VideoTextDataset( - "/mnt/hdd/data/VidProM/VidProM_pika/meta/vidprom_relength_fmin_48_clean_en_unescape_nourl.csv", - transform=model.preprocess, - ) - dataloader = torch.utils.data.DataLoader(dataset, batch_size=1024, shuffle=False, num_workers=16, pin_memory=True) - dataset.samples["aesthetic"] = "" + preprocess = model.preprocess model = torch.nn.DataParallel(model) - output_file = "vidprom_aes.csv" + + # build dataset + dataset = VideoTextDataset(args.input, transform=preprocess, points=(0.5,)) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.bs, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + prefetch_factor=args.prefetch_factor, + ) + + # compute aesthetic scores + dataset.data["aesthetic"] = np.nan index = 0 for batch in tqdm(dataloader): - image = batch["image"].to(device) + images = batch["image"].to(device) + B = images.shape[0] + images = rearrange(images, "b p c h w -> (b p) c h w") with torch.no_grad(): - score = model(image) - dataset.samples.loc[index : index + len(score) - 1, "aesthetic"] = score.cpu().numpy().flatten() - index += len(score) - - dataset.samples.to_csv(output_file, index=False, header=False) - print(f"Saved {index} samples") + scores = model(images) + scores = rearrange(scores, "(b p) 1 -> b p", b=B) + scores = scores.mean(dim=1) + scores_np = scores.cpu().numpy() + dataset.data.loc[index : index + len(scores_np) - 1, "aesthetic"] = scores_np + index += len(images) + dataset.data.to_csv(output_file, index=False) if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("input", type=str, help="Path to the input CSV file") + parser.add_argument("--bs", type=int, default=512, help="Batch size") + parser.add_argument("--num_workers", type=int, default=64, help="Number of workers") + parser.add_argument("--prefetch_factor", type=int, default=8, help="Prefetch factor") + args = parser.parse_args() + main(args)