mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
more prompts
This commit is contained in:
parent
823ddfb436
commit
83693675f2
|
|
@ -87,7 +87,7 @@ def main(args):
|
|||
dist.get_world_size()
|
||||
)
|
||||
|
||||
output_file = args.output_prefix + f"_rank{dist.get_rank()}" + "_llama3.csv"
|
||||
output_file = args.output_prefix + f"_rank{dist.get_rank()}" + f"_{args.key}.csv"
|
||||
output_file_handle = open(output_file, "w")
|
||||
writer = csv.writer(output_file_handle)
|
||||
columns = list(dataframe.columns) + [args.key]
|
||||
|
|
@ -181,9 +181,12 @@ def main(args):
|
|||
return responses
|
||||
|
||||
print("Processing starting...")
|
||||
if args.prompt == "":
|
||||
if args.prompt == "" and args.key == "objects":
|
||||
prompt = ("You are a AI assistant to extract objects from user's text. "
|
||||
"For example: user: 'In this video a dog is running around. In addition, a person is laughing at the dog.', you produce a list of objects separated by ',' and wrapped by '[' and ']': '[dog, person]' ")
|
||||
elif args.prompt == "" and args.key == "actions":
|
||||
prompt = ("You are a AI assistant to extract actions from user's text. "
|
||||
"For example: user: 'In this video a dog is running around. In addition, a person is laughing at the dog.', you produce a list of actions separated by ',' and wrapped by '[' and ']': '[run, laugh]' ")
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
|
|
@ -211,10 +214,10 @@ def main(args):
|
|||
dist.barrier()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
collated_file = args.output_prefix + "_llama3.csv"
|
||||
collated_file = args.output_prefix + f"_{args.key}.csv"
|
||||
print("All ranks are finished. Collating the processed data to {}".format(collated_file))
|
||||
import pandas as pd
|
||||
csv_files = [args.output_prefix + f"_rank{i}" + "_llama3.csv" for i in range(dist.get_world_size())]
|
||||
csv_files = [args.output_prefix + f"_rank{i}" + f"_{args.key}.csv" for i in range(dist.get_world_size())]
|
||||
# List to hold DataFrames
|
||||
dataframes = []
|
||||
# Read each CSV into a DataFrame and append to list
|
||||
|
|
@ -238,7 +241,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--output_prefix", type=str, help="Path to the output CSV file")
|
||||
parser.add_argument("--prompt", type=str, default="")
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument('--key', default='summary', type=str)
|
||||
parser.add_argument('--key', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
Loading…
Reference in a new issue