specify output file

This commit is contained in:
Tom Young 2024-06-13 09:15:42 +00:00
parent f96d15ead0
commit 811739dbba

View file

@ -87,9 +87,9 @@ def main(args):
dist.get_world_size()
)
output_file = os.path.splitext(args.input)[0] + f"_rank{dist.get_rank()}" + "_llama3.csv"
writer = csv.writer(open(output_file, "w"))
output_file = args.output_prefix + f"_rank{dist.get_rank()}" + "_llama3.csv"
output_file_handle = open(output_file, "w")
writer = csv.writer(output_file_handle)
columns = list(dataframe.columns) + [args.key]
writer.writerow(columns)
@ -206,12 +206,15 @@ def main(args):
row = batch[idx]
writer.writerow([*row, keywords])
output_file_handle.close()
dist.barrier()
if dist.get_rank() == 0:
print("All ranks are finished. Collating the processed data to {}".format(output_file))
collated_file = args.output_prefix + "_llama3.csv"
print("All ranks are finished. Collating the processed data to {}".format(collated_file))
import pandas as pd
csv_files = [os.path.splitext(args.input)[0] + f"_rank{i}" + "_llama3.csv" for i in range(dist.get_world_size())]
csv_files = [args.output_prefix + f"_rank{i}" + "_llama3.csv" for i in range(dist.get_world_size())]
# List to hold DataFrames
dataframes = []
# Read each CSV into a DataFrame and append to list
@ -221,7 +224,6 @@ def main(args):
# Concatenate all DataFrames
combined_df = pd.concat(dataframes, ignore_index=True)
collated_file = os.path.splitext(args.input)[0] + "_llama3.csv"
# Save the combined DataFrame to a new CSV file
combined_df.to_csv(collated_file, index=False)
print("Collated data saved to {}".format(collated_file))
@ -233,6 +235,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", default="meta-llama/Meta-Llama-3-8B-Instruct")
parser.add_argument("input", type=str, help="Path to the input CSV file")
parser.add_argument("--output_prefix", type=str, help="Path to the output CSV file")
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument('--key', default='summary', type=str)