accelerate aesthetic scoring (#32)

* accelerate aesthetic scoring

* polish
This commit is contained in:
Frank Lee 2024-04-04 16:03:43 +08:00 committed by GitHub
parent 1bb49f75e7
commit 3bbee00436
2 changed files with 56 additions and 20 deletions

View file

@ -32,9 +32,11 @@ With `meta.csv` containing the paths to the videos, run the following command:
```bash
# output: meta_aes.csv
python -m tools.scoring.aesthetic.inference meta.csv
torchrun --nproc_per_node 8 -m tools.scoring.aesthetic.inference meta.csv --bs 1024 --num_workers 16
```
This will generate multiple part files, you can use `python -m tools.datasets.csvutil DATA1.csv DATA2.csv` to merge these part files.
## Optical Flow Score
Optical flow scores are used to assess the motion of a video. Higher optical flow scores indicate larger movement.
TODO: acknowledge UniMatch.

View file

@ -1,19 +1,30 @@
# adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py
import argparse
import os
from datetime import timedelta
import clip
import decord
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from colossalai.utils import set_seed
from einops import rearrange
from PIL import Image
from torchvision.datasets.folder import pil_loader
from tqdm import tqdm
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
@ -49,8 +60,7 @@ class VideoTextDataset(torch.utils.data.Dataset):
images = extract_frames(sample["path"], points=self.points)
images = [self.transform(img) for img in images]
images = torch.stack(images)
return dict(index=index, image=images)
return dict(index=index, images=images)
def __len__(self):
return len(self.data)
@ -96,39 +106,62 @@ class AestheticScorer(nn.Module):
@torch.inference_mode()
def main(args):
output_file = args.input.replace(".csv", "_aes.csv")
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(1024)
rank = dist.get_rank()
world_size = dist.get_world_size()
output_file = args.input.replace(".csv", f"_aes_part{rank}.csv")
# build model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AestheticScorer(768, device)
preprocess = model.preprocess
model = torch.nn.DataParallel(model)
# build dataset
dataset = VideoTextDataset(args.input, transform=preprocess)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset=dataset, num_replicas=world_size, rank=rank, shuffle=False
)
dataloader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=args.bs,
shuffle=False,
num_workers=0,
num_workers=args.num_workers,
pin_memory=True,
# prefetch_factor=args.prefetch_factor,
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
)
# compute aesthetic scores
dataset.data["aes"] = np.nan
index = 0
for batch in tqdm(dataloader):
images = batch["image"].to(device)
B = images.shape[0]
images = rearrange(images, "b p c h w -> (b p) c h w")
with torch.no_grad():
with tqdm(dataloader, position=rank, desc=f"Data Parallel Rank {rank}") as t:
for idx, batch in enumerate(t):
if idx == 12:
break
image_indices = batch["index"]
images = batch["images"].to(device, non_blocking=True)
B = images.shape[0]
images = rearrange(images, "b p c h w -> (b p) c h w")
# compute score
scores = model(images)
scores = rearrange(scores, "(b p) 1 -> b p", b=B)
scores = scores.mean(dim=1)
scores_np = scores.cpu().numpy()
dataset.data.loc[index : index + len(scores_np) - 1, "aes"] = scores_np
index += len(images)
scores = rearrange(scores, "(b p) 1 -> b p", b=B)
scores = scores.mean(dim=1)
scores_np = scores.to(torch.float32).cpu().numpy()
# assign the score
dataset.data.loc[image_indices, "aes"] = scores_np
# wait for all ranks to finish data processing
dist.barrier()
# exclude rows whose aes is nan and save file
dataset.data = dataset.data[dataset.data["aes"] > 0]
dataset.data.to_csv(output_file, index=False)
print(f"New meta with aesthetic scores saved to '{output_file}'.")
@ -137,8 +170,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, help="Path to the input CSV file")
parser.add_argument("--bs", type=int, default=512, help="Batch size")
parser.add_argument("--num_workers", type=int, default=64, help="Number of workers")
parser.add_argument("--prefetch_factor", type=int, default=8, help="Prefetch factor")
parser.add_argument("--num_workers", type=int, default=16, help="Number of workers")
parser.add_argument("--accumulate", type=int, default=1, help="batch to accumulate")
parser.add_argument("--prefetch_factor", type=int, default=2, help="Prefetch factor")
args = parser.parse_args()
main(args)