mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-12 22:38:53 +02:00
automatically tabulate vbench
This commit is contained in:
parent
6bac2980f2
commit
3efa8127b6
|
|
@ -14,8 +14,8 @@ from ast import literal_eval
|
|||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--log_dir", type=str, default="/home/zhengzangwei/projs/Open-Sora-dev/logs/loss")
|
||||
parser.add_argument("--ckpt_name", type=str, default="epoch0-global_step9000")
|
||||
parser.add_argument("--log_dir", type=str, default="logs/loss")
|
||||
parser.add_argument("--ckpt_name", type=str)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
|
@ -47,5 +47,7 @@ if __name__ == "__main__":
|
|||
loss_info[resolution][frame] = format(loss, ".4f")
|
||||
|
||||
# Convert and write JSON object to file
|
||||
with open(os.path.join(output_dir, args.ckpt_name + "_loss.json"), "w") as outfile:
|
||||
output_file_path = os.path.join(output_dir, args.ckpt_name + "_loss.json")
|
||||
with open(output_file_path, "w") as outfile:
|
||||
json.dump(loss_info, outfile, indent=4, sort_keys=True)
|
||||
print(f"results saved to: {output_file_path}")
|
||||
41
eval/vbench/tabulate_vbench_scores.py
Normal file
41
eval/vbench/tabulate_vbench_scores.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
from ast import literal_eval
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--score_dir", type=str) # evaluation_results/samples_...
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
res_postfix = "_eval_results.json"
|
||||
info_postfix = "_full_info.json"
|
||||
files = os.listdir(args.score_dir)
|
||||
res_files = [x for x in files if res_postfix in x]
|
||||
info_files = [x for x in files if info_postfix in x]
|
||||
assert len(res_files) == len(info_files), f"got {len(res_files)} res files, but {len(info_files)} info files"
|
||||
|
||||
full_results = {}
|
||||
|
||||
for res_file in res_files:
|
||||
# first check if results is normal
|
||||
info_file = res_file.split(res_postfix)[0] + info_postfix
|
||||
with open(info_file, "r", encoding="utf-8") as f:
|
||||
info = json.load(f)
|
||||
assert len(info[0]["video_list"]) > 0, f"Error: {info_file} has 0 video list"
|
||||
# read results
|
||||
with open(res_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
for key, val in data.items():
|
||||
full_results[key] = val[0]
|
||||
|
||||
|
||||
output_file_path = os.path.join(args.score_dir, "all_results.json")
|
||||
with open(output_file_path, "w") as outfile:
|
||||
json.dump(full_results, outfile, indent=4, sort_keys=True)
|
||||
print(f"results saved to: {output_file_path}")
|
||||
|
||||
|
||||
Loading…
Reference in a new issue