diff --git a/eval/cal_fvd.py b/eval/cal_fvd.py index 1f1a980..949b773 100644 --- a/eval/cal_fvd.py +++ b/eval/cal_fvd.py @@ -2,7 +2,7 @@ import numpy as np import torch from tqdm import tqdm -def trans(x): +def trans(x): # requires video to be BCTHW # if greyscale images add channel if x.shape[-3] == 1: x = x.repeat(1, 1, 3, 1, 1) @@ -12,6 +12,7 @@ def trans(x): return x + def calculate_fvd(videos1, videos2, device, method='styleganv'): if method == 'styleganv': @@ -39,6 +40,7 @@ def calculate_fvd(videos1, videos2, device, method='styleganv'): fvd_results = {} + # for calculate FVD, each clip_timestamp must >= 10 for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)): @@ -47,10 +49,12 @@ def calculate_fvd(videos1, videos2, device, method='styleganv'): videos_clip1 = videos1[:, :, : clip_timestamp] videos_clip2 = videos2[:, :, : clip_timestamp] + breakpoint() # get FVD features feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device) feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device) - + + breakpoint() # calculate FVD when timestamps[:clip] fvd_results[clip_timestamp] = frechet_distance(feats1, feats2) diff --git a/eval/eval_common_metric.py b/eval/eval_common_metric.py index abca5fa..4e03b4a 100644 --- a/eval/eval_common_metric.py +++ b/eval/eval_common_metric.py @@ -42,11 +42,11 @@ from torchvision.transforms import Lambda, Compose from torchvision.transforms._transforms_video import CenterCropVideo import sys sys.path.append(".") -from .cal_lpips import calculate_lpips -from .cal_fvd import calculate_fvd -from .cal_psnr import calculate_psnr -from .cal_flolpips import calculate_flolpips -from .cal_ssim import calculate_ssim +# from cal_lpips import calculate_lpips +from cal_fvd import calculate_fvd +from cal_psnr import calculate_psnr +# from cal_flolpips import calculate_flolpips +from cal_ssim import calculate_ssim try: from tqdm import tqdm @@ -146,6 +146,7 @@ def calculate_common_metric(args, dataloader, device): assert real_videos.shape[2] == generated_videos.shape[2] if args.metric == 'fvd': tmp_list = list(calculate_fvd(real_videos, generated_videos, args.device, method=args.fvd_method)['value'].values()) + print("fvd list:", tmp_list) elif args.metric == 'ssim': tmp_list = list(calculate_ssim(real_videos, generated_videos)['value'].values()) elif args.metric == 'psnr': diff --git a/eval/fvd/styleganv/fvd.py b/eval/fvd/styleganv/fvd.py index 3043a2a..e1e9fde 100644 --- a/eval/fvd/styleganv/fvd.py +++ b/eval/fvd/styleganv/fvd.py @@ -10,6 +10,7 @@ def load_i3d_pretrained(device=torch.device('cpu')): i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt" filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt') print(filepath) + breakpoint() if not os.path.exists(filepath): print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.") os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")