From d6a4aeda74aefbbfdf7f068c5c053bc2267f5bb6 Mon Sep 17 00:00:00 2001 From: xyupeng Date: Tue, 2 Apr 2024 11:13:41 +0800 Subject: [PATCH] update scoring --- tools/scoring/README.md | 40 +++++++++++++++++++++---- tools/scoring/aesthetic/inference.py | 6 ++-- tools/scoring/optical_flow/inference.py | 10 +++---- 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/tools/scoring/README.md b/tools/scoring/README.md index 64cb73e..6db738c 100644 --- a/tools/scoring/README.md +++ b/tools/scoring/README.md @@ -1,8 +1,38 @@ -## Data Scoring and Filtering +# Data Scoring and Filtering +Important!!! All scoring jobs require these columns in meta files: +- `path`: absolute path to a sample -### Aesthetic Score +## Aesthetic Score +First prepare the environment and pretrained models. +```bash +# install clip +pip install git+https://github.com/openai/CLIP.git +pip install decord -### Optical Flow Score -`python tools/scoring/optical_flow/inference.py --meta_path ./data/Panda-70M/processed/meta/test_intact_cut_head-100.csv` +# get pretrained model +wget https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac+logos+ava1-l14-linearMSE.pth -O pretrained_models/aesthetic.pth +``` -### Matching Score \ No newline at end of file +Then run: +```bash +# output: DATA_aes.csv +python -m tools.scoring.aesthetic.inference /path/to/meta.csv +``` +The output should be `/path/to/meta_aes.csv` with column `aes`. Aesthetic scores range from 1 to 10, with 10 being the best quality. + +## Optical Flow Score +First get the pretrained model. +```bash +wget https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth -P pretrained_models/unimatch +``` + +Then run: +``` +python tools/scoring/optical_flow/inference.py /path/to/meta.csv +``` +The output should be `/path/to/meta_flow.csv` with column `flow`. Higher optical flow scores indicate larger movement. + +## Matching Score +Require column `text` in meta files, which is the caption of the sample. + +TODO. diff --git a/tools/scoring/aesthetic/inference.py b/tools/scoring/aesthetic/inference.py index 40ff78a..f28c4bc 100644 --- a/tools/scoring/aesthetic/inference.py +++ b/tools/scoring/aesthetic/inference.py @@ -116,7 +116,7 @@ def main(args): ) # compute aesthetic scores - dataset.data["aesthetic_score"] = np.nan + dataset.data["aes"] = np.nan index = 0 for batch in tqdm(dataloader): images = batch["image"].to(device) @@ -127,10 +127,10 @@ def main(args): scores = rearrange(scores, "(b p) 1 -> b p", b=B) scores = scores.mean(dim=1) scores_np = scores.cpu().numpy() - dataset.data.loc[index : index + len(scores_np) - 1, "aesthetic_score"] = scores_np + dataset.data.loc[index : index + len(scores_np) - 1, "aes"] = scores_np index += len(images) dataset.data.to_csv(output_file, index=False) - print(f"Saved aesthetic scores to {output_file}.") + print(f"New meta with aesthetic scores saved to \'{output_file}\'.") if __name__ == "__main__": diff --git a/tools/scoring/optical_flow/inference.py b/tools/scoring/optical_flow/inference.py index 2ccb132..be33822 100644 --- a/tools/scoring/optical_flow/inference.py +++ b/tools/scoring/optical_flow/inference.py @@ -71,7 +71,7 @@ class VideoTextDataset(torch.utils.data.Dataset): def main(): parser = argparse.ArgumentParser() - parser.add_argument("--meta_path", type=str, help="Path to the input CSV file") + parser.add_argument("meta_path", type=str, help="Path to the input CSV file") parser.add_argument("--bs", type=int, default=4, help="Batch size") parser.add_argument("--num_workers", type=int, default=16, help="Number of workers") args = parser.parse_args() @@ -92,10 +92,10 @@ def main(): reg_refine=True, task='flow', ) - # ckpt = torch.load( - # './checkpoints/pretrained_models/unimatch/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth' - # ) - # model.load_state_dict(ckpt['model']) + ckpt = torch.load( + './pretrained_models/unimatch/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth' + ) + model.load_state_dict(ckpt['model']) model = model.to(device) model = torch.nn.DataParallel(model)