mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
accelerate aesthetic scoring (#32)
* accelerate aesthetic scoring * polish
This commit is contained in:
parent
1bb49f75e7
commit
3bbee00436
|
|
@ -32,9 +32,11 @@ With `meta.csv` containing the paths to the videos, run the following command:
|
|||
|
||||
```bash
|
||||
# output: meta_aes.csv
|
||||
python -m tools.scoring.aesthetic.inference meta.csv
|
||||
torchrun --nproc_per_node 8 -m tools.scoring.aesthetic.inference meta.csv --bs 1024 --num_workers 16
|
||||
```
|
||||
|
||||
This will generate multiple part files, you can use `python -m tools.datasets.csvutil DATA1.csv DATA2.csv` to merge these part files.
|
||||
|
||||
## Optical Flow Score
|
||||
Optical flow scores are used to assess the motion of a video. Higher optical flow scores indicate larger movement.
|
||||
TODO: acknowledge UniMatch.
|
||||
|
|
|
|||
|
|
@ -1,19 +1,30 @@
|
|||
# adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py
|
||||
import argparse
|
||||
import os
|
||||
from datetime import timedelta
|
||||
|
||||
import clip
|
||||
import decord
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.utils import set_seed
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
from torchvision.datasets.folder import pil_loader
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
BICUBIC = InterpolationMode.BICUBIC
|
||||
except ImportError:
|
||||
BICUBIC = Image.BICUBIC
|
||||
|
||||
|
||||
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
|
||||
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||
|
||||
|
|
@ -49,8 +60,7 @@ class VideoTextDataset(torch.utils.data.Dataset):
|
|||
images = extract_frames(sample["path"], points=self.points)
|
||||
images = [self.transform(img) for img in images]
|
||||
images = torch.stack(images)
|
||||
|
||||
return dict(index=index, image=images)
|
||||
return dict(index=index, images=images)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
|
@ -96,39 +106,62 @@ class AestheticScorer(nn.Module):
|
|||
|
||||
@torch.inference_mode()
|
||||
def main(args):
|
||||
output_file = args.input.replace(".csv", "_aes.csv")
|
||||
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
||||
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
||||
set_seed(1024)
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
output_file = args.input.replace(".csv", f"_aes_part{rank}.csv")
|
||||
|
||||
# build model
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = AestheticScorer(768, device)
|
||||
preprocess = model.preprocess
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# build dataset
|
||||
dataset = VideoTextDataset(args.input, transform=preprocess)
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset=dataset, num_replicas=world_size, rank=rank, shuffle=False
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size=args.bs,
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
# prefetch_factor=args.prefetch_factor,
|
||||
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
|
||||
)
|
||||
|
||||
# compute aesthetic scores
|
||||
dataset.data["aes"] = np.nan
|
||||
index = 0
|
||||
for batch in tqdm(dataloader):
|
||||
images = batch["image"].to(device)
|
||||
B = images.shape[0]
|
||||
images = rearrange(images, "b p c h w -> (b p) c h w")
|
||||
with torch.no_grad():
|
||||
|
||||
with tqdm(dataloader, position=rank, desc=f"Data Parallel Rank {rank}") as t:
|
||||
for idx, batch in enumerate(t):
|
||||
if idx == 12:
|
||||
break
|
||||
|
||||
image_indices = batch["index"]
|
||||
images = batch["images"].to(device, non_blocking=True)
|
||||
B = images.shape[0]
|
||||
images = rearrange(images, "b p c h w -> (b p) c h w")
|
||||
|
||||
# compute score
|
||||
scores = model(images)
|
||||
scores = rearrange(scores, "(b p) 1 -> b p", b=B)
|
||||
scores = scores.mean(dim=1)
|
||||
scores_np = scores.cpu().numpy()
|
||||
dataset.data.loc[index : index + len(scores_np) - 1, "aes"] = scores_np
|
||||
index += len(images)
|
||||
scores = rearrange(scores, "(b p) 1 -> b p", b=B)
|
||||
scores = scores.mean(dim=1)
|
||||
scores_np = scores.to(torch.float32).cpu().numpy()
|
||||
|
||||
# assign the score
|
||||
dataset.data.loc[image_indices, "aes"] = scores_np
|
||||
|
||||
# wait for all ranks to finish data processing
|
||||
dist.barrier()
|
||||
|
||||
# exclude rows whose aes is nan and save file
|
||||
dataset.data = dataset.data[dataset.data["aes"] > 0]
|
||||
dataset.data.to_csv(output_file, index=False)
|
||||
print(f"New meta with aesthetic scores saved to '{output_file}'.")
|
||||
|
||||
|
|
@ -137,8 +170,9 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input", type=str, help="Path to the input CSV file")
|
||||
parser.add_argument("--bs", type=int, default=512, help="Batch size")
|
||||
parser.add_argument("--num_workers", type=int, default=64, help="Number of workers")
|
||||
parser.add_argument("--prefetch_factor", type=int, default=8, help="Prefetch factor")
|
||||
parser.add_argument("--num_workers", type=int, default=16, help="Number of workers")
|
||||
parser.add_argument("--accumulate", type=int, default=1, help="batch to accumulate")
|
||||
parser.add_argument("--prefetch_factor", type=int, default=2, help="Prefetch factor")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
|
|||
Loading…
Reference in a new issue