mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
accelerate aesthetic
This commit is contained in:
parent
13423c57a7
commit
4d3b68e3ad
|
|
@ -1,52 +1,47 @@
|
|||
# adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py
|
||||
import argparse
|
||||
|
||||
import av
|
||||
import clip
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from einops import rearrange
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_video_length(cap):
|
||||
return int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
|
||||
def extract_frames(video_path, points=(0.5,)):
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
length = get_video_length(cap)
|
||||
points = [int(length * point) for point in points]
|
||||
def extract_frames(video_path, points=(0.0, 0.5, 0.9)):
|
||||
container = av.open(video_path)
|
||||
total_frames = container.streams.video[0].frames
|
||||
frames = []
|
||||
if length < 3:
|
||||
return frames, length
|
||||
for point in points:
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, point)
|
||||
ret, frame = cap.read()
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame = Image.fromarray(frame)
|
||||
target_frame = total_frames * point
|
||||
target_timestamp = int((target_frame * av.time_base) / container.streams.video[0].average_rate)
|
||||
container.seek(target_timestamp)
|
||||
frame = next(container.decode(video=0)).to_image()
|
||||
frames.append(frame)
|
||||
if len(frames) == 1:
|
||||
frames = frames[0]
|
||||
return frames, length
|
||||
return frames
|
||||
|
||||
|
||||
class VideoTextDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, csv_path, transform=None):
|
||||
def __init__(self, csv_path, transform=None, points=(0.1, 0.5, 0.9)):
|
||||
self.csv_path = csv_path
|
||||
self.samples = pd.read_csv(csv_path, header=None)
|
||||
self.data = pd.read_csv(csv_path)
|
||||
self.transform = transform
|
||||
self.points = points
|
||||
|
||||
def getitem(self, index):
|
||||
sample = self.samples.iloc[index]
|
||||
img = extract_frames(sample[0])[0]
|
||||
img = self.transform(img)
|
||||
text = sample[1]
|
||||
sample = self.data.iloc[index]
|
||||
images = extract_frames(sample["path"], points=self.points)
|
||||
images = [self.transform(img) for img in images]
|
||||
images = torch.stack(images)
|
||||
|
||||
return dict(index=index, image=img, text=text)
|
||||
return dict(index=index, image=images)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.getitem(index)
|
||||
|
|
@ -87,29 +82,48 @@ class AestheticScorer(nn.Module):
|
|||
return self.mlp(image_features)
|
||||
|
||||
|
||||
def main():
|
||||
def main(args):
|
||||
output_file = args.input.replace(".csv", "_aes.csv")
|
||||
|
||||
# build model
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = AestheticScorer(768, device)
|
||||
|
||||
dataset = VideoTextDataset(
|
||||
"/mnt/hdd/data/VidProM/VidProM_pika/meta/vidprom_relength_fmin_48_clean_en_unescape_nourl.csv",
|
||||
transform=model.preprocess,
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1024, shuffle=False, num_workers=16, pin_memory=True)
|
||||
dataset.samples["aesthetic"] = ""
|
||||
preprocess = model.preprocess
|
||||
model = torch.nn.DataParallel(model)
|
||||
output_file = "vidprom_aes.csv"
|
||||
|
||||
# build dataset
|
||||
dataset = VideoTextDataset(args.input, transform=preprocess, points=(0.5,))
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.bs,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=args.prefetch_factor,
|
||||
)
|
||||
|
||||
# compute aesthetic scores
|
||||
dataset.data["aesthetic"] = np.nan
|
||||
index = 0
|
||||
for batch in tqdm(dataloader):
|
||||
image = batch["image"].to(device)
|
||||
images = batch["image"].to(device)
|
||||
B = images.shape[0]
|
||||
images = rearrange(images, "b p c h w -> (b p) c h w")
|
||||
with torch.no_grad():
|
||||
score = model(image)
|
||||
dataset.samples.loc[index : index + len(score) - 1, "aesthetic"] = score.cpu().numpy().flatten()
|
||||
index += len(score)
|
||||
|
||||
dataset.samples.to_csv(output_file, index=False, header=False)
|
||||
print(f"Saved {index} samples")
|
||||
scores = model(images)
|
||||
scores = rearrange(scores, "(b p) 1 -> b p", b=B)
|
||||
scores = scores.mean(dim=1)
|
||||
scores_np = scores.cpu().numpy()
|
||||
dataset.data.loc[index : index + len(scores_np) - 1, "aesthetic"] = scores_np
|
||||
index += len(images)
|
||||
dataset.data.to_csv(output_file, index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input", type=str, help="Path to the input CSV file")
|
||||
parser.add_argument("--bs", type=int, default=512, help="Batch size")
|
||||
parser.add_argument("--num_workers", type=int, default=64, help="Number of workers")
|
||||
parser.add_argument("--prefetch_factor", type=int, default=8, help="Prefetch factor")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
|
|
|||
Loading…
Reference in a new issue