quick update flow

This commit is contained in:
zhengzangw 2024-06-03 15:33:19 +00:00
parent 1295bec278
commit 0ec7a99dcf
2 changed files with 8 additions and 5 deletions

View file

@ -1,6 +1,6 @@
# adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py
import os
import argparse
import os
from datetime import timedelta
import clip
@ -10,9 +10,9 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, DistributedSampler
from einops import rearrange
from PIL import Image
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.datasets.folder import pil_loader
from tqdm import tqdm
@ -169,6 +169,7 @@ def main():
dist.barrier()
gathered_list = [None] * dist.get_world_size()
breakpoint()
dist.all_gather_object(gathered_list, (indices_list, scores_list))
if dist.get_rank() == 0:
meta_new = merge_scores(gathered_list, dataset.meta, column='aes')

View file

@ -1,4 +1,4 @@
import cv2
import cv2 # isort:skip
import argparse
import os
@ -17,6 +17,8 @@ from tqdm import tqdm
from tools.datasets.utils import extract_frames
from tools.scoring.optical_flow.unimatch import UniMatch
# torch.backends.cudnn.enabled = False # This line enables large batch, but the speed is similar
def merge_scores(gathered_list: list, meta: pd.DataFrame, column):
# reorder
@ -150,8 +152,8 @@ def main():
scores_list.extend(flow_scores)
# jun 3 quickfix
meta_local = merge_scores([(indices_list, scores_list)], dataset.meta, column='flow')
out_path_local = out_path.replace('.csv', f'_part_{dist.get_rank()}.csv')
meta_local = merge_scores([(indices_list, scores_list)], dataset.meta, column="flow")
out_path_local = out_path.replace(".csv", f"_part_{dist.get_rank()}.csv")
meta_local.to_csv(out_path_local, index=False)
# wait for all ranks to finish data processing