more prompts

This commit is contained in:
Tom Young 2024-06-14 06:30:59 +00:00
parent 823ddfb436
commit 83693675f2

View file

@ -87,7 +87,7 @@ def main(args):
dist.get_world_size()
)
output_file = args.output_prefix + f"_rank{dist.get_rank()}" + "_llama3.csv"
output_file = args.output_prefix + f"_rank{dist.get_rank()}" + f"_{args.key}.csv"
output_file_handle = open(output_file, "w")
writer = csv.writer(output_file_handle)
columns = list(dataframe.columns) + [args.key]
@ -181,9 +181,12 @@ def main(args):
return responses
print("Processing starting...")
if args.prompt == "":
if args.prompt == "" and args.key == "objects":
prompt = ("You are a AI assistant to extract objects from user's text. "
"For example: user: 'In this video a dog is running around. In addition, a person is laughing at the dog.', you produce a list of objects separated by ',' and wrapped by '[' and ']': '[dog, person]' ")
elif args.prompt == "" and args.key == "actions":
prompt = ("You are a AI assistant to extract actions from user's text. "
"For example: user: 'In this video a dog is running around. In addition, a person is laughing at the dog.', you produce a list of actions separated by ',' and wrapped by '[' and ']': '[run, laugh]' ")
else:
prompt = args.prompt
@ -211,10 +214,10 @@ def main(args):
dist.barrier()
if dist.get_rank() == 0:
collated_file = args.output_prefix + "_llama3.csv"
collated_file = args.output_prefix + f"_{args.key}.csv"
print("All ranks are finished. Collating the processed data to {}".format(collated_file))
import pandas as pd
csv_files = [args.output_prefix + f"_rank{i}" + "_llama3.csv" for i in range(dist.get_world_size())]
csv_files = [args.output_prefix + f"_rank{i}" + f"_{args.key}.csv" for i in range(dist.get_world_size())]
# List to hold DataFrames
dataframes = []
# Read each CSV into a DataFrame and append to list
@ -238,7 +241,7 @@ if __name__ == "__main__":
parser.add_argument("--output_prefix", type=str, help="Path to the output CSV file")
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument('--key', default='summary', type=str)
parser.add_argument('--key', type=str)
args = parser.parse_args()
main(args)