diff --git a/tools/caption/caption_llama3.py b/tools/caption/caption_llama3.py index 8ccd992..df22d6b 100644 --- a/tools/caption/caption_llama3.py +++ b/tools/caption/caption_llama3.py @@ -87,7 +87,7 @@ def main(args): dist.get_world_size() ) - output_file = args.output_prefix + f"_rank{dist.get_rank()}" + "_llama3.csv" + output_file = args.output_prefix + f"_rank{dist.get_rank()}" + f"_{args.key}.csv" output_file_handle = open(output_file, "w") writer = csv.writer(output_file_handle) columns = list(dataframe.columns) + [args.key] @@ -181,9 +181,12 @@ def main(args): return responses print("Processing starting...") - if args.prompt == "": + if args.prompt == "" and args.key == "objects": prompt = ("You are a AI assistant to extract objects from user's text. " "For example: user: 'In this video a dog is running around. In addition, a person is laughing at the dog.', you produce a list of objects separated by ',' and wrapped by '[' and ']': '[dog, person]' ") + elif args.prompt == "" and args.key == "actions": + prompt = ("You are a AI assistant to extract actions from user's text. " + "For example: user: 'In this video a dog is running around. In addition, a person is laughing at the dog.', you produce a list of actions separated by ',' and wrapped by '[' and ']': '[run, laugh]' ") else: prompt = args.prompt @@ -211,10 +214,10 @@ def main(args): dist.barrier() if dist.get_rank() == 0: - collated_file = args.output_prefix + "_llama3.csv" + collated_file = args.output_prefix + f"_{args.key}.csv" print("All ranks are finished. Collating the processed data to {}".format(collated_file)) import pandas as pd - csv_files = [args.output_prefix + f"_rank{i}" + "_llama3.csv" for i in range(dist.get_world_size())] + csv_files = [args.output_prefix + f"_rank{i}" + f"_{args.key}.csv" for i in range(dist.get_world_size())] # List to hold DataFrames dataframes = [] # Read each CSV into a DataFrame and append to list @@ -238,7 +241,7 @@ if __name__ == "__main__": parser.add_argument("--output_prefix", type=str, help="Path to the output CSV file") parser.add_argument("--prompt", type=str, default="") parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument('--key', default='summary', type=str) + parser.add_argument('--key', type=str) args = parser.parse_args() main(args) \ No newline at end of file