diff --git a/eval/loss/tabulate_rl_loss.py b/eval/loss/tabulate_rl_loss.py index dee48c2..85396e4 100644 --- a/eval/loss/tabulate_rl_loss.py +++ b/eval/loss/tabulate_rl_loss.py @@ -14,8 +14,8 @@ from ast import literal_eval def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--log_dir", type=str, default="/home/zhengzangwei/projs/Open-Sora-dev/logs/loss") - parser.add_argument("--ckpt_name", type=str, default="epoch0-global_step9000") + parser.add_argument("--log_dir", type=str, default="logs/loss") + parser.add_argument("--ckpt_name", type=str) args = parser.parse_args() return args @@ -47,5 +47,7 @@ if __name__ == "__main__": loss_info[resolution][frame] = format(loss, ".4f") # Convert and write JSON object to file - with open(os.path.join(output_dir, args.ckpt_name + "_loss.json"), "w") as outfile: + output_file_path = os.path.join(output_dir, args.ckpt_name + "_loss.json") + with open(output_file_path, "w") as outfile: json.dump(loss_info, outfile, indent=4, sort_keys=True) + print(f"results saved to: {output_file_path}") \ No newline at end of file diff --git a/eval/vbench/tabulate_vbench_scores.py b/eval/vbench/tabulate_vbench_scores.py new file mode 100644 index 0000000..3928eed --- /dev/null +++ b/eval/vbench/tabulate_vbench_scores.py @@ -0,0 +1,41 @@ +import argparse +import json +import os +from ast import literal_eval + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--score_dir", type=str) # evaluation_results/samples_... + args = parser.parse_args() + return args + +if __name__ == "__main__": + args = parse_args() + res_postfix = "_eval_results.json" + info_postfix = "_full_info.json" + files = os.listdir(args.score_dir) + res_files = [x for x in files if res_postfix in x] + info_files = [x for x in files if info_postfix in x] + assert len(res_files) == len(info_files), f"got {len(res_files)} res files, but {len(info_files)} info files" + + full_results = {} + + for res_file in res_files: + # first check if results is normal + info_file = res_file.split(res_postfix)[0] + info_postfix + with open(info_file, "r", encoding="utf-8") as f: + info = json.load(f) + assert len(info[0]["video_list"]) > 0, f"Error: {info_file} has 0 video list" + # read results + with open(res_file, "r", encoding="utf-8") as f: + data = json.load(f) + for key, val in data.items(): + full_results[key] = val[0] + + + output_file_path = os.path.join(args.score_dir, "all_results.json") + with open(output_file_path, "w") as outfile: + json.dump(full_results, outfile, indent=4, sort_keys=True) + print(f"results saved to: {output_file_path}") + +