mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-19 01:15:33 +02:00
fvd debug
This commit is contained in:
parent
533fde627b
commit
76bbe32264
|
|
@ -2,7 +2,7 @@ import numpy as np
|
|||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
def trans(x):
|
||||
def trans(x): # requires video to be BCTHW
|
||||
# if greyscale images add channel
|
||||
if x.shape[-3] == 1:
|
||||
x = x.repeat(1, 1, 3, 1, 1)
|
||||
|
|
@ -12,6 +12,7 @@ def trans(x):
|
|||
|
||||
return x
|
||||
|
||||
|
||||
def calculate_fvd(videos1, videos2, device, method='styleganv'):
|
||||
|
||||
if method == 'styleganv':
|
||||
|
|
@ -39,6 +40,7 @@ def calculate_fvd(videos1, videos2, device, method='styleganv'):
|
|||
|
||||
fvd_results = {}
|
||||
|
||||
|
||||
# for calculate FVD, each clip_timestamp must >= 10
|
||||
for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)):
|
||||
|
||||
|
|
@ -47,10 +49,12 @@ def calculate_fvd(videos1, videos2, device, method='styleganv'):
|
|||
videos_clip1 = videos1[:, :, : clip_timestamp]
|
||||
videos_clip2 = videos2[:, :, : clip_timestamp]
|
||||
|
||||
breakpoint()
|
||||
# get FVD features
|
||||
feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device)
|
||||
feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device)
|
||||
|
||||
|
||||
breakpoint()
|
||||
# calculate FVD when timestamps[:clip]
|
||||
fvd_results[clip_timestamp] = frechet_distance(feats1, feats2)
|
||||
|
||||
|
|
|
|||
|
|
@ -42,11 +42,11 @@ from torchvision.transforms import Lambda, Compose
|
|||
from torchvision.transforms._transforms_video import CenterCropVideo
|
||||
import sys
|
||||
sys.path.append(".")
|
||||
from .cal_lpips import calculate_lpips
|
||||
from .cal_fvd import calculate_fvd
|
||||
from .cal_psnr import calculate_psnr
|
||||
from .cal_flolpips import calculate_flolpips
|
||||
from .cal_ssim import calculate_ssim
|
||||
# from cal_lpips import calculate_lpips
|
||||
from cal_fvd import calculate_fvd
|
||||
from cal_psnr import calculate_psnr
|
||||
# from cal_flolpips import calculate_flolpips
|
||||
from cal_ssim import calculate_ssim
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
|
@ -146,6 +146,7 @@ def calculate_common_metric(args, dataloader, device):
|
|||
assert real_videos.shape[2] == generated_videos.shape[2]
|
||||
if args.metric == 'fvd':
|
||||
tmp_list = list(calculate_fvd(real_videos, generated_videos, args.device, method=args.fvd_method)['value'].values())
|
||||
print("fvd list:", tmp_list)
|
||||
elif args.metric == 'ssim':
|
||||
tmp_list = list(calculate_ssim(real_videos, generated_videos)['value'].values())
|
||||
elif args.metric == 'psnr':
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ def load_i3d_pretrained(device=torch.device('cpu')):
|
|||
i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt"
|
||||
filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt')
|
||||
print(filepath)
|
||||
breakpoint()
|
||||
if not os.path.exists(filepath):
|
||||
print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
|
||||
os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
|
||||
|
|
|
|||
Loading…
Reference in a new issue