From 0ec7a99dcf08025f5af43d644b5ee431b28beafe Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Mon, 3 Jun 2024 15:33:19 +0000 Subject: [PATCH] quick update flow --- tools/scoring/aesthetic/inference.py | 5 +++-- tools/scoring/optical_flow/inference.py | 8 +++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tools/scoring/aesthetic/inference.py b/tools/scoring/aesthetic/inference.py index fa583b1..a291914 100644 --- a/tools/scoring/aesthetic/inference.py +++ b/tools/scoring/aesthetic/inference.py @@ -1,6 +1,6 @@ # adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py -import os import argparse +import os from datetime import timedelta import clip @@ -10,9 +10,9 @@ import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import DataLoader, DistributedSampler from einops import rearrange from PIL import Image +from torch.utils.data import DataLoader, DistributedSampler from torchvision.datasets.folder import pil_loader from tqdm import tqdm @@ -169,6 +169,7 @@ def main(): dist.barrier() gathered_list = [None] * dist.get_world_size() + breakpoint() dist.all_gather_object(gathered_list, (indices_list, scores_list)) if dist.get_rank() == 0: meta_new = merge_scores(gathered_list, dataset.meta, column='aes') diff --git a/tools/scoring/optical_flow/inference.py b/tools/scoring/optical_flow/inference.py index 7437d98..3f7d85f 100644 --- a/tools/scoring/optical_flow/inference.py +++ b/tools/scoring/optical_flow/inference.py @@ -1,4 +1,4 @@ -import cv2 +import cv2 # isort:skip import argparse import os @@ -17,6 +17,8 @@ from tqdm import tqdm from tools.datasets.utils import extract_frames from tools.scoring.optical_flow.unimatch import UniMatch +# torch.backends.cudnn.enabled = False # This line enables large batch, but the speed is similar + def merge_scores(gathered_list: list, meta: pd.DataFrame, column): # reorder @@ -150,8 +152,8 @@ def main(): scores_list.extend(flow_scores) # jun 3 quickfix - meta_local = merge_scores([(indices_list, scores_list)], dataset.meta, column='flow') - out_path_local = out_path.replace('.csv', f'_part_{dist.get_rank()}.csv') + meta_local = merge_scores([(indices_list, scores_list)], dataset.meta, column="flow") + out_path_local = out_path.replace(".csv", f"_part_{dist.get_rank()}.csv") meta_local.to_csv(out_path_local, index=False) # wait for all ranks to finish data processing