fvd debug

This commit is contained in:
Shen-Chenhui 2024-04-30 15:54:08 +08:00
parent 533fde627b
commit 76bbe32264
3 changed files with 13 additions and 7 deletions

View file

@ -2,7 +2,7 @@ import numpy as np
import torch
from tqdm import tqdm
def trans(x):
def trans(x): # requires video to be BCTHW
# if greyscale images add channel
if x.shape[-3] == 1:
x = x.repeat(1, 1, 3, 1, 1)
@ -12,6 +12,7 @@ def trans(x):
return x
def calculate_fvd(videos1, videos2, device, method='styleganv'):
if method == 'styleganv':
@ -39,6 +40,7 @@ def calculate_fvd(videos1, videos2, device, method='styleganv'):
fvd_results = {}
# for calculate FVD, each clip_timestamp must >= 10
for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)):
@ -47,10 +49,12 @@ def calculate_fvd(videos1, videos2, device, method='styleganv'):
videos_clip1 = videos1[:, :, : clip_timestamp]
videos_clip2 = videos2[:, :, : clip_timestamp]
breakpoint()
# get FVD features
feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device)
feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device)
breakpoint()
# calculate FVD when timestamps[:clip]
fvd_results[clip_timestamp] = frechet_distance(feats1, feats2)

View file

@ -42,11 +42,11 @@ from torchvision.transforms import Lambda, Compose
from torchvision.transforms._transforms_video import CenterCropVideo
import sys
sys.path.append(".")
from .cal_lpips import calculate_lpips
from .cal_fvd import calculate_fvd
from .cal_psnr import calculate_psnr
from .cal_flolpips import calculate_flolpips
from .cal_ssim import calculate_ssim
# from cal_lpips import calculate_lpips
from cal_fvd import calculate_fvd
from cal_psnr import calculate_psnr
# from cal_flolpips import calculate_flolpips
from cal_ssim import calculate_ssim
try:
from tqdm import tqdm
@ -146,6 +146,7 @@ def calculate_common_metric(args, dataloader, device):
assert real_videos.shape[2] == generated_videos.shape[2]
if args.metric == 'fvd':
tmp_list = list(calculate_fvd(real_videos, generated_videos, args.device, method=args.fvd_method)['value'].values())
print("fvd list:", tmp_list)
elif args.metric == 'ssim':
tmp_list = list(calculate_ssim(real_videos, generated_videos)['value'].values())
elif args.metric == 'psnr':

View file

@ -10,6 +10,7 @@ def load_i3d_pretrained(device=torch.device('cpu')):
i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt"
filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt')
print(filepath)
breakpoint()
if not os.path.exists(filepath):
print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")