mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
auto-collating
This commit is contained in:
parent
a397173104
commit
f96d15ead0
|
|
@ -207,6 +207,24 @@ def main(args):
|
|||
row = batch[idx]
|
||||
writer.writerow([*row, keywords])
|
||||
dist.barrier()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
print("All ranks are finished. Collating the processed data to {}".format(output_file))
|
||||
import pandas as pd
|
||||
csv_files = [os.path.splitext(args.input)[0] + f"_rank{i}" + "_llama3.csv" for i in range(dist.get_world_size())]
|
||||
# List to hold DataFrames
|
||||
dataframes = []
|
||||
# Read each CSV into a DataFrame and append to list
|
||||
for file in csv_files:
|
||||
df = pd.read_csv(file)
|
||||
dataframes.append(df)
|
||||
# Concatenate all DataFrames
|
||||
combined_df = pd.concat(dataframes, ignore_index=True)
|
||||
|
||||
collated_file = os.path.splitext(args.input)[0] + "_llama3.csv"
|
||||
# Save the combined DataFrame to a new CSV file
|
||||
combined_df.to_csv(collated_file, index=False)
|
||||
print("Collated data saved to {}".format(collated_file))
|
||||
# terminate distributed env
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue