mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 20:36:58 +02:00
update scoring (#82)
* update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scene_cut * update scene_cut * update scene_cut[A * update scene_cut * update scene_cut * update scene_cut * update scene_cut * update scene_cut * update scene_cut * m * m * m * m * m * m * m * m * m * m * m * m * m * m * update readme * update readme * extract frames using opencv everywhere * extract frames using opencv everywhere * extract frames using opencv everywhere * filter panda10m * filter panda10m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * ocr * add ocr * add main.sh * add ocr * add ocr * add ocr * add ocr * add ocr * add ocr * update scene_cut * update remove main.sh * update scoring * update scoring * update scoring * update README * update readme * update scene_cut * update readme * update scoring * update readme * update readme * update filter_panda10m * update readme * update readme * update launch.ipynb * update scene_cut * update scene_cut * update readme * update launch.ipynb * update readme * add 1.1 demo * update readme * add 1.1 demo * update readme * Update README.md * add num_workers for pandarallel * update scene_cut * update readme * update datautil * update scoring * update scoring
This commit is contained in:
parent
bf633f3e68
commit
93ae382c2f
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -173,6 +173,7 @@ samples/
|
|||
samples
|
||||
logs/
|
||||
pretrained_models/
|
||||
pretrained_models
|
||||
evaluation_results/
|
||||
cache/
|
||||
*.swp
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ torchrun --nproc_per_node 8 -m tools.scoring.aesthetic.inference \
|
|||
# 3.2 Merge files; This should output ${ROOT_META}/meta_clips_info_fmin1_aes.csv
|
||||
python -m tools.datasets.datautil ${ROOT_META}/meta_clips_info_fmin1_aes_part*.csv --output ${ROOT_META}/meta_clips_info_fmin1_aes.csv
|
||||
|
||||
# 3.2 Filter by aesthetic scores. This should output ${ROOT_META}/meta_clips_info_fmin1_aes_aesmin5.csv
|
||||
# 3.3 Filter by aesthetic scores. This should output ${ROOT_META}/meta_clips_info_fmin1_aes_aesmin5.csv
|
||||
python -m tools.datasets.datautil ${ROOT_META}/meta_clips_info_fmin1_aes.csv --aesmin 5
|
||||
|
||||
# 4.1 Generate caption. This should output ${ROOT_META}/meta_clips_info_fmin1_aes_aesmin5_caption_part*.csv
|
||||
|
|
|
|||
|
|
@ -108,6 +108,7 @@ def process_general_videos(root, output):
|
|||
df = pd.DataFrame(dict(path=video_lists))
|
||||
if output is None:
|
||||
output = "videos.csv"
|
||||
os.makedirs(os.path.dirname(output), exist_ok=True)
|
||||
df.to_csv(output, index=False)
|
||||
print(f"Saved {len(df)} samples to {output}.")
|
||||
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ def parse_args():
|
|||
parser.add_argument("meta_path", type=str)
|
||||
parser.add_argument("--folder_path", type=str, required=True)
|
||||
parser.add_argument("--mode", type=str, default=None)
|
||||
parser.add_argument("--num_workers", type=int, default=None, help='#workers for pandarallel')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
|
@ -104,7 +105,10 @@ def main():
|
|||
meta_fname = os.path.basename(meta_path)
|
||||
wo_ext, ext = os.path.splitext(meta_fname)
|
||||
|
||||
pandarallel.initialize(progress_bar=True)
|
||||
if args.num_workers is not None:
|
||||
pandarallel.initialize(progress_bar=True, nb_workers=args.num_workers)
|
||||
else:
|
||||
pandarallel.initialize(progress_bar=True)
|
||||
is_intact_partial = partial(is_intact, mode=mode)
|
||||
|
||||
meta = pd.read_csv(meta_path)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from functools import partial
|
|||
|
||||
import pandas as pd
|
||||
from imageio_ffmpeg import get_ffmpeg_exe
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from pandarallel import pandarallel
|
||||
from scenedetect import FrameTimecode
|
||||
from tqdm import tqdm
|
||||
|
|
@ -14,12 +13,17 @@ from tqdm import tqdm
|
|||
tqdm.pandas()
|
||||
|
||||
|
||||
def process_single_row(row, args, log_name=None):
|
||||
def print_log(s, logger=None):
|
||||
if logger is not None:
|
||||
logger.info(s)
|
||||
else:
|
||||
print(s)
|
||||
|
||||
|
||||
def process_single_row(row, args):
|
||||
video_path = row["path"]
|
||||
|
||||
logger = None
|
||||
if log_name is not None:
|
||||
logger = MMLogger.get_instance(log_name)
|
||||
|
||||
# check mp4 integrity
|
||||
# if not is_intact_video(video_path, logger=logger):
|
||||
|
|
@ -132,6 +136,7 @@ def parse_args():
|
|||
help='if not None, clip longer than max_seconds is truncated')
|
||||
parser.add_argument("--target_fps", type=int, default=30, help='target fps of clips')
|
||||
parser.add_argument("--shorter_size", type=int, default=720, help='resize the shorter size by keeping ratio')
|
||||
parser.add_argument("--num_workers", type=int, default=None, help='#workers for pandarallel')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
|
@ -144,16 +149,14 @@ def main():
|
|||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# create logger
|
||||
log_dir = os.path.dirname(save_dir)
|
||||
log_name = os.path.basename(save_dir)
|
||||
timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime(time.time()))
|
||||
log_path = os.path.join(log_dir, f"{log_name}_{timestamp}.log")
|
||||
logger = MMLogger.get_instance(log_name, log_file=log_path)
|
||||
# logger = None
|
||||
logger = None
|
||||
|
||||
# initialize pandarallel
|
||||
pandarallel.initialize(progress_bar=True)
|
||||
process_single_row_partial = partial(process_single_row, args=args, log_name=log_name)
|
||||
if args.num_workers is not None:
|
||||
pandarallel.initialize(progress_bar=True, nb_workers=args.num_workers)
|
||||
else:
|
||||
pandarallel.initialize(progress_bar=True)
|
||||
process_single_row_partial = partial(process_single_row, args=args)
|
||||
|
||||
# process
|
||||
meta = pd.read_csv(args.meta_path)
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ def process_single_row(row):
|
|||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("meta_path", type=str)
|
||||
parser.add_argument("--num_workers", type=int, default=None, help='#workers for pandarallel')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
|
@ -43,7 +44,10 @@ def main():
|
|||
args = parse_args()
|
||||
meta_path = args.meta_path
|
||||
|
||||
pandarallel.initialize(progress_bar=True)
|
||||
if args.num_workers is not None:
|
||||
pandarallel.initialize(progress_bar=True, nb_workers=args.num_workers)
|
||||
else:
|
||||
pandarallel.initialize(progress_bar=True)
|
||||
|
||||
meta = pd.read_csv(meta_path)
|
||||
ret = meta.parallel_apply(process_single_row, axis=1)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py
|
||||
import os
|
||||
import argparse
|
||||
from datetime import timedelta
|
||||
|
||||
|
|
@ -9,7 +10,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.utils import set_seed
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
from torchvision.datasets.folder import pil_loader
|
||||
|
|
@ -17,30 +18,41 @@ from tqdm import tqdm
|
|||
|
||||
from tools.datasets.utils import extract_frames, is_video
|
||||
|
||||
try:
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
BICUBIC = InterpolationMode.BICUBIC
|
||||
except ImportError:
|
||||
BICUBIC = Image.BICUBIC
|
||||
|
||||
|
||||
NUM_FRAMES_POINTS = {
|
||||
1: (0.5,),
|
||||
2: (0.25, 0.5),
|
||||
3: (0.1, 0.5, 0.9),
|
||||
}
|
||||
|
||||
def merge_scores(gathered_list: list, meta: pd.DataFrame, column):
|
||||
# reorder
|
||||
indices_list = list(map(lambda x: x[0], gathered_list))
|
||||
scores_list = list(map(lambda x: x[1], gathered_list))
|
||||
|
||||
flat_indices = []
|
||||
for x in zip(*indices_list):
|
||||
flat_indices.extend(x)
|
||||
flat_scores = []
|
||||
for x in zip(*scores_list):
|
||||
flat_scores.extend(x)
|
||||
flat_indices = np.array(flat_indices)
|
||||
flat_scores = np.array(flat_scores)
|
||||
|
||||
# filter duplicates
|
||||
unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
|
||||
meta.loc[unique_indices, column] = flat_scores[unique_indices_idx]
|
||||
return meta
|
||||
|
||||
|
||||
class VideoTextDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, csv_path, transform=None, num_frames=3):
|
||||
self.csv_path = csv_path
|
||||
self.data = pd.read_csv(csv_path)
|
||||
self.meta = pd.read_csv(csv_path)
|
||||
self.transform = transform
|
||||
self.points = NUM_FRAMES_POINTS[num_frames]
|
||||
|
||||
def getitem(self, index):
|
||||
sample = self.data.iloc[index]
|
||||
def __getitem__(self, index):
|
||||
sample = self.meta.iloc[index]
|
||||
path = sample["path"]
|
||||
if not is_video(path):
|
||||
images = [pil_loader(path)]
|
||||
|
|
@ -55,10 +67,7 @@ class VideoTextDataset(torch.utils.data.Dataset):
|
|||
return ret
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.getitem(index)
|
||||
return len(self.meta)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
|
@ -96,15 +105,15 @@ class AestheticScorer(nn.Module):
|
|||
return self.mlp(image_features)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(args):
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
||||
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
||||
set_seed(1024)
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
output_file = args.input.replace(".csv", f"_aes_part{rank}.csv")
|
||||
meta_path = args.meta_path
|
||||
wo_ext, ext = os.path.splitext(meta_path)
|
||||
out_path = f"{wo_ext}_aes{ext}"
|
||||
|
||||
# build model
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
|
@ -112,57 +121,62 @@ def main(args):
|
|||
preprocess = model.preprocess
|
||||
|
||||
# build dataset
|
||||
dataset = VideoTextDataset(args.input, transform=preprocess, num_frames=args.num_frames)
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset=dataset, num_replicas=world_size, rank=rank, shuffle=False
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset = VideoTextDataset(args.meta_path, transform=preprocess, num_frames=args.num_frames)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size=args.bs,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
|
||||
sampler=DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=dist.get_world_size(),
|
||||
rank=dist.get_rank(),
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
),
|
||||
)
|
||||
|
||||
# compute aesthetic scores
|
||||
dataset.data["aes"] = np.nan
|
||||
indices_list = []
|
||||
scores_list = []
|
||||
model.eval()
|
||||
for batch in tqdm(dataloader, disable=dist.get_rank() != 0):
|
||||
indices = batch["index"]
|
||||
images = batch["images"].to(device, non_blocking=True)
|
||||
|
||||
with tqdm(dataloader, position=rank, desc=f"Data Parallel Rank {rank}") as t:
|
||||
for idx, batch in enumerate(t):
|
||||
image_indices = batch["index"]
|
||||
images = batch["images"].to(device, non_blocking=True)
|
||||
B = images.shape[0]
|
||||
images = rearrange(images, "b p c h w -> (b p) c h w")
|
||||
B = images.shape[0]
|
||||
images = rearrange(images, "B N C H W -> (B N) C H W")
|
||||
|
||||
# compute score
|
||||
# compute score
|
||||
with torch.no_grad():
|
||||
scores = model(images)
|
||||
scores = rearrange(scores, "(b p) 1 -> b p", b=B)
|
||||
scores = scores.mean(dim=1)
|
||||
scores_np = scores.to(torch.float32).cpu().numpy()
|
||||
scores = rearrange(scores, "(B N) 1 -> B N", B=B)
|
||||
scores = scores.mean(dim=1)
|
||||
scores_np = scores.to(torch.float32).cpu().numpy()
|
||||
|
||||
# assign the score
|
||||
dataset.data.loc[image_indices, "aes"] = scores_np
|
||||
indices_list.extend(indices)
|
||||
scores_list.extend(scores_np)
|
||||
|
||||
# wait for all ranks to finish data processing
|
||||
dist.barrier()
|
||||
|
||||
# exclude rows whose aes is nan and save file
|
||||
dataset.data = dataset.data[dataset.data["aes"] > 0]
|
||||
dataset.data.to_csv(output_file, index=False)
|
||||
print(f"New meta with aesthetic scores saved to '{output_file}'.")
|
||||
gathered_list = [None] * dist.get_world_size()
|
||||
dist.all_gather_object(gathered_list, (indices_list, scores_list))
|
||||
if dist.get_rank() == 0:
|
||||
meta_new = merge_scores(gathered_list, dataset.meta, column='aes')
|
||||
meta_new.to_csv(out_path, index=False)
|
||||
print(f"New meta with aesthetic scores saved to '{out_path}'.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input", type=str, help="Path to the input CSV file")
|
||||
parser.add_argument("meta_path", type=str, help="Path to the input CSV file")
|
||||
parser.add_argument("--bs", type=int, default=1024, help="Batch size")
|
||||
parser.add_argument("--num_workers", type=int, default=16, help="Number of workers")
|
||||
parser.add_argument("--accumulate", type=int, default=1, help="batch to accumulate")
|
||||
parser.add_argument("--prefetch_factor", type=int, default=2, help="Prefetch factor")
|
||||
parser.add_argument("--num_frames", type=int, default=3, help="Number of frames to extract")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
return args
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -15,6 +15,26 @@ from tqdm import tqdm
|
|||
from tools.datasets.utils import extract_frames, is_video
|
||||
|
||||
|
||||
def merge_scores(gathered_list: list, meta: pd.DataFrame, column):
|
||||
# reorder
|
||||
indices_list = list(map(lambda x: x[0], gathered_list))
|
||||
scores_list = list(map(lambda x: x[1], gathered_list))
|
||||
|
||||
flat_indices = []
|
||||
for x in zip(*indices_list):
|
||||
flat_indices.extend(x)
|
||||
flat_scores = []
|
||||
for x in zip(*scores_list):
|
||||
flat_scores.extend(x)
|
||||
flat_indices = np.array(flat_indices)
|
||||
flat_scores = np.array(flat_scores)
|
||||
|
||||
# filter duplicates
|
||||
unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
|
||||
meta.loc[unique_indices, column] = flat_scores[unique_indices_idx]
|
||||
return meta
|
||||
|
||||
|
||||
class VideoTextDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, meta_path, transform):
|
||||
self.meta_path = meta_path
|
||||
|
|
@ -41,23 +61,6 @@ class VideoTextDataset(torch.utils.data.Dataset):
|
|||
return len(self.meta)
|
||||
|
||||
|
||||
def merge_scores(gathered_list: list, meta: pd.DataFrame):
|
||||
# reorder
|
||||
indices_list = list(map(lambda x: x[0], gathered_list))
|
||||
scores_list = list(map(lambda x: x[1], gathered_list))
|
||||
flat_indices = []
|
||||
for x in zip(*indices_list):
|
||||
flat_indices.extend(x)
|
||||
flat_scores = []
|
||||
for x in zip(*scores_list):
|
||||
flat_scores.extend(x)
|
||||
flat_indices = np.array(flat_indices)
|
||||
flat_scores = np.array(flat_scores)
|
||||
# filter duplicates
|
||||
unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
|
||||
meta.loc[unique_indices, "match"] = flat_scores[unique_indices_idx]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("meta_path", type=str, help="Path to the input CSV file")
|
||||
|
|
@ -96,7 +99,6 @@ def main():
|
|||
)
|
||||
|
||||
# compute scores
|
||||
dataset.meta["match"] = np.nan
|
||||
indices_list = []
|
||||
scores_list = []
|
||||
model.eval()
|
||||
|
|
@ -118,8 +120,8 @@ def main():
|
|||
gathered_list = [None] * dist.get_world_size()
|
||||
dist.all_gather_object(gathered_list, (indices_list, scores_list))
|
||||
if dist.get_rank() == 0:
|
||||
merge_scores(gathered_list, dataset.meta)
|
||||
dataset.meta.to_csv(out_path, index=False)
|
||||
meta_new = merge_scores(gathered_list, dataset.meta, column='match')
|
||||
meta_new.to_csv(out_path, index=False)
|
||||
print(f"New meta with matching scores saved to '{out_path}'.")
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue