accelerate aesthetic

This commit is contained in:
Zangwei Zheng 2024-03-25 20:54:02 +08:00
parent 13423c57a7
commit 4d3b68e3ad

View file

@ -1,52 +1,47 @@
# adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py
import argparse
import av
import clip
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from einops import rearrange
from tqdm import tqdm
def get_video_length(cap):
return int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
def extract_frames(video_path, points=(0.5,)):
cap = cv2.VideoCapture(video_path)
length = get_video_length(cap)
points = [int(length * point) for point in points]
def extract_frames(video_path, points=(0.0, 0.5, 0.9)):
container = av.open(video_path)
total_frames = container.streams.video[0].frames
frames = []
if length < 3:
return frames, length
for point in points:
cap.set(cv2.CAP_PROP_POS_FRAMES, point)
ret, frame = cap.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
target_frame = total_frames * point
target_timestamp = int((target_frame * av.time_base) / container.streams.video[0].average_rate)
container.seek(target_timestamp)
frame = next(container.decode(video=0)).to_image()
frames.append(frame)
if len(frames) == 1:
frames = frames[0]
return frames, length
return frames
class VideoTextDataset(torch.utils.data.Dataset):
def __init__(self, csv_path, transform=None):
def __init__(self, csv_path, transform=None, points=(0.1, 0.5, 0.9)):
self.csv_path = csv_path
self.samples = pd.read_csv(csv_path, header=None)
self.data = pd.read_csv(csv_path)
self.transform = transform
self.points = points
def getitem(self, index):
sample = self.samples.iloc[index]
img = extract_frames(sample[0])[0]
img = self.transform(img)
text = sample[1]
sample = self.data.iloc[index]
images = extract_frames(sample["path"], points=self.points)
images = [self.transform(img) for img in images]
images = torch.stack(images)
return dict(index=index, image=img, text=text)
return dict(index=index, image=images)
def __len__(self):
return len(self.samples)
return len(self.data)
def __getitem__(self, index):
return self.getitem(index)
@ -87,29 +82,48 @@ class AestheticScorer(nn.Module):
return self.mlp(image_features)
def main():
def main(args):
output_file = args.input.replace(".csv", "_aes.csv")
# build model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AestheticScorer(768, device)
dataset = VideoTextDataset(
"/mnt/hdd/data/VidProM/VidProM_pika/meta/vidprom_relength_fmin_48_clean_en_unescape_nourl.csv",
transform=model.preprocess,
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1024, shuffle=False, num_workers=16, pin_memory=True)
dataset.samples["aesthetic"] = ""
preprocess = model.preprocess
model = torch.nn.DataParallel(model)
output_file = "vidprom_aes.csv"
# build dataset
dataset = VideoTextDataset(args.input, transform=preprocess, points=(0.5,))
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.bs,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
prefetch_factor=args.prefetch_factor,
)
# compute aesthetic scores
dataset.data["aesthetic"] = np.nan
index = 0
for batch in tqdm(dataloader):
image = batch["image"].to(device)
images = batch["image"].to(device)
B = images.shape[0]
images = rearrange(images, "b p c h w -> (b p) c h w")
with torch.no_grad():
score = model(image)
dataset.samples.loc[index : index + len(score) - 1, "aesthetic"] = score.cpu().numpy().flatten()
index += len(score)
dataset.samples.to_csv(output_file, index=False, header=False)
print(f"Saved {index} samples")
scores = model(images)
scores = rearrange(scores, "(b p) 1 -> b p", b=B)
scores = scores.mean(dim=1)
scores_np = scores.cpu().numpy()
dataset.data.loc[index : index + len(scores_np) - 1, "aesthetic"] = scores_np
index += len(images)
dataset.data.to_csv(output_file, index=False)
if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, help="Path to the input CSV file")
parser.add_argument("--bs", type=int, default=512, help="Batch size")
parser.add_argument("--num_workers", type=int, default=64, help="Number of workers")
parser.add_argument("--prefetch_factor", type=int, default=8, help="Prefetch factor")
args = parser.parse_args()
main(args)