From 93ae382c2f5f0342c8edcdaf5306a104be131999 Mon Sep 17 00:00:00 2001 From: xyupeng <99191637+xyupeng@users.noreply.github.com> Date: Tue, 30 Apr 2024 14:44:45 +0800 Subject: [PATCH] update scoring (#82) * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scene_cut * update scene_cut * update scene_cut[A * update scene_cut * update scene_cut * update scene_cut * update scene_cut * update scene_cut * update scene_cut * m * m * m * m * m * m * m * m * m * m * m * m * m * m * update readme * update readme * extract frames using opencv everywhere * extract frames using opencv everywhere * extract frames using opencv everywhere * filter panda10m * filter panda10m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * ocr * add ocr * add main.sh * add ocr * add ocr * add ocr * add ocr * add ocr * add ocr * update scene_cut * update remove main.sh * update scoring * update scoring * update scoring * update README * update readme * update scene_cut * update readme * update scoring * update readme * update readme * update filter_panda10m * update readme * update readme * update launch.ipynb * update scene_cut * update scene_cut * update readme * update launch.ipynb * update readme * add 1.1 demo * update readme * add 1.1 demo * update readme * Update README.md * add num_workers for pandarallel * update scene_cut * update readme * update datautil * update scoring * update scoring --- .gitignore | 1 + docs/data_processing.md | 2 +- tools/datasets/convert.py | 1 + tools/scene_cut/convert_id_to_path.py | 6 +- tools/scene_cut/cut.py | 27 +++--- tools/scene_cut/scene_detect.py | 6 +- tools/scoring/aesthetic/inference.py | 120 ++++++++++++++------------ tools/scoring/matching/inference.py | 42 ++++----- 8 files changed, 117 insertions(+), 88 deletions(-) diff --git a/.gitignore b/.gitignore index b9f8121..4991c19 100644 --- a/.gitignore +++ b/.gitignore @@ -173,6 +173,7 @@ samples/ samples logs/ pretrained_models/ +pretrained_models evaluation_results/ cache/ *.swp diff --git a/docs/data_processing.md b/docs/data_processing.md index f907194..0d1c195 100644 --- a/docs/data_processing.md +++ b/docs/data_processing.md @@ -47,7 +47,7 @@ torchrun --nproc_per_node 8 -m tools.scoring.aesthetic.inference \ # 3.2 Merge files; This should output ${ROOT_META}/meta_clips_info_fmin1_aes.csv python -m tools.datasets.datautil ${ROOT_META}/meta_clips_info_fmin1_aes_part*.csv --output ${ROOT_META}/meta_clips_info_fmin1_aes.csv -# 3.2 Filter by aesthetic scores. This should output ${ROOT_META}/meta_clips_info_fmin1_aes_aesmin5.csv +# 3.3 Filter by aesthetic scores. This should output ${ROOT_META}/meta_clips_info_fmin1_aes_aesmin5.csv python -m tools.datasets.datautil ${ROOT_META}/meta_clips_info_fmin1_aes.csv --aesmin 5 # 4.1 Generate caption. This should output ${ROOT_META}/meta_clips_info_fmin1_aes_aesmin5_caption_part*.csv diff --git a/tools/datasets/convert.py b/tools/datasets/convert.py index ef6eee3..2c0db28 100644 --- a/tools/datasets/convert.py +++ b/tools/datasets/convert.py @@ -108,6 +108,7 @@ def process_general_videos(root, output): df = pd.DataFrame(dict(path=video_lists)) if output is None: output = "videos.csv" + os.makedirs(os.path.dirname(output), exist_ok=True) df.to_csv(output, index=False) print(f"Saved {len(df)} samples to {output}.") diff --git a/tools/scene_cut/convert_id_to_path.py b/tools/scene_cut/convert_id_to_path.py index eb7b1cb..d217a78 100644 --- a/tools/scene_cut/convert_id_to_path.py +++ b/tools/scene_cut/convert_id_to_path.py @@ -68,6 +68,7 @@ def parse_args(): parser.add_argument("meta_path", type=str) parser.add_argument("--folder_path", type=str, required=True) parser.add_argument("--mode", type=str, default=None) + parser.add_argument("--num_workers", type=int, default=None, help='#workers for pandarallel') args = parser.parse_args() return args @@ -104,7 +105,10 @@ def main(): meta_fname = os.path.basename(meta_path) wo_ext, ext = os.path.splitext(meta_fname) - pandarallel.initialize(progress_bar=True) + if args.num_workers is not None: + pandarallel.initialize(progress_bar=True, nb_workers=args.num_workers) + else: + pandarallel.initialize(progress_bar=True) is_intact_partial = partial(is_intact, mode=mode) meta = pd.read_csv(meta_path) diff --git a/tools/scene_cut/cut.py b/tools/scene_cut/cut.py index b3ecbe0..0e02aed 100644 --- a/tools/scene_cut/cut.py +++ b/tools/scene_cut/cut.py @@ -6,7 +6,6 @@ from functools import partial import pandas as pd from imageio_ffmpeg import get_ffmpeg_exe -from mmengine.logging import MMLogger, print_log from pandarallel import pandarallel from scenedetect import FrameTimecode from tqdm import tqdm @@ -14,12 +13,17 @@ from tqdm import tqdm tqdm.pandas() -def process_single_row(row, args, log_name=None): +def print_log(s, logger=None): + if logger is not None: + logger.info(s) + else: + print(s) + + +def process_single_row(row, args): video_path = row["path"] logger = None - if log_name is not None: - logger = MMLogger.get_instance(log_name) # check mp4 integrity # if not is_intact_video(video_path, logger=logger): @@ -132,6 +136,7 @@ def parse_args(): help='if not None, clip longer than max_seconds is truncated') parser.add_argument("--target_fps", type=int, default=30, help='target fps of clips') parser.add_argument("--shorter_size", type=int, default=720, help='resize the shorter size by keeping ratio') + parser.add_argument("--num_workers", type=int, default=None, help='#workers for pandarallel') args = parser.parse_args() return args @@ -144,16 +149,14 @@ def main(): os.makedirs(save_dir, exist_ok=True) # create logger - log_dir = os.path.dirname(save_dir) - log_name = os.path.basename(save_dir) - timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime(time.time())) - log_path = os.path.join(log_dir, f"{log_name}_{timestamp}.log") - logger = MMLogger.get_instance(log_name, log_file=log_path) - # logger = None + logger = None # initialize pandarallel - pandarallel.initialize(progress_bar=True) - process_single_row_partial = partial(process_single_row, args=args, log_name=log_name) + if args.num_workers is not None: + pandarallel.initialize(progress_bar=True, nb_workers=args.num_workers) + else: + pandarallel.initialize(progress_bar=True) + process_single_row_partial = partial(process_single_row, args=args) # process meta = pd.read_csv(args.meta_path) diff --git a/tools/scene_cut/scene_detect.py b/tools/scene_cut/scene_detect.py index eb7b003..a9acf54 100644 --- a/tools/scene_cut/scene_detect.py +++ b/tools/scene_cut/scene_detect.py @@ -34,6 +34,7 @@ def process_single_row(row): def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("meta_path", type=str) + parser.add_argument("--num_workers", type=int, default=None, help='#workers for pandarallel') args = parser.parse_args() return args @@ -43,7 +44,10 @@ def main(): args = parse_args() meta_path = args.meta_path - pandarallel.initialize(progress_bar=True) + if args.num_workers is not None: + pandarallel.initialize(progress_bar=True, nb_workers=args.num_workers) + else: + pandarallel.initialize(progress_bar=True) meta = pd.read_csv(meta_path) ret = meta.parallel_apply(process_single_row, axis=1) diff --git a/tools/scoring/aesthetic/inference.py b/tools/scoring/aesthetic/inference.py index a527859..c07008b 100644 --- a/tools/scoring/aesthetic/inference.py +++ b/tools/scoring/aesthetic/inference.py @@ -1,4 +1,5 @@ # adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py +import os import argparse from datetime import timedelta @@ -9,7 +10,7 @@ import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -from colossalai.utils import set_seed +from torch.utils.data import DataLoader, DistributedSampler from einops import rearrange from PIL import Image from torchvision.datasets.folder import pil_loader @@ -17,30 +18,41 @@ from tqdm import tqdm from tools.datasets.utils import extract_frames, is_video -try: - from torchvision.transforms import InterpolationMode - - BICUBIC = InterpolationMode.BICUBIC -except ImportError: - BICUBIC = Image.BICUBIC - - NUM_FRAMES_POINTS = { 1: (0.5,), 2: (0.25, 0.5), 3: (0.1, 0.5, 0.9), } +def merge_scores(gathered_list: list, meta: pd.DataFrame, column): + # reorder + indices_list = list(map(lambda x: x[0], gathered_list)) + scores_list = list(map(lambda x: x[1], gathered_list)) + + flat_indices = [] + for x in zip(*indices_list): + flat_indices.extend(x) + flat_scores = [] + for x in zip(*scores_list): + flat_scores.extend(x) + flat_indices = np.array(flat_indices) + flat_scores = np.array(flat_scores) + + # filter duplicates + unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True) + meta.loc[unique_indices, column] = flat_scores[unique_indices_idx] + return meta + class VideoTextDataset(torch.utils.data.Dataset): def __init__(self, csv_path, transform=None, num_frames=3): self.csv_path = csv_path - self.data = pd.read_csv(csv_path) + self.meta = pd.read_csv(csv_path) self.transform = transform self.points = NUM_FRAMES_POINTS[num_frames] - def getitem(self, index): - sample = self.data.iloc[index] + def __getitem__(self, index): + sample = self.meta.iloc[index] path = sample["path"] if not is_video(path): images = [pil_loader(path)] @@ -55,10 +67,7 @@ class VideoTextDataset(torch.utils.data.Dataset): return ret def __len__(self): - return len(self.data) - - def __getitem__(self, index): - return self.getitem(index) + return len(self.meta) class MLP(nn.Module): @@ -96,15 +105,15 @@ class AestheticScorer(nn.Module): return self.mlp(image_features) -@torch.inference_mode() -def main(args): +def main(): + args = parse_args() + dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) - set_seed(1024) - rank = dist.get_rank() - world_size = dist.get_world_size() - output_file = args.input.replace(".csv", f"_aes_part{rank}.csv") + meta_path = args.meta_path + wo_ext, ext = os.path.splitext(meta_path) + out_path = f"{wo_ext}_aes{ext}" # build model device = "cuda" if torch.cuda.is_available() else "cpu" @@ -112,57 +121,62 @@ def main(args): preprocess = model.preprocess # build dataset - dataset = VideoTextDataset(args.input, transform=preprocess, num_frames=args.num_frames) - sampler = torch.utils.data.distributed.DistributedSampler( - dataset=dataset, num_replicas=world_size, rank=rank, shuffle=False - ) - - dataloader = torch.utils.data.DataLoader( + dataset = VideoTextDataset(args.meta_path, transform=preprocess, num_frames=args.num_frames) + dataloader = DataLoader( dataset, - sampler=sampler, batch_size=args.bs, - shuffle=False, num_workers=args.num_workers, - pin_memory=True, - prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None, + sampler=DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=False, + drop_last=False, + ), ) # compute aesthetic scores - dataset.data["aes"] = np.nan + indices_list = [] + scores_list = [] + model.eval() + for batch in tqdm(dataloader, disable=dist.get_rank() != 0): + indices = batch["index"] + images = batch["images"].to(device, non_blocking=True) - with tqdm(dataloader, position=rank, desc=f"Data Parallel Rank {rank}") as t: - for idx, batch in enumerate(t): - image_indices = batch["index"] - images = batch["images"].to(device, non_blocking=True) - B = images.shape[0] - images = rearrange(images, "b p c h w -> (b p) c h w") + B = images.shape[0] + images = rearrange(images, "B N C H W -> (B N) C H W") - # compute score + # compute score + with torch.no_grad(): scores = model(images) - scores = rearrange(scores, "(b p) 1 -> b p", b=B) - scores = scores.mean(dim=1) - scores_np = scores.to(torch.float32).cpu().numpy() + scores = rearrange(scores, "(B N) 1 -> B N", B=B) + scores = scores.mean(dim=1) + scores_np = scores.to(torch.float32).cpu().numpy() - # assign the score - dataset.data.loc[image_indices, "aes"] = scores_np + indices_list.extend(indices) + scores_list.extend(scores_np) # wait for all ranks to finish data processing dist.barrier() - # exclude rows whose aes is nanĀ and save file - dataset.data = dataset.data[dataset.data["aes"] > 0] - dataset.data.to_csv(output_file, index=False) - print(f"New meta with aesthetic scores saved to '{output_file}'.") + gathered_list = [None] * dist.get_world_size() + dist.all_gather_object(gathered_list, (indices_list, scores_list)) + if dist.get_rank() == 0: + meta_new = merge_scores(gathered_list, dataset.meta, column='aes') + meta_new.to_csv(out_path, index=False) + print(f"New meta with aesthetic scores saved to '{out_path}'.") -if __name__ == "__main__": +def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("input", type=str, help="Path to the input CSV file") + parser.add_argument("meta_path", type=str, help="Path to the input CSV file") parser.add_argument("--bs", type=int, default=1024, help="Batch size") parser.add_argument("--num_workers", type=int, default=16, help="Number of workers") - parser.add_argument("--accumulate", type=int, default=1, help="batch to accumulate") parser.add_argument("--prefetch_factor", type=int, default=2, help="Prefetch factor") parser.add_argument("--num_frames", type=int, default=3, help="Number of frames to extract") args = parser.parse_args() - main(args) + return args + +if __name__ == "__main__": + main() diff --git a/tools/scoring/matching/inference.py b/tools/scoring/matching/inference.py index 7bedef1..78766d5 100644 --- a/tools/scoring/matching/inference.py +++ b/tools/scoring/matching/inference.py @@ -15,6 +15,26 @@ from tqdm import tqdm from tools.datasets.utils import extract_frames, is_video +def merge_scores(gathered_list: list, meta: pd.DataFrame, column): + # reorder + indices_list = list(map(lambda x: x[0], gathered_list)) + scores_list = list(map(lambda x: x[1], gathered_list)) + + flat_indices = [] + for x in zip(*indices_list): + flat_indices.extend(x) + flat_scores = [] + for x in zip(*scores_list): + flat_scores.extend(x) + flat_indices = np.array(flat_indices) + flat_scores = np.array(flat_scores) + + # filter duplicates + unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True) + meta.loc[unique_indices, column] = flat_scores[unique_indices_idx] + return meta + + class VideoTextDataset(torch.utils.data.Dataset): def __init__(self, meta_path, transform): self.meta_path = meta_path @@ -41,23 +61,6 @@ class VideoTextDataset(torch.utils.data.Dataset): return len(self.meta) -def merge_scores(gathered_list: list, meta: pd.DataFrame): - # reorder - indices_list = list(map(lambda x: x[0], gathered_list)) - scores_list = list(map(lambda x: x[1], gathered_list)) - flat_indices = [] - for x in zip(*indices_list): - flat_indices.extend(x) - flat_scores = [] - for x in zip(*scores_list): - flat_scores.extend(x) - flat_indices = np.array(flat_indices) - flat_scores = np.array(flat_scores) - # filter duplicates - unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True) - meta.loc[unique_indices, "match"] = flat_scores[unique_indices_idx] - - def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("meta_path", type=str, help="Path to the input CSV file") @@ -96,7 +99,6 @@ def main(): ) # compute scores - dataset.meta["match"] = np.nan indices_list = [] scores_list = [] model.eval() @@ -118,8 +120,8 @@ def main(): gathered_list = [None] * dist.get_world_size() dist.all_gather_object(gathered_list, (indices_list, scores_list)) if dist.get_rank() == 0: - merge_scores(gathered_list, dataset.meta) - dataset.meta.to_csv(out_path, index=False) + meta_new = merge_scores(gathered_list, dataset.meta, column='match') + meta_new.to_csv(out_path, index=False) print(f"New meta with matching scores saved to '{out_path}'.")