auto-collating

This commit is contained in:
Tom Young 2024-06-13 06:31:17 +00:00
parent a397173104
commit f96d15ead0

View file

@ -207,6 +207,24 @@ def main(args):
row = batch[idx]
writer.writerow([*row, keywords])
dist.barrier()
if dist.get_rank() == 0:
print("All ranks are finished. Collating the processed data to {}".format(output_file))
import pandas as pd
csv_files = [os.path.splitext(args.input)[0] + f"_rank{i}" + "_llama3.csv" for i in range(dist.get_world_size())]
# List to hold DataFrames
dataframes = []
# Read each CSV into a DataFrame and append to list
for file in csv_files:
df = pd.read_csv(file)
dataframes.append(df)
# Concatenate all DataFrames
combined_df = pd.concat(dataframes, ignore_index=True)
collated_file = os.path.splitext(args.input)[0] + "_llama3.csv"
# Save the combined DataFrame to a new CSV file
combined_df.to_csv(collated_file, index=False)
print("Collated data saved to {}".format(collated_file))
# terminate distributed env
dist.destroy_process_group()