mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
specify output file
This commit is contained in:
parent
f96d15ead0
commit
811739dbba
|
|
@ -87,9 +87,9 @@ def main(args):
|
|||
dist.get_world_size()
|
||||
)
|
||||
|
||||
output_file = os.path.splitext(args.input)[0] + f"_rank{dist.get_rank()}" + "_llama3.csv"
|
||||
|
||||
writer = csv.writer(open(output_file, "w"))
|
||||
output_file = args.output_prefix + f"_rank{dist.get_rank()}" + "_llama3.csv"
|
||||
output_file_handle = open(output_file, "w")
|
||||
writer = csv.writer(output_file_handle)
|
||||
columns = list(dataframe.columns) + [args.key]
|
||||
|
||||
writer.writerow(columns)
|
||||
|
|
@ -206,12 +206,15 @@ def main(args):
|
|||
|
||||
row = batch[idx]
|
||||
writer.writerow([*row, keywords])
|
||||
|
||||
output_file_handle.close()
|
||||
dist.barrier()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
print("All ranks are finished. Collating the processed data to {}".format(output_file))
|
||||
collated_file = args.output_prefix + "_llama3.csv"
|
||||
print("All ranks are finished. Collating the processed data to {}".format(collated_file))
|
||||
import pandas as pd
|
||||
csv_files = [os.path.splitext(args.input)[0] + f"_rank{i}" + "_llama3.csv" for i in range(dist.get_world_size())]
|
||||
csv_files = [args.output_prefix + f"_rank{i}" + "_llama3.csv" for i in range(dist.get_world_size())]
|
||||
# List to hold DataFrames
|
||||
dataframes = []
|
||||
# Read each CSV into a DataFrame and append to list
|
||||
|
|
@ -221,7 +224,6 @@ def main(args):
|
|||
# Concatenate all DataFrames
|
||||
combined_df = pd.concat(dataframes, ignore_index=True)
|
||||
|
||||
collated_file = os.path.splitext(args.input)[0] + "_llama3.csv"
|
||||
# Save the combined DataFrame to a new CSV file
|
||||
combined_df.to_csv(collated_file, index=False)
|
||||
print("Collated data saved to {}".format(collated_file))
|
||||
|
|
@ -233,6 +235,7 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-id", default="meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
parser.add_argument("input", type=str, help="Path to the input CSV file")
|
||||
parser.add_argument("--output_prefix", type=str, help="Path to the output CSV file")
|
||||
parser.add_argument("--prompt", type=str, default="")
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument('--key', default='summary', type=str)
|
||||
|
|
|
|||
Loading…
Reference in a new issue