automatically tabulate vbench

This commit is contained in:
Shen-Chenhui 2024-05-27 08:08:36 +00:00
parent 6bac2980f2
commit 3efa8127b6
2 changed files with 46 additions and 3 deletions

View file

@ -14,8 +14,8 @@ from ast import literal_eval
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--log_dir", type=str, default="/home/zhengzangwei/projs/Open-Sora-dev/logs/loss")
parser.add_argument("--ckpt_name", type=str, default="epoch0-global_step9000")
parser.add_argument("--log_dir", type=str, default="logs/loss")
parser.add_argument("--ckpt_name", type=str)
args = parser.parse_args()
return args
@ -47,5 +47,7 @@ if __name__ == "__main__":
loss_info[resolution][frame] = format(loss, ".4f")
# Convert and write JSON object to file
with open(os.path.join(output_dir, args.ckpt_name + "_loss.json"), "w") as outfile:
output_file_path = os.path.join(output_dir, args.ckpt_name + "_loss.json")
with open(output_file_path, "w") as outfile:
json.dump(loss_info, outfile, indent=4, sort_keys=True)
print(f"results saved to: {output_file_path}")

View file

@ -0,0 +1,41 @@
import argparse
import json
import os
from ast import literal_eval
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--score_dir", type=str) # evaluation_results/samples_...
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
res_postfix = "_eval_results.json"
info_postfix = "_full_info.json"
files = os.listdir(args.score_dir)
res_files = [x for x in files if res_postfix in x]
info_files = [x for x in files if info_postfix in x]
assert len(res_files) == len(info_files), f"got {len(res_files)} res files, but {len(info_files)} info files"
full_results = {}
for res_file in res_files:
# first check if results is normal
info_file = res_file.split(res_postfix)[0] + info_postfix
with open(info_file, "r", encoding="utf-8") as f:
info = json.load(f)
assert len(info[0]["video_list"]) > 0, f"Error: {info_file} has 0 video list"
# read results
with open(res_file, "r", encoding="utf-8") as f:
data = json.load(f)
for key, val in data.items():
full_results[key] = val[0]
output_file_path = os.path.join(args.score_dir, "all_results.json")
with open(output_file_path, "w") as outfile:
json.dump(full_results, outfile, indent=4, sort_keys=True)
print(f"results saved to: {output_file_path}")