mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-21 03:33:55 +02:00
add evals
This commit is contained in:
parent
f5499aaf19
commit
c9785342c2
83
eval/cal_flolpips.py
Normal file
83
eval/cal_flolpips.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
from einops import rearrange
|
||||
import sys
|
||||
sys.path.append(".")
|
||||
from flolpips.pwcnet import Network as PWCNet
|
||||
from flolpips.flolpips import FloLPIPS
|
||||
|
||||
loss_fn = FloLPIPS(net='alex', version='0.1').eval().requires_grad_(False)
|
||||
flownet = PWCNet().eval().requires_grad_(False)
|
||||
|
||||
def trans(x):
|
||||
return x
|
||||
|
||||
|
||||
def calculate_flolpips(videos1, videos2, device):
|
||||
global loss_fn, flownet
|
||||
|
||||
print("calculate_flowlpips...")
|
||||
loss_fn = loss_fn.to(device)
|
||||
flownet = flownet.to(device)
|
||||
|
||||
if videos1.shape != videos2.shape:
|
||||
print("Warning: the shape of videos are not equal.")
|
||||
min_frames = min(videos1.shape[1], videos2.shape[1])
|
||||
videos1 = videos1[:, :min_frames]
|
||||
videos2 = videos2[:, :min_frames]
|
||||
|
||||
videos1 = trans(videos1)
|
||||
videos2 = trans(videos2)
|
||||
|
||||
flolpips_results = []
|
||||
for video_num in tqdm(range(videos1.shape[0])):
|
||||
video1 = videos1[video_num].to(device)
|
||||
video2 = videos2[video_num].to(device)
|
||||
frames_rec = video1[:-1]
|
||||
frames_rec_next = video1[1:]
|
||||
frames_gt = video2[:-1]
|
||||
frames_gt_next = video2[1:]
|
||||
t, c, h, w = frames_gt.shape
|
||||
flow_gt = flownet(frames_gt, frames_gt_next)
|
||||
flow_dis = flownet(frames_rec, frames_rec_next)
|
||||
flow_diff = flow_gt - flow_dis
|
||||
flolpips = loss_fn.forward(frames_gt, frames_rec, flow_diff, normalize=True)
|
||||
flolpips_results.append(flolpips.cpu().numpy().tolist())
|
||||
|
||||
flolpips_results = np.array(flolpips_results) # [batch_size, num_frames]
|
||||
flolpips = {}
|
||||
flolpips_std = {}
|
||||
|
||||
for clip_timestamp in range(flolpips_results.shape[1]):
|
||||
flolpips[clip_timestamp] = np.mean(flolpips_results[:,clip_timestamp], axis=-1)
|
||||
flolpips_std[clip_timestamp] = np.std(flolpips_results[:,clip_timestamp], axis=-1)
|
||||
|
||||
result = {
|
||||
"value": flolpips,
|
||||
"value_std": flolpips_std,
|
||||
"video_setting": video1.shape,
|
||||
"video_setting_name": "time, channel, heigth, width",
|
||||
"result": flolpips_results,
|
||||
"details": flolpips_results.tolist()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# test code / using example
|
||||
|
||||
def main():
|
||||
NUMBER_OF_VIDEOS = 8
|
||||
VIDEO_LENGTH = 50
|
||||
CHANNEL = 3
|
||||
SIZE = 64
|
||||
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
|
||||
import json
|
||||
result = calculate_flolpips(videos1, videos2, "cuda:0")
|
||||
print(json.dumps(result, indent=4))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
85
eval/cal_fvd.py
Normal file
85
eval/cal_fvd.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
def trans(x):
|
||||
# if greyscale images add channel
|
||||
if x.shape[-3] == 1:
|
||||
x = x.repeat(1, 1, 3, 1, 1)
|
||||
|
||||
# permute BTCHW -> BCTHW
|
||||
x = x.permute(0, 2, 1, 3, 4)
|
||||
|
||||
return x
|
||||
|
||||
def calculate_fvd(videos1, videos2, device, method='styleganv'):
|
||||
|
||||
if method == 'styleganv':
|
||||
from fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained
|
||||
elif method == 'videogpt':
|
||||
from fvd.videogpt.fvd import load_i3d_pretrained
|
||||
from fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats
|
||||
from fvd.videogpt.fvd import frechet_distance
|
||||
|
||||
print("calculate_fvd...")
|
||||
|
||||
# videos [batch_size, timestamps, channel, h, w]
|
||||
|
||||
assert videos1.shape == videos2.shape
|
||||
|
||||
i3d = load_i3d_pretrained(device=device)
|
||||
fvd_results = []
|
||||
|
||||
# support grayscale input, if grayscale -> channel*3
|
||||
# BTCHW -> BCTHW
|
||||
# videos -> [batch_size, channel, timestamps, h, w]
|
||||
|
||||
videos1 = trans(videos1)
|
||||
videos2 = trans(videos2)
|
||||
|
||||
fvd_results = {}
|
||||
|
||||
# for calculate FVD, each clip_timestamp must >= 10
|
||||
for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)):
|
||||
|
||||
# get a video clip
|
||||
# videos_clip [batch_size, channel, timestamps[:clip], h, w]
|
||||
videos_clip1 = videos1[:, :, : clip_timestamp]
|
||||
videos_clip2 = videos2[:, :, : clip_timestamp]
|
||||
|
||||
# get FVD features
|
||||
feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device)
|
||||
feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device)
|
||||
|
||||
# calculate FVD when timestamps[:clip]
|
||||
fvd_results[clip_timestamp] = frechet_distance(feats1, feats2)
|
||||
|
||||
result = {
|
||||
"value": fvd_results,
|
||||
"video_setting": videos1.shape,
|
||||
"video_setting_name": "batch_size, channel, time, heigth, width",
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# test code / using example
|
||||
|
||||
def main():
|
||||
NUMBER_OF_VIDEOS = 8
|
||||
VIDEO_LENGTH = 50
|
||||
CHANNEL = 3
|
||||
SIZE = 64
|
||||
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
device = torch.device("cuda")
|
||||
# device = torch.device("cpu")
|
||||
|
||||
import json
|
||||
result = calculate_fvd(videos1, videos2, device, method='videogpt')
|
||||
print(json.dumps(result, indent=4))
|
||||
|
||||
result = calculate_fvd(videos1, videos2, device, method='styleganv')
|
||||
print(json.dumps(result, indent=4))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
97
eval/cal_lpips.py
Normal file
97
eval/cal_lpips.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
|
||||
import torch
|
||||
import lpips
|
||||
|
||||
spatial = True # Return a spatial map of perceptual distance.
|
||||
|
||||
# Linearly calibrated models (LPIPS)
|
||||
loss_fn = lpips.LPIPS(net='alex', spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
|
||||
# loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
|
||||
|
||||
def trans(x):
|
||||
# if greyscale images add channel
|
||||
if x.shape[-3] == 1:
|
||||
x = x.repeat(1, 1, 3, 1, 1)
|
||||
|
||||
# value range [0, 1] -> [-1, 1]
|
||||
x = x * 2 - 1
|
||||
|
||||
return x
|
||||
|
||||
def calculate_lpips(videos1, videos2, device):
|
||||
# image should be RGB, IMPORTANT: normalized to [-1,1]
|
||||
print("calculate_lpips...")
|
||||
|
||||
assert videos1.shape == videos2.shape
|
||||
|
||||
# videos [batch_size, timestamps, channel, h, w]
|
||||
|
||||
# support grayscale input, if grayscale -> channel*3
|
||||
# value range [0, 1] -> [-1, 1]
|
||||
videos1 = trans(videos1)
|
||||
videos2 = trans(videos2)
|
||||
|
||||
lpips_results = []
|
||||
|
||||
for video_num in tqdm(range(videos1.shape[0])):
|
||||
# get a video
|
||||
# video [timestamps, channel, h, w]
|
||||
video1 = videos1[video_num]
|
||||
video2 = videos2[video_num]
|
||||
|
||||
lpips_results_of_a_video = []
|
||||
for clip_timestamp in range(len(video1)):
|
||||
# get a img
|
||||
# img [timestamps[x], channel, h, w]
|
||||
# img [channel, h, w] tensor
|
||||
|
||||
img1 = video1[clip_timestamp].unsqueeze(0).to(device)
|
||||
img2 = video2[clip_timestamp].unsqueeze(0).to(device)
|
||||
|
||||
loss_fn.to(device)
|
||||
|
||||
# calculate lpips of a video
|
||||
lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
|
||||
lpips_results.append(lpips_results_of_a_video)
|
||||
|
||||
lpips_results = np.array(lpips_results)
|
||||
|
||||
lpips = {}
|
||||
lpips_std = {}
|
||||
|
||||
for clip_timestamp in range(len(video1)):
|
||||
lpips[clip_timestamp] = np.mean(lpips_results[:,clip_timestamp])
|
||||
lpips_std[clip_timestamp] = np.std(lpips_results[:,clip_timestamp])
|
||||
|
||||
|
||||
result = {
|
||||
"value": lpips,
|
||||
"value_std": lpips_std,
|
||||
"video_setting": video1.shape,
|
||||
"video_setting_name": "time, channel, heigth, width",
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# test code / using example
|
||||
|
||||
def main():
|
||||
NUMBER_OF_VIDEOS = 8
|
||||
VIDEO_LENGTH = 50
|
||||
CHANNEL = 3
|
||||
SIZE = 64
|
||||
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
device = torch.device("cuda")
|
||||
# device = torch.device("cpu")
|
||||
|
||||
import json
|
||||
result = calculate_lpips(videos1, videos2, device)
|
||||
print(json.dumps(result, indent=4))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
84
eval/cal_psnr.py
Normal file
84
eval/cal_psnr.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
|
||||
def img_psnr(img1, img2):
|
||||
# [0,1]
|
||||
# compute mse
|
||||
# mse = np.mean((img1-img2)**2)
|
||||
mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
|
||||
# compute psnr
|
||||
if mse < 1e-10:
|
||||
return 100
|
||||
psnr = 20 * math.log10(1 / math.sqrt(mse))
|
||||
return psnr
|
||||
|
||||
def trans(x):
|
||||
return x
|
||||
|
||||
def calculate_psnr(videos1, videos2):
|
||||
print("calculate_psnr...")
|
||||
|
||||
# videos [batch_size, timestamps, channel, h, w]
|
||||
|
||||
assert videos1.shape == videos2.shape
|
||||
|
||||
videos1 = trans(videos1)
|
||||
videos2 = trans(videos2)
|
||||
|
||||
psnr_results = []
|
||||
|
||||
for video_num in tqdm(range(videos1.shape[0])):
|
||||
# get a video
|
||||
# video [timestamps, channel, h, w]
|
||||
video1 = videos1[video_num]
|
||||
video2 = videos2[video_num]
|
||||
|
||||
psnr_results_of_a_video = []
|
||||
for clip_timestamp in range(len(video1)):
|
||||
# get a img
|
||||
# img [timestamps[x], channel, h, w]
|
||||
# img [channel, h, w] numpy
|
||||
|
||||
img1 = video1[clip_timestamp].numpy()
|
||||
img2 = video2[clip_timestamp].numpy()
|
||||
|
||||
# calculate psnr of a video
|
||||
psnr_results_of_a_video.append(img_psnr(img1, img2))
|
||||
|
||||
psnr_results.append(psnr_results_of_a_video)
|
||||
|
||||
psnr_results = np.array(psnr_results) # [batch_size, num_frames]
|
||||
psnr = {}
|
||||
psnr_std = {}
|
||||
|
||||
for clip_timestamp in range(len(video1)):
|
||||
psnr[clip_timestamp] = np.mean(psnr_results[:,clip_timestamp])
|
||||
psnr_std[clip_timestamp] = np.std(psnr_results[:,clip_timestamp])
|
||||
|
||||
result = {
|
||||
"value": psnr,
|
||||
"value_std": psnr_std,
|
||||
"video_setting": video1.shape,
|
||||
"video_setting_name": "time, channel, heigth, width",
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# test code / using example
|
||||
|
||||
def main():
|
||||
NUMBER_OF_VIDEOS = 8
|
||||
VIDEO_LENGTH = 50
|
||||
CHANNEL = 3
|
||||
SIZE = 64
|
||||
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
|
||||
import json
|
||||
result = calculate_psnr(videos1, videos2)
|
||||
print(json.dumps(result, indent=4))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
113
eval/cal_ssim.py
Normal file
113
eval/cal_ssim.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
|
||||
def ssim(img1, img2):
|
||||
C1 = 0.01 ** 2
|
||||
C2 = 0.03 ** 2
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1 ** 2
|
||||
mu2_sq = mu2 ** 2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
||||
(sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
|
||||
|
||||
def calculate_ssim_function(img1, img2):
|
||||
# [0,1]
|
||||
# ssim is the only metric extremely sensitive to gray being compared to b/w
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError('Input images must have the same dimensions.')
|
||||
if img1.ndim == 2:
|
||||
return ssim(img1, img2)
|
||||
elif img1.ndim == 3:
|
||||
if img1.shape[0] == 3:
|
||||
ssims = []
|
||||
for i in range(3):
|
||||
ssims.append(ssim(img1[i], img2[i]))
|
||||
return np.array(ssims).mean()
|
||||
elif img1.shape[0] == 1:
|
||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
||||
else:
|
||||
raise ValueError('Wrong input image dimensions.')
|
||||
|
||||
def trans(x):
|
||||
return x
|
||||
|
||||
def calculate_ssim(videos1, videos2):
|
||||
print("calculate_ssim...")
|
||||
|
||||
# videos [batch_size, timestamps, channel, h, w]
|
||||
|
||||
assert videos1.shape == videos2.shape
|
||||
|
||||
videos1 = trans(videos1)
|
||||
videos2 = trans(videos2)
|
||||
|
||||
ssim_results = []
|
||||
|
||||
for video_num in tqdm(range(videos1.shape[0])):
|
||||
# get a video
|
||||
# video [timestamps, channel, h, w]
|
||||
video1 = videos1[video_num]
|
||||
video2 = videos2[video_num]
|
||||
|
||||
ssim_results_of_a_video = []
|
||||
for clip_timestamp in range(len(video1)):
|
||||
# get a img
|
||||
# img [timestamps[x], channel, h, w]
|
||||
# img [channel, h, w] numpy
|
||||
|
||||
img1 = video1[clip_timestamp].numpy()
|
||||
img2 = video2[clip_timestamp].numpy()
|
||||
|
||||
# calculate ssim of a video
|
||||
ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
|
||||
|
||||
ssim_results.append(ssim_results_of_a_video)
|
||||
|
||||
ssim_results = np.array(ssim_results)
|
||||
|
||||
ssim = {}
|
||||
ssim_std = {}
|
||||
|
||||
for clip_timestamp in range(len(video1)):
|
||||
ssim[clip_timestamp] = np.mean(ssim_results[:,clip_timestamp])
|
||||
ssim_std[clip_timestamp] = np.std(ssim_results[:,clip_timestamp])
|
||||
|
||||
result = {
|
||||
"value": ssim,
|
||||
"value_std": ssim_std,
|
||||
"video_setting": video1.shape,
|
||||
"video_setting_name": "time, channel, heigth, width",
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# test code / using example
|
||||
|
||||
def main():
|
||||
NUMBER_OF_VIDEOS = 8
|
||||
VIDEO_LENGTH = 50
|
||||
CHANNEL = 3
|
||||
SIZE = 64
|
||||
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
||||
device = torch.device("cuda")
|
||||
|
||||
import json
|
||||
result = calculate_ssim(videos1, videos2)
|
||||
print(json.dumps(result, indent=4))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
225
eval/eval_clip_score.py
Executable file
225
eval/eval_clip_score.py
Executable file
|
|
@ -0,0 +1,225 @@
|
|||
"""Calculates the CLIP Scores
|
||||
|
||||
The CLIP model is a contrasitively learned language-image model. There is
|
||||
an image encoder and a text encoder. It is believed that the CLIP model could
|
||||
measure the similarity of cross modalities. Please find more information from
|
||||
https://github.com/openai/CLIP.
|
||||
|
||||
The CLIP Score measures the Cosine Similarity between two embedded features.
|
||||
This repository utilizes the pretrained CLIP Model to calculate
|
||||
the mean average of cosine similarities.
|
||||
|
||||
See --help to see further details.
|
||||
|
||||
Code apapted from https://github.com/mseitzer/pytorch-fid and https://github.com/openai/CLIP.
|
||||
|
||||
Copyright 2023 The Hong Kong Polytechnic University
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import os
|
||||
import os.path as osp
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
|
||||
import clip
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
# If tqdm is not available, provide a mock version of it
|
||||
def tqdm(x):
|
||||
return x
|
||||
|
||||
|
||||
IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
|
||||
'tif', 'tiff', 'webp'}
|
||||
|
||||
TEXT_EXTENSIONS = {'txt'}
|
||||
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
|
||||
FLAGS = ['img', 'txt']
|
||||
def __init__(self, real_path, generated_path,
|
||||
real_flag: str = 'img',
|
||||
generated_flag: str = 'img',
|
||||
transform = None,
|
||||
tokenizer = None) -> None:
|
||||
super().__init__()
|
||||
assert real_flag in self.FLAGS and generated_flag in self.FLAGS, \
|
||||
'CLIP Score only support modality of {}. However, get {} and {}'.format(
|
||||
self.FLAGS, real_flag, generated_flag
|
||||
)
|
||||
self.real_folder = self._combine_without_prefix(real_path)
|
||||
self.real_flag = real_flag
|
||||
self.fake_foler = self._combine_without_prefix(generated_path)
|
||||
self.generated_flag = generated_flag
|
||||
self.transform = transform
|
||||
self.tokenizer = tokenizer
|
||||
# assert self._check()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.real_folder)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if index >= len(self):
|
||||
raise IndexError
|
||||
real_path = self.real_folder[index]
|
||||
generated_path = self.fake_foler[index]
|
||||
real_data = self._load_modality(real_path, self.real_flag)
|
||||
fake_data = self._load_modality(generated_path, self.generated_flag)
|
||||
|
||||
sample = dict(real=real_data, fake=fake_data)
|
||||
return sample
|
||||
|
||||
def _load_modality(self, path, modality):
|
||||
if modality == 'img':
|
||||
data = self._load_img(path)
|
||||
elif modality == 'txt':
|
||||
data = self._load_txt(path)
|
||||
else:
|
||||
raise TypeError("Got unexpected modality: {}".format(modality))
|
||||
return data
|
||||
|
||||
def _load_img(self, path):
|
||||
img = Image.open(path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return img
|
||||
|
||||
def _load_txt(self, path):
|
||||
with open(path, 'r') as fp:
|
||||
data = fp.read()
|
||||
fp.close()
|
||||
if self.tokenizer is not None:
|
||||
data = self.tokenizer(data).squeeze()
|
||||
return data
|
||||
|
||||
def _check(self):
|
||||
for idx in range(len(self)):
|
||||
real_name = self.real_folder[idx].split('.')
|
||||
fake_name = self.fake_folder[idx].split('.')
|
||||
if fake_name != real_name:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _combine_without_prefix(self, folder_path, prefix='.'):
|
||||
folder = []
|
||||
for name in os.listdir(folder_path):
|
||||
if name[0] == prefix:
|
||||
continue
|
||||
folder.append(osp.join(folder_path, name))
|
||||
folder.sort()
|
||||
return folder
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_clip_score(dataloader, model, real_flag, generated_flag):
|
||||
score_acc = 0.
|
||||
sample_num = 0.
|
||||
logit_scale = model.logit_scale.exp()
|
||||
for batch_data in tqdm(dataloader):
|
||||
real = batch_data['real']
|
||||
real_features = forward_modality(model, real, real_flag)
|
||||
fake = batch_data['fake']
|
||||
fake_features = forward_modality(model, fake, generated_flag)
|
||||
|
||||
# normalize features
|
||||
real_features = real_features / real_features.norm(dim=1, keepdim=True).to(torch.float32)
|
||||
fake_features = fake_features / fake_features.norm(dim=1, keepdim=True).to(torch.float32)
|
||||
|
||||
# calculate scores
|
||||
# score = logit_scale * real_features @ fake_features.t()
|
||||
# score_acc += torch.diag(score).sum()
|
||||
score = logit_scale * (fake_features * real_features).sum()
|
||||
score_acc += score
|
||||
sample_num += real.shape[0]
|
||||
|
||||
return score_acc / sample_num
|
||||
|
||||
|
||||
def forward_modality(model, data, flag):
|
||||
device = next(model.parameters()).device
|
||||
if flag == 'img':
|
||||
features = model.encode_image(data.to(device))
|
||||
elif flag == 'txt':
|
||||
features = model.encode_text(data.to(device))
|
||||
else:
|
||||
raise TypeError
|
||||
return features
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--batch-size', type=int, default=50,
|
||||
help='Batch size to use')
|
||||
parser.add_argument('--clip-model', type=str, default='ViT-B/32',
|
||||
help='CLIP model to use')
|
||||
parser.add_argument('--num-workers', type=int, default=8,
|
||||
help=('Number of processes to use for data loading. '
|
||||
'Defaults to `min(8, num_cpus)`'))
|
||||
parser.add_argument('--device', type=str, default=None,
|
||||
help='Device to use. Like cuda, cuda:0 or cpu')
|
||||
parser.add_argument('--real_flag', type=str, default='img',
|
||||
help=('The modality of real path. '
|
||||
'Default to img'))
|
||||
parser.add_argument('--generated_flag', type=str, default='txt',
|
||||
help=('The modality of generated path. '
|
||||
'Default to txt'))
|
||||
parser.add_argument('--real_path', type=str,
|
||||
help=('Paths to the real images or '
|
||||
'to .npz statistic files'))
|
||||
parser.add_argument('--generated_path', type=str,
|
||||
help=('Paths to the generated images or '
|
||||
'to .npz statistic files'))
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.device is None:
|
||||
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
if args.num_workers is None:
|
||||
try:
|
||||
num_cpus = len(os.sched_getaffinity(0))
|
||||
except AttributeError:
|
||||
# os.sched_getaffinity is not available under Windows, use
|
||||
# os.cpu_count instead (which may not return the *available* number
|
||||
# of CPUs).
|
||||
num_cpus = os.cpu_count()
|
||||
|
||||
num_workers = min(num_cpus, 8) if num_cpus is not None else 0
|
||||
else:
|
||||
num_workers = args.num_workers
|
||||
|
||||
print('Loading CLIP model: {}'.format(args.clip_model))
|
||||
model, preprocess = clip.load(args.clip_model, device=device)
|
||||
|
||||
dataset = DummyDataset(args.real_path, args.generated_path,
|
||||
args.real_flag, args.generated_flag,
|
||||
transform=preprocess, tokenizer=clip.tokenize)
|
||||
dataloader = DataLoader(dataset, args.batch_size,
|
||||
num_workers=num_workers, pin_memory=True)
|
||||
|
||||
print('Calculating CLIP Score:')
|
||||
clip_score = calculate_clip_score(dataloader, model,
|
||||
args.real_flag, args.generated_flag)
|
||||
clip_score = clip_score.cpu().item()
|
||||
print('CLIP Score: ', clip_score)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
224
eval/eval_common_metric.py
Normal file
224
eval/eval_common_metric.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
"""Calculates the CLIP Scores
|
||||
|
||||
The CLIP model is a contrasitively learned language-image model. There is
|
||||
an image encoder and a text encoder. It is believed that the CLIP model could
|
||||
measure the similarity of cross modalities. Please find more information from
|
||||
https://github.com/openai/CLIP.
|
||||
|
||||
The CLIP Score measures the Cosine Similarity between two embedded features.
|
||||
This repository utilizes the pretrained CLIP Model to calculate
|
||||
the mean average of cosine similarities.
|
||||
|
||||
See --help to see further details.
|
||||
|
||||
Code apapted from https://github.com/mseitzer/pytorch-fid and https://github.com/openai/CLIP.
|
||||
|
||||
Copyright 2023 The Hong Kong Polytechnic University
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader, Subset
|
||||
from decord import VideoReader, cpu
|
||||
import random
|
||||
from pytorchvideo.transforms import ShortSideScale
|
||||
from torchvision.io import read_video
|
||||
from torchvision.transforms import Lambda, Compose
|
||||
from torchvision.transforms._transforms_video import CenterCropVideo
|
||||
import sys
|
||||
sys.path.append(".")
|
||||
from opensora.eval.cal_lpips import calculate_lpips
|
||||
from opensora.eval.cal_fvd import calculate_fvd
|
||||
from opensora.eval.cal_psnr import calculate_psnr
|
||||
from opensora.eval.cal_flolpips import calculate_flolpips
|
||||
from opensora.eval.cal_ssim import calculate_ssim
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
# If tqdm is not available, provide a mock version of it
|
||||
def tqdm(x):
|
||||
return x
|
||||
|
||||
class VideoDataset(Dataset):
|
||||
def __init__(self,
|
||||
real_video_dir,
|
||||
generated_video_dir,
|
||||
num_frames,
|
||||
sample_rate = 1,
|
||||
crop_size=None,
|
||||
resolution=128,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.real_video_files = self._combine_without_prefix(real_video_dir)
|
||||
self.generated_video_files = self._combine_without_prefix(generated_video_dir)
|
||||
self.num_frames = num_frames
|
||||
self.sample_rate = sample_rate
|
||||
self.crop_size = crop_size
|
||||
self.short_size = resolution
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.real_video_files)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if index >= len(self):
|
||||
raise IndexError
|
||||
real_video_file = self.real_video_files[index]
|
||||
generated_video_file = self.generated_video_files[index]
|
||||
print(real_video_file, generated_video_file)
|
||||
real_video_tensor = self._load_video(real_video_file)
|
||||
generated_video_tensor = self._load_video(generated_video_file)
|
||||
return {'real': real_video_tensor, 'generated':generated_video_tensor }
|
||||
|
||||
|
||||
def _load_video(self, video_path):
|
||||
num_frames = self.num_frames
|
||||
sample_rate = self.sample_rate
|
||||
decord_vr = VideoReader(video_path, ctx=cpu(0))
|
||||
total_frames = len(decord_vr)
|
||||
sample_frames_len = sample_rate * num_frames
|
||||
|
||||
if total_frames >= sample_frames_len:
|
||||
s = 0
|
||||
e = s + sample_frames_len
|
||||
num_frames = num_frames
|
||||
else:
|
||||
s = 0
|
||||
e = total_frames
|
||||
num_frames = int(total_frames / sample_frames_len * num_frames)
|
||||
print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path,
|
||||
total_frames)
|
||||
|
||||
|
||||
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
|
||||
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
|
||||
video_data = torch.from_numpy(video_data)
|
||||
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (C, T, H, W)
|
||||
return _preprocess(video_data, short_size=self.short_size, crop_size = self.crop_size)
|
||||
|
||||
|
||||
def _combine_without_prefix(self, folder_path, prefix='.'):
|
||||
folder = []
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
for name in os.listdir(folder_path):
|
||||
if name[0] == prefix:
|
||||
continue
|
||||
if osp.isfile(osp.join(folder_path, name)):
|
||||
folder.append(osp.join(folder_path, name))
|
||||
folder.sort()
|
||||
return folder
|
||||
|
||||
def _preprocess(video_data, short_size=128, crop_size=None):
|
||||
transform = Compose(
|
||||
[
|
||||
Lambda(lambda x: x / 255.0),
|
||||
ShortSideScale(size=short_size),
|
||||
CenterCropVideo(crop_size=crop_size),
|
||||
]
|
||||
)
|
||||
video_outputs = transform(video_data)
|
||||
# video_outputs = torch.unsqueeze(video_outputs, 0) # (bz,c,t,h,w)
|
||||
return video_outputs
|
||||
|
||||
|
||||
def calculate_common_metric(args, dataloader, device):
|
||||
|
||||
score_list = []
|
||||
for batch_data in tqdm(dataloader): # {'real': real_video_tensor, 'generated':generated_video_tensor }
|
||||
real_videos = batch_data['real']
|
||||
generated_videos = batch_data['generated']
|
||||
assert real_videos.shape[2] == generated_videos.shape[2]
|
||||
if args.metric == 'fvd':
|
||||
tmp_list = list(calculate_fvd(real_videos, generated_videos, args.device, method=args.fvd_method)['value'].values())
|
||||
elif args.metric == 'ssim':
|
||||
tmp_list = list(calculate_ssim(real_videos, generated_videos)['value'].values())
|
||||
elif args.metric == 'psnr':
|
||||
tmp_list = list(calculate_psnr(real_videos, generated_videos)['value'].values())
|
||||
elif args.metric == 'flolpips':
|
||||
result = calculate_flolpips(real_videos, generated_videos, args.device)
|
||||
tmp_list = list(result['value'].values())
|
||||
else:
|
||||
tmp_list = list(calculate_lpips(real_videos, generated_videos, args.device)['value'].values())
|
||||
score_list += tmp_list
|
||||
return np.mean(score_list)
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--batch_size', type=int, default=2,
|
||||
help='Batch size to use')
|
||||
parser.add_argument('--real_video_dir', type=str,
|
||||
help=('the path of real videos`'))
|
||||
parser.add_argument('--generated_video_dir', type=str,
|
||||
help=('the path of generated videos`'))
|
||||
parser.add_argument('--device', type=str, default=None,
|
||||
help='Device to use. Like cuda, cuda:0 or cpu')
|
||||
parser.add_argument('--num_workers', type=int, default=8,
|
||||
help=('Number of processes to use for data loading. '
|
||||
'Defaults to `min(8, num_cpus)`'))
|
||||
parser.add_argument('--sample_fps', type=int, default=30)
|
||||
parser.add_argument('--resolution', type=int, default=336)
|
||||
parser.add_argument('--crop_size', type=int, default=None)
|
||||
parser.add_argument('--num_frames', type=int, default=100)
|
||||
parser.add_argument('--sample_rate', type=int, default=1)
|
||||
parser.add_argument('--subset_size', type=int, default=None)
|
||||
parser.add_argument("--metric", type=str, default="fvd",choices=['fvd','psnr','ssim','lpips', 'flolpips'])
|
||||
parser.add_argument("--fvd_method", type=str, default='styleganv',choices=['styleganv','videogpt'])
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.device is None:
|
||||
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
if args.num_workers is None:
|
||||
try:
|
||||
num_cpus = len(os.sched_getaffinity(0))
|
||||
except AttributeError:
|
||||
# os.sched_getaffinity is not available under Windows, use
|
||||
# os.cpu_count instead (which may not return the *available* number
|
||||
# of CPUs).
|
||||
num_cpus = os.cpu_count()
|
||||
|
||||
num_workers = min(num_cpus, 8) if num_cpus is not None else 0
|
||||
else:
|
||||
num_workers = args.num_workers
|
||||
|
||||
|
||||
dataset = VideoDataset(args.real_video_dir,
|
||||
args.generated_video_dir,
|
||||
num_frames = args.num_frames,
|
||||
sample_rate = args.sample_rate,
|
||||
crop_size=args.crop_size,
|
||||
resolution=args.resolution)
|
||||
|
||||
if args.subset_size:
|
||||
indices = range(args.subset_size)
|
||||
dataset = Subset(dataset, indices=indices)
|
||||
|
||||
dataloader = DataLoader(dataset, args.batch_size,
|
||||
num_workers=num_workers, pin_memory=True)
|
||||
|
||||
|
||||
metric_score = calculate_common_metric(args, dataloader,device)
|
||||
print('metric: ', args.metric, " ",metric_score)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
397
eval/flolpips/correlation/correlation.py
Normal file
397
eval/flolpips/correlation/correlation.py
Normal file
|
|
@ -0,0 +1,397 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import torch
|
||||
|
||||
import cupy
|
||||
import re
|
||||
|
||||
kernel_Correlation_rearrange = '''
|
||||
extern "C" __global__ void kernel_Correlation_rearrange(
|
||||
const int n,
|
||||
const float* input,
|
||||
float* output
|
||||
) {
|
||||
int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
|
||||
if (intIndex >= n) {
|
||||
return;
|
||||
}
|
||||
|
||||
int intSample = blockIdx.z;
|
||||
int intChannel = blockIdx.y;
|
||||
|
||||
float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int intPaddedY = (intIndex / SIZE_3(input)) + 4;
|
||||
int intPaddedX = (intIndex % SIZE_3(input)) + 4;
|
||||
int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX;
|
||||
|
||||
output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue;
|
||||
}
|
||||
'''
|
||||
|
||||
kernel_Correlation_updateOutput = '''
|
||||
extern "C" __global__ void kernel_Correlation_updateOutput(
|
||||
const int n,
|
||||
const float* rbot0,
|
||||
const float* rbot1,
|
||||
float* top
|
||||
) {
|
||||
extern __shared__ char patch_data_char[];
|
||||
|
||||
float *patch_data = (float *)patch_data_char;
|
||||
|
||||
// First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
|
||||
int x1 = blockIdx.x + 4;
|
||||
int y1 = blockIdx.y + 4;
|
||||
int item = blockIdx.z;
|
||||
int ch_off = threadIdx.x;
|
||||
|
||||
// Load 3D patch into shared shared memory
|
||||
for (int j = 0; j < 1; j++) { // HEIGHT
|
||||
for (int i = 0; i < 1; i++) { // WIDTH
|
||||
int ji_off = (j + i) * SIZE_3(rbot0);
|
||||
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
|
||||
int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
|
||||
int idxPatchData = ji_off + ch;
|
||||
patch_data[idxPatchData] = rbot0[idx1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
__shared__ float sum[32];
|
||||
|
||||
// Compute correlation
|
||||
for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
|
||||
sum[ch_off] = 0;
|
||||
|
||||
int s2o = top_channel % 9 - 4;
|
||||
int s2p = top_channel / 9 - 4;
|
||||
|
||||
for (int j = 0; j < 1; j++) { // HEIGHT
|
||||
for (int i = 0; i < 1; i++) { // WIDTH
|
||||
int ji_off = (j + i) * SIZE_3(rbot0);
|
||||
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
|
||||
int x2 = x1 + s2o;
|
||||
int y2 = y1 + s2p;
|
||||
|
||||
int idxPatchData = ji_off + ch;
|
||||
int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
|
||||
|
||||
sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (ch_off == 0) {
|
||||
float total_sum = 0;
|
||||
for (int idx = 0; idx < 32; idx++) {
|
||||
total_sum += sum[idx];
|
||||
}
|
||||
const int sumelems = SIZE_3(rbot0);
|
||||
const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
|
||||
top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
kernel_Correlation_updateGradFirst = '''
|
||||
#define ROUND_OFF 50000
|
||||
|
||||
extern "C" __global__ void kernel_Correlation_updateGradFirst(
|
||||
const int n,
|
||||
const int intSample,
|
||||
const float* rbot0,
|
||||
const float* rbot1,
|
||||
const float* gradOutput,
|
||||
float* gradFirst,
|
||||
float* gradSecond
|
||||
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||
int n = intIndex % SIZE_1(gradFirst); // channels
|
||||
int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos
|
||||
int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos
|
||||
|
||||
// round_off is a trick to enable integer division with ceil, even for negative numbers
|
||||
// We use a large offset, for the inner part not to become negative.
|
||||
const int round_off = ROUND_OFF;
|
||||
const int round_off_s1 = round_off;
|
||||
|
||||
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
|
||||
int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
|
||||
int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
|
||||
|
||||
// Same here:
|
||||
int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4)
|
||||
int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4)
|
||||
|
||||
float sum = 0;
|
||||
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
|
||||
xmin = max(0,xmin);
|
||||
xmax = min(SIZE_3(gradOutput)-1,xmax);
|
||||
|
||||
ymin = max(0,ymin);
|
||||
ymax = min(SIZE_2(gradOutput)-1,ymax);
|
||||
|
||||
for (int p = -4; p <= 4; p++) {
|
||||
for (int o = -4; o <= 4; o++) {
|
||||
// Get rbot1 data:
|
||||
int s2o = o;
|
||||
int s2p = p;
|
||||
int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
|
||||
float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
|
||||
|
||||
// Index offset for gradOutput in following loops:
|
||||
int op = (p+4) * 9 + (o+4); // index[o,p]
|
||||
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
|
||||
|
||||
for (int y = ymin; y <= ymax; y++) {
|
||||
for (int x = xmin; x <= xmax; x++) {
|
||||
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
|
||||
sum += gradOutput[idxgradOutput] * bot1tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const int sumelems = SIZE_1(gradFirst);
|
||||
const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4);
|
||||
gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;
|
||||
} }
|
||||
'''
|
||||
|
||||
kernel_Correlation_updateGradSecond = '''
|
||||
#define ROUND_OFF 50000
|
||||
|
||||
extern "C" __global__ void kernel_Correlation_updateGradSecond(
|
||||
const int n,
|
||||
const int intSample,
|
||||
const float* rbot0,
|
||||
const float* rbot1,
|
||||
const float* gradOutput,
|
||||
float* gradFirst,
|
||||
float* gradSecond
|
||||
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||
int n = intIndex % SIZE_1(gradSecond); // channels
|
||||
int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos
|
||||
int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos
|
||||
|
||||
// round_off is a trick to enable integer division with ceil, even for negative numbers
|
||||
// We use a large offset, for the inner part not to become negative.
|
||||
const int round_off = ROUND_OFF;
|
||||
const int round_off_s1 = round_off;
|
||||
|
||||
float sum = 0;
|
||||
for (int p = -4; p <= 4; p++) {
|
||||
for (int o = -4; o <= 4; o++) {
|
||||
int s2o = o;
|
||||
int s2p = p;
|
||||
|
||||
//Get X,Y ranges and clamp
|
||||
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
|
||||
int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
|
||||
int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
|
||||
|
||||
// Same here:
|
||||
int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o)
|
||||
int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p)
|
||||
|
||||
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
|
||||
xmin = max(0,xmin);
|
||||
xmax = min(SIZE_3(gradOutput)-1,xmax);
|
||||
|
||||
ymin = max(0,ymin);
|
||||
ymax = min(SIZE_2(gradOutput)-1,ymax);
|
||||
|
||||
// Get rbot0 data:
|
||||
int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
|
||||
float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
|
||||
|
||||
// Index offset for gradOutput in following loops:
|
||||
int op = (p+4) * 9 + (o+4); // index[o,p]
|
||||
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
|
||||
|
||||
for (int y = ymin; y <= ymax; y++) {
|
||||
for (int x = xmin; x <= xmax; x++) {
|
||||
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
|
||||
sum += gradOutput[idxgradOutput] * bot0tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const int sumelems = SIZE_1(gradSecond);
|
||||
const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4);
|
||||
gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;
|
||||
} }
|
||||
'''
|
||||
|
||||
def cupy_kernel(strFunction, objVariables):
|
||||
strKernel = globals()[strFunction]
|
||||
|
||||
while True:
|
||||
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
||||
|
||||
if objMatch is None:
|
||||
break
|
||||
# end
|
||||
|
||||
intArg = int(objMatch.group(2))
|
||||
|
||||
strTensor = objMatch.group(4)
|
||||
intSizes = objVariables[strTensor].size()
|
||||
|
||||
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
|
||||
# end
|
||||
|
||||
while True:
|
||||
objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
|
||||
|
||||
if objMatch is None:
|
||||
break
|
||||
# end
|
||||
|
||||
intArgs = int(objMatch.group(2))
|
||||
strArgs = objMatch.group(4).split(',')
|
||||
|
||||
strTensor = strArgs[0]
|
||||
intStrides = objVariables[strTensor].stride()
|
||||
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]
|
||||
|
||||
strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
|
||||
# end
|
||||
|
||||
return strKernel
|
||||
# end
|
||||
|
||||
@cupy.memoize(for_each_device=True)
|
||||
def cupy_launch(strFunction, strKernel):
|
||||
return cupy.RawKernel(strKernel, strFunction)
|
||||
# end
|
||||
|
||||
class _FunctionCorrelation(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(self, first, second):
|
||||
rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ])
|
||||
rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ])
|
||||
|
||||
self.save_for_backward(first, second, rbot0, rbot1)
|
||||
|
||||
first = first.contiguous(); assert(first.is_cuda == True)
|
||||
second = second.contiguous(); assert(second.is_cuda == True)
|
||||
|
||||
output = first.new_zeros([ first.shape[0], 81, first.shape[2], first.shape[3] ])
|
||||
|
||||
if first.is_cuda == True:
|
||||
n = first.shape[2] * first.shape[3]
|
||||
cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
|
||||
'input': first,
|
||||
'output': rbot0
|
||||
}))(
|
||||
grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]),
|
||||
block=tuple([ 16, 1, 1 ]),
|
||||
args=[ n, first.data_ptr(), rbot0.data_ptr() ]
|
||||
)
|
||||
|
||||
n = second.shape[2] * second.shape[3]
|
||||
cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
|
||||
'input': second,
|
||||
'output': rbot1
|
||||
}))(
|
||||
grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]),
|
||||
block=tuple([ 16, 1, 1 ]),
|
||||
args=[ n, second.data_ptr(), rbot1.data_ptr() ]
|
||||
)
|
||||
|
||||
n = output.shape[1] * output.shape[2] * output.shape[3]
|
||||
cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
|
||||
'rbot0': rbot0,
|
||||
'rbot1': rbot1,
|
||||
'top': output
|
||||
}))(
|
||||
grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]),
|
||||
block=tuple([ 32, 1, 1 ]),
|
||||
shared_mem=first.shape[1] * 4,
|
||||
args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ]
|
||||
)
|
||||
|
||||
elif first.is_cuda == False:
|
||||
raise NotImplementedError()
|
||||
|
||||
# end
|
||||
|
||||
return output
|
||||
# end
|
||||
|
||||
@staticmethod
|
||||
def backward(self, gradOutput):
|
||||
first, second, rbot0, rbot1 = self.saved_tensors
|
||||
|
||||
gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True)
|
||||
|
||||
gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None
|
||||
gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None
|
||||
|
||||
if first.is_cuda == True:
|
||||
if gradFirst is not None:
|
||||
for intSample in range(first.shape[0]):
|
||||
n = first.shape[1] * first.shape[2] * first.shape[3]
|
||||
cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', {
|
||||
'rbot0': rbot0,
|
||||
'rbot1': rbot1,
|
||||
'gradOutput': gradOutput,
|
||||
'gradFirst': gradFirst,
|
||||
'gradSecond': None
|
||||
}))(
|
||||
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
||||
block=tuple([ 512, 1, 1 ]),
|
||||
args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ]
|
||||
)
|
||||
# end
|
||||
# end
|
||||
|
||||
if gradSecond is not None:
|
||||
for intSample in range(first.shape[0]):
|
||||
n = first.shape[1] * first.shape[2] * first.shape[3]
|
||||
cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', {
|
||||
'rbot0': rbot0,
|
||||
'rbot1': rbot1,
|
||||
'gradOutput': gradOutput,
|
||||
'gradFirst': None,
|
||||
'gradSecond': gradSecond
|
||||
}))(
|
||||
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
||||
block=tuple([ 512, 1, 1 ]),
|
||||
args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ]
|
||||
)
|
||||
# end
|
||||
# end
|
||||
|
||||
elif first.is_cuda == False:
|
||||
raise NotImplementedError()
|
||||
|
||||
# end
|
||||
|
||||
return gradFirst, gradSecond
|
||||
# end
|
||||
# end
|
||||
|
||||
def FunctionCorrelation(tenFirst, tenSecond):
|
||||
return _FunctionCorrelation.apply(tenFirst, tenSecond)
|
||||
# end
|
||||
|
||||
class ModuleCorrelation(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(ModuleCorrelation, self).__init__()
|
||||
# end
|
||||
|
||||
def forward(self, tenFirst, tenSecond):
|
||||
return _FunctionCorrelation.apply(tenFirst, tenSecond)
|
||||
# end
|
||||
# end
|
||||
308
eval/flolpips/flolpips.py
Normal file
308
eval/flolpips/flolpips.py
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
|
||||
from __future__ import absolute_import
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
from .pretrained_networks import vgg16, alexnet, squeezenet
|
||||
import torch.nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as TF
|
||||
import cv2
|
||||
|
||||
from .pwcnet import Network as PWCNet
|
||||
from .utils import *
|
||||
|
||||
def spatial_average(in_tens, keepdim=True):
|
||||
return in_tens.mean([2,3],keepdim=keepdim)
|
||||
|
||||
def mw_spatial_average(in_tens, flow, keepdim=True):
|
||||
_,_,h,w = in_tens.shape
|
||||
flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear')
|
||||
flow_mag = torch.sqrt(flow[:,0:1]**2 + flow[:,1:2]**2)
|
||||
flow_mag = flow_mag / torch.sum(flow_mag, dim=[1,2,3], keepdim=True)
|
||||
return torch.sum(in_tens*flow_mag, dim=[2,3],keepdim=keepdim)
|
||||
|
||||
|
||||
def mtw_spatial_average(in_tens, flow, texture, keepdim=True):
|
||||
_,_,h,w = in_tens.shape
|
||||
flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear')
|
||||
texture = F.interpolate(texture, (h,w), align_corners=False, mode='bilinear')
|
||||
flow_mag = torch.sqrt(flow[:,0:1]**2 + flow[:,1:2]**2)
|
||||
flow_mag = (flow_mag - flow_mag.min()) / (flow_mag.max() - flow_mag.min()) + 1e-6
|
||||
texture = (texture - texture.min()) / (texture.max() - texture.min()) + 1e-6
|
||||
weight = flow_mag / texture
|
||||
weight /= torch.sum(weight)
|
||||
return torch.sum(in_tens*weight, dim=[2,3],keepdim=keepdim)
|
||||
|
||||
|
||||
|
||||
def m2w_spatial_average(in_tens, flow, keepdim=True):
|
||||
_,_,h,w = in_tens.shape
|
||||
flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear')
|
||||
flow_mag = flow[:,0:1]**2 + flow[:,1:2]**2 # B,1,H,W
|
||||
flow_mag = flow_mag / torch.sum(flow_mag)
|
||||
return torch.sum(in_tens*flow_mag, dim=[2,3],keepdim=keepdim)
|
||||
|
||||
def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W
|
||||
in_H, in_W = in_tens.shape[2], in_tens.shape[3]
|
||||
return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)
|
||||
|
||||
# Learned perceptual metric
|
||||
class LPIPS(nn.Module):
|
||||
def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,
|
||||
pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False):
|
||||
# lpips - [True] means with linear calibration on top of base network
|
||||
# pretrained - [True] means load linear weights
|
||||
|
||||
super(LPIPS, self).__init__()
|
||||
if(verbose):
|
||||
print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'%
|
||||
('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))
|
||||
|
||||
self.pnet_type = net
|
||||
self.pnet_tune = pnet_tune
|
||||
self.pnet_rand = pnet_rand
|
||||
self.spatial = spatial
|
||||
self.lpips = lpips # false means baseline of just averaging all layers
|
||||
self.version = version
|
||||
self.scaling_layer = ScalingLayer()
|
||||
|
||||
if(self.pnet_type in ['vgg','vgg16']):
|
||||
net_type = vgg16
|
||||
self.chns = [64,128,256,512,512]
|
||||
elif(self.pnet_type=='alex'):
|
||||
net_type = alexnet
|
||||
self.chns = [64,192,384,256,256]
|
||||
elif(self.pnet_type=='squeeze'):
|
||||
net_type = squeezenet
|
||||
self.chns = [64,128,256,384,384,512,512]
|
||||
self.L = len(self.chns)
|
||||
|
||||
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
|
||||
|
||||
if(lpips):
|
||||
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
||||
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
||||
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
||||
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
||||
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
||||
self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
|
||||
if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
|
||||
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
||||
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
||||
self.lins+=[self.lin5,self.lin6]
|
||||
self.lins = nn.ModuleList(self.lins)
|
||||
|
||||
if(pretrained):
|
||||
if(model_path is None):
|
||||
import inspect
|
||||
import os
|
||||
model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net)))
|
||||
|
||||
if(verbose):
|
||||
print('Loading model from: %s'%model_path)
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
|
||||
|
||||
if(eval_mode):
|
||||
self.eval()
|
||||
|
||||
def forward(self, in0, in1, retPerLayer=False, normalize=False):
|
||||
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
|
||||
in0 = 2 * in0 - 1
|
||||
in1 = 2 * in1 - 1
|
||||
|
||||
# v0.0 - original release had a bug, where input was not scaled
|
||||
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
|
||||
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
||||
feats0, feats1, diffs = {}, {}, {}
|
||||
|
||||
for kk in range(self.L):
|
||||
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
||||
diffs[kk] = (feats0[kk]-feats1[kk])**2
|
||||
|
||||
if(self.lpips):
|
||||
if(self.spatial):
|
||||
res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
|
||||
else:
|
||||
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
|
||||
else:
|
||||
if(self.spatial):
|
||||
res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
|
||||
else:
|
||||
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
|
||||
|
||||
# val = res[0]
|
||||
# for l in range(1,self.L):
|
||||
# val += res[l]
|
||||
# print(val)
|
||||
|
||||
# a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
|
||||
# b = torch.max(self.lins[kk](feats0[kk]**2))
|
||||
# for kk in range(self.L):
|
||||
# a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
|
||||
# b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))
|
||||
# a = a/self.L
|
||||
# from IPython import embed
|
||||
# embed()
|
||||
# return 10*torch.log10(b/a)
|
||||
|
||||
# if(retPerLayer):
|
||||
# return (val, res)
|
||||
# else:
|
||||
return torch.sum(torch.cat(res, 1), dim=(1,2,3), keepdims=False)
|
||||
|
||||
|
||||
class ScalingLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super(ScalingLayer, self).__init__()
|
||||
self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
|
||||
self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
|
||||
|
||||
def forward(self, inp):
|
||||
return (inp - self.shift) / self.scale
|
||||
|
||||
|
||||
class NetLinLayer(nn.Module):
|
||||
''' A single linear layer which does a 1x1 conv '''
|
||||
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
||||
super(NetLinLayer, self).__init__()
|
||||
|
||||
layers = [nn.Dropout(),] if(use_dropout) else []
|
||||
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
class Dist2LogitLayer(nn.Module):
|
||||
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
|
||||
def __init__(self, chn_mid=32, use_sigmoid=True):
|
||||
super(Dist2LogitLayer, self).__init__()
|
||||
|
||||
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
|
||||
layers += [nn.LeakyReLU(0.2,True),]
|
||||
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
|
||||
layers += [nn.LeakyReLU(0.2,True),]
|
||||
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
|
||||
if(use_sigmoid):
|
||||
layers += [nn.Sigmoid(),]
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self,d0,d1,eps=0.1):
|
||||
return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
|
||||
|
||||
class BCERankingLoss(nn.Module):
|
||||
def __init__(self, chn_mid=32):
|
||||
super(BCERankingLoss, self).__init__()
|
||||
self.net = Dist2LogitLayer(chn_mid=chn_mid)
|
||||
# self.parameters = list(self.net.parameters())
|
||||
self.loss = torch.nn.BCELoss()
|
||||
|
||||
def forward(self, d0, d1, judge):
|
||||
per = (judge+1.)/2.
|
||||
self.logit = self.net.forward(d0,d1)
|
||||
return self.loss(self.logit, per)
|
||||
|
||||
# L2, DSSIM metrics
|
||||
class FakeNet(nn.Module):
|
||||
def __init__(self, use_gpu=True, colorspace='Lab'):
|
||||
super(FakeNet, self).__init__()
|
||||
self.use_gpu = use_gpu
|
||||
self.colorspace = colorspace
|
||||
|
||||
class L2(FakeNet):
|
||||
def forward(self, in0, in1, retPerLayer=None):
|
||||
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
||||
|
||||
if(self.colorspace=='RGB'):
|
||||
(N,C,X,Y) = in0.size()
|
||||
value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
|
||||
return value
|
||||
elif(self.colorspace=='Lab'):
|
||||
value = l2(tensor2np(tensor2tensorlab(in0.data,to_norm=False)),
|
||||
tensor2np(tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
||||
ret_var = Variable( torch.Tensor((value,) ) )
|
||||
if(self.use_gpu):
|
||||
ret_var = ret_var.cuda()
|
||||
return ret_var
|
||||
|
||||
class DSSIM(FakeNet):
|
||||
|
||||
def forward(self, in0, in1, retPerLayer=None):
|
||||
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
||||
|
||||
if(self.colorspace=='RGB'):
|
||||
value = dssim(1.*tensor2im(in0.data), 1.*tensor2im(in1.data), range=255.).astype('float')
|
||||
elif(self.colorspace=='Lab'):
|
||||
value = dssim(tensor2np(tensor2tensorlab(in0.data,to_norm=False)),
|
||||
tensor2np(tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
||||
ret_var = Variable( torch.Tensor((value,) ) )
|
||||
if(self.use_gpu):
|
||||
ret_var = ret_var.cuda()
|
||||
return ret_var
|
||||
|
||||
def print_network(net):
|
||||
num_params = 0
|
||||
for param in net.parameters():
|
||||
num_params += param.numel()
|
||||
print('Network',net)
|
||||
print('Total number of parameters: %d' % num_params)
|
||||
|
||||
|
||||
class FloLPIPS(LPIPS):
|
||||
def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False):
|
||||
super(FloLPIPS, self).__init__(pretrained, net, version, lpips, spatial, pnet_rand, pnet_tune, use_dropout, model_path, eval_mode, verbose)
|
||||
|
||||
def forward(self, in0, in1, flow, retPerLayer=False, normalize=False):
|
||||
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
|
||||
in0 = 2 * in0 - 1
|
||||
in1 = 2 * in1 - 1
|
||||
|
||||
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
|
||||
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
||||
feats0, feats1, diffs = {}, {}, {}
|
||||
|
||||
for kk in range(self.L):
|
||||
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
||||
diffs[kk] = (feats0[kk]-feats1[kk])**2
|
||||
|
||||
res = [mw_spatial_average(self.lins[kk](diffs[kk]), flow, keepdim=True) for kk in range(self.L)]
|
||||
|
||||
return torch.sum(torch.cat(res, 1), dim=(1,2,3), keepdims=False)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class Flolpips(nn.Module):
|
||||
def __init__(self):
|
||||
super(Flolpips, self).__init__()
|
||||
self.loss_fn = FloLPIPS(net='alex',version='0.1')
|
||||
self.flownet = PWCNet()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, I0, I1, frame_dis, frame_ref):
|
||||
"""
|
||||
args:
|
||||
I0: first frame of the triplet, shape: [B, C, H, W]
|
||||
I1: third frame of the triplet, shape: [B, C, H, W]
|
||||
frame_dis: prediction of the intermediate frame, shape: [B, C, H, W]
|
||||
frame_ref: ground-truth of the intermediate frame, shape: [B, C, H, W]
|
||||
"""
|
||||
assert I0.size() == I1.size() == frame_dis.size() == frame_ref.size(), \
|
||||
"the 4 input tensors should have same size"
|
||||
|
||||
flow_ref = self.flownet(frame_ref, I0)
|
||||
flow_dis = self.flownet(frame_dis, I0)
|
||||
flow_diff = flow_ref - flow_dis
|
||||
flolpips_wrt_I0 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True)
|
||||
|
||||
flow_ref = self.flownet(frame_ref, I1)
|
||||
flow_dis = self.flownet(frame_dis, I1)
|
||||
flow_diff = flow_ref - flow_dis
|
||||
flolpips_wrt_I1 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True)
|
||||
|
||||
flolpips = (flolpips_wrt_I0 + flolpips_wrt_I1) / 2
|
||||
return flolpips
|
||||
180
eval/flolpips/pretrained_networks.py
Normal file
180
eval/flolpips/pretrained_networks.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
from collections import namedtuple
|
||||
import torch
|
||||
from torchvision import models as tv
|
||||
|
||||
class squeezenet(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(squeezenet, self).__init__()
|
||||
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.slice6 = torch.nn.Sequential()
|
||||
self.slice7 = torch.nn.Sequential()
|
||||
self.N_slices = 7
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), pretrained_features[x])
|
||||
for x in range(2,5):
|
||||
self.slice2.add_module(str(x), pretrained_features[x])
|
||||
for x in range(5, 8):
|
||||
self.slice3.add_module(str(x), pretrained_features[x])
|
||||
for x in range(8, 10):
|
||||
self.slice4.add_module(str(x), pretrained_features[x])
|
||||
for x in range(10, 11):
|
||||
self.slice5.add_module(str(x), pretrained_features[x])
|
||||
for x in range(11, 12):
|
||||
self.slice6.add_module(str(x), pretrained_features[x])
|
||||
for x in range(12, 13):
|
||||
self.slice7.add_module(str(x), pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5 = h
|
||||
h = self.slice6(h)
|
||||
h_relu6 = h
|
||||
h = self.slice7(h)
|
||||
h_relu7 = h
|
||||
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
|
||||
out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class alexnet(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(alexnet, self).__init__()
|
||||
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.N_slices = 5
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(2, 5):
|
||||
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(5, 8):
|
||||
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(8, 10):
|
||||
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(10, 12):
|
||||
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5 = h
|
||||
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
|
||||
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
||||
|
||||
return out
|
||||
|
||||
class vgg16(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(vgg16, self).__init__()
|
||||
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.N_slices = 5
|
||||
for x in range(4):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(4, 9):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(9, 16):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(16, 23):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(23, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1_2 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2_2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3_3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4_3 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5_3 = h
|
||||
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
||||
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
class resnet(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
||||
super(resnet, self).__init__()
|
||||
if(num==18):
|
||||
self.net = tv.resnet18(pretrained=pretrained)
|
||||
elif(num==34):
|
||||
self.net = tv.resnet34(pretrained=pretrained)
|
||||
elif(num==50):
|
||||
self.net = tv.resnet50(pretrained=pretrained)
|
||||
elif(num==101):
|
||||
self.net = tv.resnet101(pretrained=pretrained)
|
||||
elif(num==152):
|
||||
self.net = tv.resnet152(pretrained=pretrained)
|
||||
self.N_slices = 5
|
||||
|
||||
self.conv1 = self.net.conv1
|
||||
self.bn1 = self.net.bn1
|
||||
self.relu = self.net.relu
|
||||
self.maxpool = self.net.maxpool
|
||||
self.layer1 = self.net.layer1
|
||||
self.layer2 = self.net.layer2
|
||||
self.layer3 = self.net.layer3
|
||||
self.layer4 = self.net.layer4
|
||||
|
||||
def forward(self, X):
|
||||
h = self.conv1(X)
|
||||
h = self.bn1(h)
|
||||
h = self.relu(h)
|
||||
h_relu1 = h
|
||||
h = self.maxpool(h)
|
||||
h = self.layer1(h)
|
||||
h_conv2 = h
|
||||
h = self.layer2(h)
|
||||
h_conv3 = h
|
||||
h = self.layer3(h)
|
||||
h_conv4 = h
|
||||
h = self.layer4(h)
|
||||
h_conv5 = h
|
||||
|
||||
outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
|
||||
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
||||
|
||||
return out
|
||||
344
eval/flolpips/pwcnet.py
Normal file
344
eval/flolpips/pwcnet.py
Normal file
|
|
@ -0,0 +1,344 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import torch
|
||||
|
||||
import getopt
|
||||
import math
|
||||
import numpy
|
||||
import os
|
||||
import PIL
|
||||
import PIL.Image
|
||||
import sys
|
||||
|
||||
# try:
|
||||
from .correlation import correlation # the custom cost volume layer
|
||||
# except:
|
||||
# sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python
|
||||
# end
|
||||
|
||||
##########################################################
|
||||
|
||||
# assert(int(str('').join(torch.__version__.split('.')[0:2])) >= 13) # requires at least pytorch version 1.3.0
|
||||
|
||||
# torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance
|
||||
|
||||
# torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
|
||||
|
||||
# ##########################################################
|
||||
|
||||
# arguments_strModel = 'default' # 'default', or 'chairs-things'
|
||||
# arguments_strFirst = './images/first.png'
|
||||
# arguments_strSecond = './images/second.png'
|
||||
# arguments_strOut = './out.flo'
|
||||
|
||||
# for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]:
|
||||
# if strOption == '--model' and strArgument != '': arguments_strModel = strArgument # which model to use
|
||||
# if strOption == '--first' and strArgument != '': arguments_strFirst = strArgument # path to the first frame
|
||||
# if strOption == '--second' and strArgument != '': arguments_strSecond = strArgument # path to the second frame
|
||||
# if strOption == '--out' and strArgument != '': arguments_strOut = strArgument # path to where the output should be stored
|
||||
# end
|
||||
|
||||
##########################################################
|
||||
|
||||
|
||||
|
||||
def backwarp(tenInput, tenFlow):
|
||||
backwarp_tenGrid = {}
|
||||
backwarp_tenPartial = {}
|
||||
if str(tenFlow.shape) not in backwarp_tenGrid:
|
||||
tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1)
|
||||
tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3])
|
||||
|
||||
backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1).cuda()
|
||||
# end
|
||||
|
||||
if str(tenFlow.shape) not in backwarp_tenPartial:
|
||||
backwarp_tenPartial[str(tenFlow.shape)] = tenFlow.new_ones([ tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3] ])
|
||||
# end
|
||||
|
||||
tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
|
||||
tenInput = torch.cat([ tenInput, backwarp_tenPartial[str(tenFlow.shape)] ], 1)
|
||||
|
||||
tenOutput = torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
|
||||
tenMask = tenOutput[:, -1:, :, :]; tenMask[tenMask > 0.999] = 1.0; tenMask[tenMask < 1.0] = 0.0
|
||||
|
||||
return tenOutput[:, :-1, :, :] * tenMask
|
||||
# end
|
||||
|
||||
##########################################################
|
||||
|
||||
class Network(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Network, self).__init__()
|
||||
|
||||
class Extractor(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Extractor, self).__init__()
|
||||
|
||||
self.netOne = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
||||
)
|
||||
|
||||
self.netTwo = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
||||
)
|
||||
|
||||
self.netThr = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
||||
)
|
||||
|
||||
self.netFou = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
||||
)
|
||||
|
||||
self.netFiv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
||||
)
|
||||
|
||||
self.netSix = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
||||
)
|
||||
# end
|
||||
|
||||
def forward(self, tenInput):
|
||||
tenOne = self.netOne(tenInput)
|
||||
tenTwo = self.netTwo(tenOne)
|
||||
tenThr = self.netThr(tenTwo)
|
||||
tenFou = self.netFou(tenThr)
|
||||
tenFiv = self.netFiv(tenFou)
|
||||
tenSix = self.netSix(tenFiv)
|
||||
|
||||
return [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ]
|
||||
# end
|
||||
# end
|
||||
|
||||
class Decoder(torch.nn.Module):
|
||||
def __init__(self, intLevel):
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1]
|
||||
intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0]
|
||||
|
||||
if intLevel < 6: self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1)
|
||||
if intLevel < 6: self.netUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1)
|
||||
if intLevel < 6: self.fltBackwarp = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1]
|
||||
|
||||
self.netOne = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
||||
)
|
||||
|
||||
self.netTwo = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
||||
)
|
||||
|
||||
self.netThr = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
||||
)
|
||||
|
||||
self.netFou = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
||||
)
|
||||
|
||||
self.netFiv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
||||
)
|
||||
|
||||
self.netSix = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
# end
|
||||
|
||||
def forward(self, tenFirst, tenSecond, objPrevious):
|
||||
tenFlow = None
|
||||
tenFeat = None
|
||||
|
||||
if objPrevious is None:
|
||||
tenFlow = None
|
||||
tenFeat = None
|
||||
|
||||
tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=tenSecond), negative_slope=0.1, inplace=False)
|
||||
|
||||
tenFeat = torch.cat([ tenVolume ], 1)
|
||||
|
||||
elif objPrevious is not None:
|
||||
tenFlow = self.netUpflow(objPrevious['tenFlow'])
|
||||
tenFeat = self.netUpfeat(objPrevious['tenFeat'])
|
||||
|
||||
tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=backwarp(tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp)), negative_slope=0.1, inplace=False)
|
||||
|
||||
tenFeat = torch.cat([ tenVolume, tenFirst, tenFlow, tenFeat ], 1)
|
||||
|
||||
# end
|
||||
|
||||
tenFeat = torch.cat([ self.netOne(tenFeat), tenFeat ], 1)
|
||||
tenFeat = torch.cat([ self.netTwo(tenFeat), tenFeat ], 1)
|
||||
tenFeat = torch.cat([ self.netThr(tenFeat), tenFeat ], 1)
|
||||
tenFeat = torch.cat([ self.netFou(tenFeat), tenFeat ], 1)
|
||||
tenFeat = torch.cat([ self.netFiv(tenFeat), tenFeat ], 1)
|
||||
|
||||
tenFlow = self.netSix(tenFeat)
|
||||
|
||||
return {
|
||||
'tenFlow': tenFlow,
|
||||
'tenFeat': tenFeat
|
||||
}
|
||||
# end
|
||||
# end
|
||||
|
||||
class Refiner(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Refiner, self).__init__()
|
||||
|
||||
self.netMain = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1)
|
||||
)
|
||||
# end
|
||||
|
||||
def forward(self, tenInput):
|
||||
return self.netMain(tenInput)
|
||||
# end
|
||||
# end
|
||||
|
||||
self.netExtractor = Extractor()
|
||||
|
||||
self.netTwo = Decoder(2)
|
||||
self.netThr = Decoder(3)
|
||||
self.netFou = Decoder(4)
|
||||
self.netFiv = Decoder(5)
|
||||
self.netSix = Decoder(6)
|
||||
|
||||
self.netRefiner = Refiner()
|
||||
|
||||
self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/github/pytorch-pwc/network-' + 'default' + '.pytorch').items() })
|
||||
# end
|
||||
|
||||
def forward(self, tenFirst, tenSecond):
|
||||
intWidth = tenFirst.shape[3]
|
||||
intHeight = tenFirst.shape[2]
|
||||
|
||||
intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0))
|
||||
intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0))
|
||||
|
||||
tenPreprocessedFirst = torch.nn.functional.interpolate(input=tenFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
|
||||
tenPreprocessedSecond = torch.nn.functional.interpolate(input=tenSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
|
||||
|
||||
tenFirst = self.netExtractor(tenPreprocessedFirst)
|
||||
tenSecond = self.netExtractor(tenPreprocessedSecond)
|
||||
|
||||
|
||||
objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None)
|
||||
objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate)
|
||||
objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate)
|
||||
objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate)
|
||||
objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate)
|
||||
|
||||
tenFlow = objEstimate['tenFlow'] + self.netRefiner(objEstimate['tenFeat'])
|
||||
tenFlow = 20.0 * torch.nn.functional.interpolate(input=tenFlow, size=(intHeight, intWidth), mode='bilinear', align_corners=False)
|
||||
tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth)
|
||||
tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight)
|
||||
|
||||
return tenFlow
|
||||
# end
|
||||
# end
|
||||
|
||||
netNetwork = None
|
||||
|
||||
##########################################################
|
||||
|
||||
def estimate(tenFirst, tenSecond):
|
||||
global netNetwork
|
||||
|
||||
if netNetwork is None:
|
||||
netNetwork = Network().cuda().eval()
|
||||
# end
|
||||
|
||||
assert(tenFirst.shape[1] == tenSecond.shape[1])
|
||||
assert(tenFirst.shape[2] == tenSecond.shape[2])
|
||||
|
||||
intWidth = tenFirst.shape[2]
|
||||
intHeight = tenFirst.shape[1]
|
||||
|
||||
assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
|
||||
assert(intHeight == 436) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
|
||||
|
||||
tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth)
|
||||
tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth)
|
||||
|
||||
intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0))
|
||||
intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0))
|
||||
|
||||
tenPreprocessedFirst = torch.nn.functional.interpolate(input=tenPreprocessedFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
|
||||
tenPreprocessedSecond = torch.nn.functional.interpolate(input=tenPreprocessedSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
|
||||
|
||||
tenFlow = 20.0 * torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedFirst, tenPreprocessedSecond), size=(intHeight, intWidth), mode='bilinear', align_corners=False)
|
||||
|
||||
tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth)
|
||||
tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight)
|
||||
|
||||
return tenFlow[0, :, :, :].cpu()
|
||||
# end
|
||||
|
||||
##########################################################
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# tenFirst = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strFirst))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
|
||||
# tenSecond = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strSecond))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
|
||||
|
||||
# tenOutput = estimate(tenFirst, tenSecond)
|
||||
|
||||
# objOutput = open(arguments_strOut, 'wb')
|
||||
|
||||
# numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput)
|
||||
# numpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput)
|
||||
# numpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput)
|
||||
|
||||
# objOutput.close()
|
||||
# end
|
||||
95
eval/flolpips/utils.py
Normal file
95
eval/flolpips/utils.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
|
||||
def normalize_tensor(in_feat,eps=1e-10):
|
||||
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
|
||||
return in_feat/(norm_factor+eps)
|
||||
|
||||
def l2(p0, p1, range=255.):
|
||||
return .5*np.mean((p0 / range - p1 / range)**2)
|
||||
|
||||
def dssim(p0, p1, range=255.):
|
||||
from skimage.measure import compare_ssim
|
||||
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
|
||||
|
||||
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
||||
image_numpy = image_tensor[0].cpu().float().numpy()
|
||||
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
||||
return image_numpy.astype(imtype)
|
||||
|
||||
def tensor2np(tensor_obj):
|
||||
# change dimension of a tensor object into a numpy array
|
||||
return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
|
||||
|
||||
def np2tensor(np_obj):
|
||||
# change dimenion of np array into tensor array
|
||||
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
||||
|
||||
def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
|
||||
# image tensor to lab tensor
|
||||
from skimage import color
|
||||
|
||||
img = tensor2im(image_tensor)
|
||||
img_lab = color.rgb2lab(img)
|
||||
if(mc_only):
|
||||
img_lab[:,:,0] = img_lab[:,:,0]-50
|
||||
if(to_norm and not mc_only):
|
||||
img_lab[:,:,0] = img_lab[:,:,0]-50
|
||||
img_lab = img_lab/100.
|
||||
|
||||
return np2tensor(img_lab)
|
||||
|
||||
def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt='420'):
|
||||
if pix_fmt == '420':
|
||||
multiplier = 1
|
||||
uv_factor = 2
|
||||
elif pix_fmt == '444':
|
||||
multiplier = 2
|
||||
uv_factor = 1
|
||||
else:
|
||||
print('Pixel format {} is not supported'.format(pix_fmt))
|
||||
return
|
||||
|
||||
if bit_depth == 8:
|
||||
datatype = np.uint8
|
||||
stream.seek(iFrame*1.5*width*height*multiplier)
|
||||
Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width))
|
||||
|
||||
# read chroma samples and upsample since original is 4:2:0 sampling
|
||||
U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
|
||||
reshape((height//uv_factor, width//uv_factor))
|
||||
V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
|
||||
reshape((height//uv_factor, width//uv_factor))
|
||||
|
||||
else:
|
||||
datatype = np.uint16
|
||||
stream.seek(iFrame*3*width*height*multiplier)
|
||||
Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width))
|
||||
|
||||
U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
|
||||
reshape((height//uv_factor, width//uv_factor))
|
||||
V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
|
||||
reshape((height//uv_factor, width//uv_factor))
|
||||
|
||||
if pix_fmt == '420':
|
||||
yuv = np.empty((height*3//2, width), dtype=datatype)
|
||||
yuv[0:height,:] = Y
|
||||
|
||||
yuv[height:height+height//4,:] = U.reshape(-1, width)
|
||||
yuv[height+height//4:,:] = V.reshape(-1, width)
|
||||
|
||||
if bit_depth != 8:
|
||||
yuv = (yuv/(2**bit_depth-1)*255).astype(np.uint8)
|
||||
|
||||
#convert to rgb
|
||||
rgb = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB_I420)
|
||||
|
||||
else:
|
||||
yvu = np.stack([Y,V,U],axis=2)
|
||||
if bit_depth != 8:
|
||||
yvu = (yvu/(2**bit_depth-1)*255).astype(np.uint8)
|
||||
rgb = cv2.cvtColor(yvu, cv2.COLOR_YCrCb2RGB)
|
||||
|
||||
return rgb
|
||||
90
eval/fvd/styleganv/fvd.py
Normal file
90
eval/fvd/styleganv/fvd.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
import torch
|
||||
import os
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
|
||||
# https://github.com/universome/fvd-comparison
|
||||
|
||||
|
||||
def load_i3d_pretrained(device=torch.device('cpu')):
|
||||
i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt"
|
||||
filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt')
|
||||
print(filepath)
|
||||
if not os.path.exists(filepath):
|
||||
print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
|
||||
os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
|
||||
i3d = torch.jit.load(filepath).eval().to(device)
|
||||
i3d = torch.nn.DataParallel(i3d)
|
||||
return i3d
|
||||
|
||||
|
||||
def get_feats(videos, detector, device, bs=10):
|
||||
# videos : torch.tensor BCTHW [0, 1]
|
||||
detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer.
|
||||
feats = np.empty((0, 400))
|
||||
with torch.no_grad():
|
||||
for i in range((len(videos)-1)//bs + 1):
|
||||
feats = np.vstack([feats, detector(torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device), **detector_kwargs).detach().cpu().numpy()])
|
||||
return feats
|
||||
|
||||
|
||||
def get_fvd_feats(videos, i3d, device, bs=10):
|
||||
# videos in [0, 1] as torch tensor BCTHW
|
||||
# videos = [preprocess_single(video) for video in videos]
|
||||
embeddings = get_feats(videos, i3d, device, bs)
|
||||
return embeddings
|
||||
|
||||
|
||||
def preprocess_single(video, resolution=224, sequence_length=None):
|
||||
# video: CTHW, [0, 1]
|
||||
c, t, h, w = video.shape
|
||||
|
||||
# temporal crop
|
||||
if sequence_length is not None:
|
||||
assert sequence_length <= t
|
||||
video = video[:, :sequence_length]
|
||||
|
||||
# scale shorter side to resolution
|
||||
scale = resolution / min(h, w)
|
||||
if h < w:
|
||||
target_size = (resolution, math.ceil(w * scale))
|
||||
else:
|
||||
target_size = (math.ceil(h * scale), resolution)
|
||||
video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False)
|
||||
|
||||
# center crop
|
||||
c, t, h, w = video.shape
|
||||
w_start = (w - resolution) // 2
|
||||
h_start = (h - resolution) // 2
|
||||
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
|
||||
|
||||
# [0, 1] -> [-1, 1]
|
||||
video = (video - 0.5) * 2
|
||||
|
||||
return video.contiguous()
|
||||
|
||||
|
||||
"""
|
||||
Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py
|
||||
"""
|
||||
from typing import Tuple
|
||||
from scipy.linalg import sqrtm
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
mu = feats.mean(axis=0) # [d]
|
||||
sigma = np.cov(feats, rowvar=False) # [d, d]
|
||||
return mu, sigma
|
||||
|
||||
|
||||
def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
|
||||
mu_gen, sigma_gen = compute_stats(feats_fake)
|
||||
mu_real, sigma_real = compute_stats(feats_real)
|
||||
m = np.square(mu_gen - mu_real).sum()
|
||||
if feats_fake.shape[0]>1:
|
||||
s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
|
||||
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
||||
else:
|
||||
fid = np.real(m)
|
||||
return float(fid)
|
||||
137
eval/fvd/videogpt/fvd.py
Normal file
137
eval/fvd/videogpt/fvd.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
import torch
|
||||
import os
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import einops
|
||||
|
||||
def load_i3d_pretrained(device=torch.device('cpu')):
|
||||
i3D_WEIGHTS_URL = "https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI"
|
||||
filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_pretrained_400.pt')
|
||||
print(filepath)
|
||||
if not os.path.exists(filepath):
|
||||
print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
|
||||
os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
|
||||
from .pytorch_i3d import InceptionI3d
|
||||
i3d = InceptionI3d(400, in_channels=3).eval().to(device)
|
||||
i3d.load_state_dict(torch.load(filepath, map_location=device))
|
||||
i3d = torch.nn.DataParallel(i3d)
|
||||
return i3d
|
||||
|
||||
def preprocess_single(video, resolution, sequence_length=None):
|
||||
# video: THWC, {0, ..., 255}
|
||||
video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
|
||||
t, c, h, w = video.shape
|
||||
|
||||
# temporal crop
|
||||
if sequence_length is not None:
|
||||
assert sequence_length <= t
|
||||
video = video[:sequence_length]
|
||||
|
||||
# scale shorter side to resolution
|
||||
scale = resolution / min(h, w)
|
||||
if h < w:
|
||||
target_size = (resolution, math.ceil(w * scale))
|
||||
else:
|
||||
target_size = (math.ceil(h * scale), resolution)
|
||||
video = F.interpolate(video, size=target_size, mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
# center crop
|
||||
t, c, h, w = video.shape
|
||||
w_start = (w - resolution) // 2
|
||||
h_start = (h - resolution) // 2
|
||||
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
|
||||
video = video.permute(1, 0, 2, 3).contiguous() # CTHW
|
||||
|
||||
video -= 0.5
|
||||
|
||||
return video
|
||||
|
||||
def preprocess(videos, target_resolution=224):
|
||||
# we should tras videos in [0-1] [b c t h w] as th.float
|
||||
# -> videos in {0, ..., 255} [b t h w c] as np.uint8 array
|
||||
videos = einops.rearrange(videos, 'b c t h w -> b t h w c')
|
||||
videos = (videos*255).numpy().astype(np.uint8)
|
||||
|
||||
b, t, h, w, c = videos.shape
|
||||
videos = torch.from_numpy(videos)
|
||||
videos = torch.stack([preprocess_single(video, target_resolution) for video in videos])
|
||||
return videos * 2 # [-0.5, 0.5] -> [-1, 1]
|
||||
|
||||
def get_fvd_logits(videos, i3d, device, bs=10):
|
||||
videos = preprocess(videos)
|
||||
embeddings = get_logits(i3d, videos, device, bs=10)
|
||||
return embeddings
|
||||
|
||||
# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161
|
||||
def _symmetric_matrix_square_root(mat, eps=1e-10):
|
||||
u, s, v = torch.svd(mat)
|
||||
si = torch.where(s < eps, s, torch.sqrt(s))
|
||||
return torch.matmul(torch.matmul(u, torch.diag(si)), v.t())
|
||||
|
||||
# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400
|
||||
def trace_sqrt_product(sigma, sigma_v):
|
||||
sqrt_sigma = _symmetric_matrix_square_root(sigma)
|
||||
sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma))
|
||||
return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
|
||||
|
||||
# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
|
||||
def cov(m, rowvar=False):
|
||||
'''Estimate a covariance matrix given data.
|
||||
|
||||
Covariance indicates the level to which two variables vary together.
|
||||
If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
|
||||
then the covariance matrix element `C_{ij}` is the covariance of
|
||||
`x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
|
||||
|
||||
Args:
|
||||
m: A 1-D or 2-D array containing multiple variables and observations.
|
||||
Each row of `m` represents a variable, and each column a single
|
||||
observation of all those variables.
|
||||
rowvar: If `rowvar` is True, then each row represents a
|
||||
variable, with observations in the columns. Otherwise, the
|
||||
relationship is transposed: each column represents a variable,
|
||||
while the rows contain observations.
|
||||
|
||||
Returns:
|
||||
The covariance matrix of the variables.
|
||||
'''
|
||||
if m.dim() > 2:
|
||||
raise ValueError('m has more than 2 dimensions')
|
||||
if m.dim() < 2:
|
||||
m = m.view(1, -1)
|
||||
if not rowvar and m.size(0) != 1:
|
||||
m = m.t()
|
||||
|
||||
fact = 1.0 / (m.size(1) - 1) # unbiased estimate
|
||||
m -= torch.mean(m, dim=1, keepdim=True)
|
||||
mt = m.t() # if complex: mt = m.t().conj()
|
||||
return fact * m.matmul(mt).squeeze()
|
||||
|
||||
|
||||
def frechet_distance(x1, x2):
|
||||
x1 = x1.flatten(start_dim=1)
|
||||
x2 = x2.flatten(start_dim=1)
|
||||
m, m_w = x1.mean(dim=0), x2.mean(dim=0)
|
||||
sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False)
|
||||
mean = torch.sum((m - m_w) ** 2)
|
||||
if x1.shape[0]>1:
|
||||
sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)
|
||||
trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component
|
||||
fd = trace + mean
|
||||
else:
|
||||
fd = np.real(mean)
|
||||
return float(fd)
|
||||
|
||||
|
||||
def get_logits(i3d, videos, device, bs=10):
|
||||
# assert videos.shape[0] % 16 == 0
|
||||
with torch.no_grad():
|
||||
logits = []
|
||||
for i in range(0, videos.shape[0], bs):
|
||||
batch = videos[i:i + bs].to(device)
|
||||
# logits.append(i3d.module.extract_features(batch)) # wrong
|
||||
logits.append(i3d(batch)) # right
|
||||
logits = torch.cat(logits, dim=0)
|
||||
return logits
|
||||
322
eval/fvd/videogpt/pytorch_i3d.py
Normal file
322
eval/fvd/videogpt/pytorch_i3d.py
Normal file
|
|
@ -0,0 +1,322 @@
|
|||
# Original code from https://github.com/piergiaj/pytorch-i3d
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
class MaxPool3dSamePadding(nn.MaxPool3d):
|
||||
|
||||
def compute_pad(self, dim, s):
|
||||
if s % self.stride[dim] == 0:
|
||||
return max(self.kernel_size[dim] - self.stride[dim], 0)
|
||||
else:
|
||||
return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
|
||||
|
||||
def forward(self, x):
|
||||
# compute 'same' padding
|
||||
(batch, channel, t, h, w) = x.size()
|
||||
out_t = np.ceil(float(t) / float(self.stride[0]))
|
||||
out_h = np.ceil(float(h) / float(self.stride[1]))
|
||||
out_w = np.ceil(float(w) / float(self.stride[2]))
|
||||
pad_t = self.compute_pad(0, t)
|
||||
pad_h = self.compute_pad(1, h)
|
||||
pad_w = self.compute_pad(2, w)
|
||||
|
||||
pad_t_f = pad_t // 2
|
||||
pad_t_b = pad_t - pad_t_f
|
||||
pad_h_f = pad_h // 2
|
||||
pad_h_b = pad_h - pad_h_f
|
||||
pad_w_f = pad_w // 2
|
||||
pad_w_b = pad_w - pad_w_f
|
||||
|
||||
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
||||
x = F.pad(x, pad)
|
||||
return super(MaxPool3dSamePadding, self).forward(x)
|
||||
|
||||
|
||||
class Unit3D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels,
|
||||
output_channels,
|
||||
kernel_shape=(1, 1, 1),
|
||||
stride=(1, 1, 1),
|
||||
padding=0,
|
||||
activation_fn=F.relu,
|
||||
use_batch_norm=True,
|
||||
use_bias=False,
|
||||
name='unit_3d'):
|
||||
|
||||
"""Initializes Unit3D module."""
|
||||
super(Unit3D, self).__init__()
|
||||
|
||||
self._output_channels = output_channels
|
||||
self._kernel_shape = kernel_shape
|
||||
self._stride = stride
|
||||
self._use_batch_norm = use_batch_norm
|
||||
self._activation_fn = activation_fn
|
||||
self._use_bias = use_bias
|
||||
self.name = name
|
||||
self.padding = padding
|
||||
|
||||
self.conv3d = nn.Conv3d(in_channels=in_channels,
|
||||
out_channels=self._output_channels,
|
||||
kernel_size=self._kernel_shape,
|
||||
stride=self._stride,
|
||||
padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function
|
||||
bias=self._use_bias)
|
||||
|
||||
if self._use_batch_norm:
|
||||
self.bn = nn.BatchNorm3d(self._output_channels, eps=1e-5, momentum=0.001)
|
||||
|
||||
def compute_pad(self, dim, s):
|
||||
if s % self._stride[dim] == 0:
|
||||
return max(self._kernel_shape[dim] - self._stride[dim], 0)
|
||||
else:
|
||||
return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# compute 'same' padding
|
||||
(batch, channel, t, h, w) = x.size()
|
||||
out_t = np.ceil(float(t) / float(self._stride[0]))
|
||||
out_h = np.ceil(float(h) / float(self._stride[1]))
|
||||
out_w = np.ceil(float(w) / float(self._stride[2]))
|
||||
pad_t = self.compute_pad(0, t)
|
||||
pad_h = self.compute_pad(1, h)
|
||||
pad_w = self.compute_pad(2, w)
|
||||
|
||||
pad_t_f = pad_t // 2
|
||||
pad_t_b = pad_t - pad_t_f
|
||||
pad_h_f = pad_h // 2
|
||||
pad_h_b = pad_h - pad_h_f
|
||||
pad_w_f = pad_w // 2
|
||||
pad_w_b = pad_w - pad_w_f
|
||||
|
||||
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
||||
x = F.pad(x, pad)
|
||||
|
||||
x = self.conv3d(x)
|
||||
if self._use_batch_norm:
|
||||
x = self.bn(x)
|
||||
if self._activation_fn is not None:
|
||||
x = self._activation_fn(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class InceptionModule(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, name):
|
||||
super(InceptionModule, self).__init__()
|
||||
|
||||
self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,
|
||||
name=name+'/Branch_0/Conv3d_0a_1x1')
|
||||
self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,
|
||||
name=name+'/Branch_1/Conv3d_0a_1x1')
|
||||
self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3],
|
||||
name=name+'/Branch_1/Conv3d_0b_3x3')
|
||||
self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,
|
||||
name=name+'/Branch_2/Conv3d_0a_1x1')
|
||||
self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3],
|
||||
name=name+'/Branch_2/Conv3d_0b_3x3')
|
||||
self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
|
||||
stride=(1, 1, 1), padding=0)
|
||||
self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,
|
||||
name=name+'/Branch_3/Conv3d_0b_1x1')
|
||||
self.name = name
|
||||
|
||||
def forward(self, x):
|
||||
b0 = self.b0(x)
|
||||
b1 = self.b1b(self.b1a(x))
|
||||
b2 = self.b2b(self.b2a(x))
|
||||
b3 = self.b3b(self.b3a(x))
|
||||
return torch.cat([b0,b1,b2,b3], dim=1)
|
||||
|
||||
|
||||
class InceptionI3d(nn.Module):
|
||||
"""Inception-v1 I3D architecture.
|
||||
The model is introduced in:
|
||||
Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
|
||||
Joao Carreira, Andrew Zisserman
|
||||
https://arxiv.org/pdf/1705.07750v1.pdf.
|
||||
See also the Inception architecture, introduced in:
|
||||
Going deeper with convolutions
|
||||
Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
|
||||
Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
|
||||
http://arxiv.org/pdf/1409.4842v1.pdf.
|
||||
"""
|
||||
|
||||
# Endpoints of the model in order. During construction, all the endpoints up
|
||||
# to a designated `final_endpoint` are returned in a dictionary as the
|
||||
# second return value.
|
||||
VALID_ENDPOINTS = (
|
||||
'Conv3d_1a_7x7',
|
||||
'MaxPool3d_2a_3x3',
|
||||
'Conv3d_2b_1x1',
|
||||
'Conv3d_2c_3x3',
|
||||
'MaxPool3d_3a_3x3',
|
||||
'Mixed_3b',
|
||||
'Mixed_3c',
|
||||
'MaxPool3d_4a_3x3',
|
||||
'Mixed_4b',
|
||||
'Mixed_4c',
|
||||
'Mixed_4d',
|
||||
'Mixed_4e',
|
||||
'Mixed_4f',
|
||||
'MaxPool3d_5a_2x2',
|
||||
'Mixed_5b',
|
||||
'Mixed_5c',
|
||||
'Logits',
|
||||
'Predictions',
|
||||
)
|
||||
|
||||
def __init__(self, num_classes=400, spatial_squeeze=True,
|
||||
final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5):
|
||||
"""Initializes I3D model instance.
|
||||
Args:
|
||||
num_classes: The number of outputs in the logit layer (default 400, which
|
||||
matches the Kinetics dataset).
|
||||
spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
|
||||
before returning (default True).
|
||||
final_endpoint: The model contains many possible endpoints.
|
||||
`final_endpoint` specifies the last endpoint for the model to be built
|
||||
up to. In addition to the output at `final_endpoint`, all the outputs
|
||||
at endpoints up to `final_endpoint` will also be returned, in a
|
||||
dictionary. `final_endpoint` must be one of
|
||||
InceptionI3d.VALID_ENDPOINTS (default 'Logits').
|
||||
name: A string (optional). The name of this module.
|
||||
Raises:
|
||||
ValueError: if `final_endpoint` is not recognized.
|
||||
"""
|
||||
|
||||
if final_endpoint not in self.VALID_ENDPOINTS:
|
||||
raise ValueError('Unknown final endpoint %s' % final_endpoint)
|
||||
|
||||
super(InceptionI3d, self).__init__()
|
||||
self._num_classes = num_classes
|
||||
self._spatial_squeeze = spatial_squeeze
|
||||
self._final_endpoint = final_endpoint
|
||||
self.logits = None
|
||||
|
||||
if self._final_endpoint not in self.VALID_ENDPOINTS:
|
||||
raise ValueError('Unknown final endpoint %s' % self._final_endpoint)
|
||||
|
||||
self.end_points = {}
|
||||
end_point = 'Conv3d_1a_7x7'
|
||||
self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],
|
||||
stride=(2, 2, 2), padding=(3,3,3), name=name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'MaxPool3d_2a_3x3'
|
||||
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
|
||||
padding=0)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Conv3d_2b_1x1'
|
||||
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
|
||||
name=name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Conv3d_2c_3x3'
|
||||
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,
|
||||
name=name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'MaxPool3d_3a_3x3'
|
||||
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
|
||||
padding=0)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_3b'
|
||||
self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_3c'
|
||||
self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'MaxPool3d_4a_3x3'
|
||||
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2),
|
||||
padding=0)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_4b'
|
||||
self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_4c'
|
||||
self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_4d'
|
||||
self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_4e'
|
||||
self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_4f'
|
||||
self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'MaxPool3d_5a_2x2'
|
||||
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2),
|
||||
padding=0)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_5b'
|
||||
self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Mixed_5c'
|
||||
self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point)
|
||||
if self._final_endpoint == end_point: return
|
||||
|
||||
end_point = 'Logits'
|
||||
self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7],
|
||||
stride=(1, 1, 1))
|
||||
self.dropout = nn.Dropout(dropout_keep_prob)
|
||||
self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
|
||||
kernel_shape=[1, 1, 1],
|
||||
padding=0,
|
||||
activation_fn=None,
|
||||
use_batch_norm=False,
|
||||
use_bias=True,
|
||||
name='logits')
|
||||
|
||||
self.build()
|
||||
|
||||
|
||||
def replace_logits(self, num_classes):
|
||||
self._num_classes = num_classes
|
||||
self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
|
||||
kernel_shape=[1, 1, 1],
|
||||
padding=0,
|
||||
activation_fn=None,
|
||||
use_batch_norm=False,
|
||||
use_bias=True,
|
||||
name='logits')
|
||||
|
||||
|
||||
def build(self):
|
||||
for k in self.end_points.keys():
|
||||
self.add_module(k, self.end_points[k])
|
||||
|
||||
def forward(self, x):
|
||||
for end_point in self.VALID_ENDPOINTS:
|
||||
if end_point in self.end_points:
|
||||
x = self._modules[end_point](x) # use _modules to work with dataparallel
|
||||
|
||||
x = self.logits(self.dropout(self.avg_pool(x)))
|
||||
if self._spatial_squeeze:
|
||||
logits = x.squeeze(3).squeeze(3)
|
||||
logits = logits.mean(dim=2)
|
||||
# logits is batch X time X classes, which is what we want to work with
|
||||
return logits
|
||||
|
||||
|
||||
def extract_features(self, x):
|
||||
for end_point in self.VALID_ENDPOINTS:
|
||||
if end_point in self.end_points:
|
||||
x = self._modules[end_point](x)
|
||||
return self.avg_pool(x)
|
||||
23
eval/script/cal_clip_score.sh
Normal file
23
eval/script/cal_clip_score.sh
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# clip_score cross modality
|
||||
python eval_clip_score.py \
|
||||
--real_path path/to/image \
|
||||
--generated_path path/to/text \
|
||||
--batch-size 50 \
|
||||
--device "cuda"
|
||||
|
||||
# clip_score within the same modality
|
||||
python eval_clip_score.py \
|
||||
--real_path path/to/textA \
|
||||
--generated_path path/to/textB \
|
||||
--real_flag txt \
|
||||
--generated_flag txt \
|
||||
--batch-size 50 \
|
||||
--device "cuda"
|
||||
|
||||
python eval_clip_score.py \
|
||||
--real_path path/to/imageA \
|
||||
--generated_path path/to/imageB \
|
||||
--real_flag img \
|
||||
--generated_flag img \
|
||||
--batch-size 50 \
|
||||
--device "cuda"
|
||||
9
eval/script/cal_fvd.sh
Normal file
9
eval/script/cal_fvd.sh
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
python eval_common_metric.py \
|
||||
--real_video_dir path/to/imageA\
|
||||
--generated_video_dir path/to/imageB \
|
||||
--batch_size 10 \
|
||||
--crop_size 64 \
|
||||
--num_frames 20 \
|
||||
--device 'cuda' \
|
||||
--metric 'fvd' \
|
||||
--fvd_method 'styleganv'
|
||||
8
eval/script/cal_lpips.sh
Normal file
8
eval/script/cal_lpips.sh
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
python eval_common_metric.py \
|
||||
--real_video_dir path/to/imageA\
|
||||
--generated_video_dir path/to/imageB \
|
||||
--batch_size 10 \
|
||||
--num_frames 20 \
|
||||
--crop_size 64 \
|
||||
--device 'cuda' \
|
||||
--metric 'lpips'
|
||||
9
eval/script/cal_psnr.sh
Normal file
9
eval/script/cal_psnr.sh
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
|
||||
python eval_common_metric.py \
|
||||
--real_video_dir /data/xiaogeng_liu/data/video1 \
|
||||
--generated_video_dir /data/xiaogeng_liu/data/video2 \
|
||||
--batch_size 10 \
|
||||
--num_frames 20 \
|
||||
--crop_size 64 \
|
||||
--device 'cuda' \
|
||||
--metric 'psnr'
|
||||
8
eval/script/cal_ssim.sh
Normal file
8
eval/script/cal_ssim.sh
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
python eval_common_metric.py \
|
||||
--real_video_dir /data/xiaogeng_liu/data/video1 \
|
||||
--generated_video_dir /data/xiaogeng_liu/data/video2 \
|
||||
--batch_size 10 \
|
||||
--num_frames 20 \
|
||||
--crop_size 64 \
|
||||
--device 'cuda' \
|
||||
--metric 'ssim'
|
||||
12
eval/script/eval.sh
Normal file
12
eval/script/eval.sh
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
python eval/eval_common_metric.py \
|
||||
--batch_size 2 \
|
||||
--real_video_dir ..//test_eval/release/origin \
|
||||
--generated_video_dir ../test_eval/release \
|
||||
--device cuda \
|
||||
--sample_fps 10 \
|
||||
--crop_size 256 \
|
||||
--resolution 256 \
|
||||
--num_frames 17 \
|
||||
--sample_rate 1 \
|
||||
--subset_size 100 \
|
||||
--metric ssim
|
||||
|
|
@ -73,4 +73,5 @@ CUDA_VISIBLE_DEVICES7 torchrun --master_port=29510 --nnodes=1 --nproc_per_node=1
|
|||
|
||||
### 2.4 Data
|
||||
|
||||
* ~/data/pixabay: `/home/data/sora_data/pixabay/raw/data/split-0`
|
||||
* ~/data/pixabay: `/home/data/sora_data/pixabay/raw/data/split-0`
|
||||
* pexels: `/home/litianyi/data/pexels/processed/meta/pexels_caption_vinfo_ready_noempty_clean.csv`
|
||||
Loading…
Reference in a new issue