diff --git a/tools/caption/caption_llama3.py b/tools/caption/caption_llama3.py index 6e79596..f463c52 100644 --- a/tools/caption/caption_llama3.py +++ b/tools/caption/caption_llama3.py @@ -207,6 +207,24 @@ def main(args): row = batch[idx] writer.writerow([*row, keywords]) dist.barrier() + + if dist.get_rank() == 0: + print("All ranks are finished. Collating the processed data to {}".format(output_file)) + import pandas as pd + csv_files = [os.path.splitext(args.input)[0] + f"_rank{i}" + "_llama3.csv" for i in range(dist.get_world_size())] + # List to hold DataFrames + dataframes = [] + # Read each CSV into a DataFrame and append to list + for file in csv_files: + df = pd.read_csv(file) + dataframes.append(df) + # Concatenate all DataFrames + combined_df = pd.concat(dataframes, ignore_index=True) + + collated_file = os.path.splitext(args.input)[0] + "_llama3.csv" + # Save the combined DataFrame to a new CSV file + combined_df.to_csv(collated_file, index=False) + print("Collated data saved to {}".format(collated_file)) # terminate distributed env dist.destroy_process_group()