mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-12 22:38:53 +02:00
quick update flow
This commit is contained in:
parent
1295bec278
commit
0ec7a99dcf
|
|
@ -1,6 +1,6 @@
|
|||
# adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py
|
||||
import os
|
||||
import argparse
|
||||
import os
|
||||
from datetime import timedelta
|
||||
|
||||
import clip
|
||||
|
|
@ -10,9 +10,9 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchvision.datasets.folder import pil_loader
|
||||
from tqdm import tqdm
|
||||
|
||||
|
|
@ -169,6 +169,7 @@ def main():
|
|||
dist.barrier()
|
||||
|
||||
gathered_list = [None] * dist.get_world_size()
|
||||
breakpoint()
|
||||
dist.all_gather_object(gathered_list, (indices_list, scores_list))
|
||||
if dist.get_rank() == 0:
|
||||
meta_new = merge_scores(gathered_list, dataset.meta, column='aes')
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import cv2
|
||||
import cv2 # isort:skip
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
|
@ -17,6 +17,8 @@ from tqdm import tqdm
|
|||
from tools.datasets.utils import extract_frames
|
||||
from tools.scoring.optical_flow.unimatch import UniMatch
|
||||
|
||||
# torch.backends.cudnn.enabled = False # This line enables large batch, but the speed is similar
|
||||
|
||||
|
||||
def merge_scores(gathered_list: list, meta: pd.DataFrame, column):
|
||||
# reorder
|
||||
|
|
@ -150,8 +152,8 @@ def main():
|
|||
scores_list.extend(flow_scores)
|
||||
|
||||
# jun 3 quickfix
|
||||
meta_local = merge_scores([(indices_list, scores_list)], dataset.meta, column='flow')
|
||||
out_path_local = out_path.replace('.csv', f'_part_{dist.get_rank()}.csv')
|
||||
meta_local = merge_scores([(indices_list, scores_list)], dataset.meta, column="flow")
|
||||
out_path_local = out_path.replace(".csv", f"_part_{dist.get_rank()}.csv")
|
||||
meta_local.to_csv(out_path_local, index=False)
|
||||
|
||||
# wait for all ranks to finish data processing
|
||||
|
|
|
|||
Loading…
Reference in a new issue