diff --git a/tools/caption/caption_llama3.py b/tools/caption/caption_llama3.py index f463c52..8ccd992 100644 --- a/tools/caption/caption_llama3.py +++ b/tools/caption/caption_llama3.py @@ -87,9 +87,9 @@ def main(args): dist.get_world_size() ) - output_file = os.path.splitext(args.input)[0] + f"_rank{dist.get_rank()}" + "_llama3.csv" - - writer = csv.writer(open(output_file, "w")) + output_file = args.output_prefix + f"_rank{dist.get_rank()}" + "_llama3.csv" + output_file_handle = open(output_file, "w") + writer = csv.writer(output_file_handle) columns = list(dataframe.columns) + [args.key] writer.writerow(columns) @@ -206,12 +206,15 @@ def main(args): row = batch[idx] writer.writerow([*row, keywords]) + + output_file_handle.close() dist.barrier() if dist.get_rank() == 0: - print("All ranks are finished. Collating the processed data to {}".format(output_file)) + collated_file = args.output_prefix + "_llama3.csv" + print("All ranks are finished. Collating the processed data to {}".format(collated_file)) import pandas as pd - csv_files = [os.path.splitext(args.input)[0] + f"_rank{i}" + "_llama3.csv" for i in range(dist.get_world_size())] + csv_files = [args.output_prefix + f"_rank{i}" + "_llama3.csv" for i in range(dist.get_world_size())] # List to hold DataFrames dataframes = [] # Read each CSV into a DataFrame and append to list @@ -221,7 +224,6 @@ def main(args): # Concatenate all DataFrames combined_df = pd.concat(dataframes, ignore_index=True) - collated_file = os.path.splitext(args.input)[0] + "_llama3.csv" # Save the combined DataFrame to a new CSV file combined_df.to_csv(collated_file, index=False) print("Collated data saved to {}".format(collated_file)) @@ -233,6 +235,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-id", default="meta-llama/Meta-Llama-3-8B-Instruct") parser.add_argument("input", type=str, help="Path to the input CSV file") + parser.add_argument("--output_prefix", type=str, help="Path to the output CSV file") parser.add_argument("--prompt", type=str, default="") parser.add_argument("--batch_size", type=int, default=64) parser.add_argument('--key', default='summary', type=str)