mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
[faeture] add script to split dataset by bucket (#43)
This commit is contained in:
parent
3210a691cf
commit
4bd85a3851
76
tools/datasets/split.py
Normal file
76
tools/datasets/split.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
import argparse
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
from mmengine.config import Config
|
||||
|
||||
from opensora.datasets.bucket import Bucket
|
||||
|
||||
|
||||
def split_by_bucket(
|
||||
bucket: Bucket,
|
||||
input_files: List[str],
|
||||
output_path: str,
|
||||
limit: int,
|
||||
frame_interval: int,
|
||||
):
|
||||
print(f"Split {len(input_files)} files into {len(bucket)} buckets")
|
||||
total_limit = len(bucket) * limit
|
||||
bucket_cnt = {}
|
||||
# get all bucket id
|
||||
for hw_id, d in bucket.ar_criteria.items():
|
||||
for t_id, v in d.items():
|
||||
for ar_id in v.keys():
|
||||
bucket_id = (hw_id, t_id, ar_id)
|
||||
bucket_cnt[bucket_id] = 0
|
||||
output_df = None
|
||||
# split files
|
||||
for path in input_files:
|
||||
df = pd.read_csv(path)
|
||||
if output_df is None:
|
||||
output_df = pd.DataFrame(columns=df.columns)
|
||||
for i in range(len(df)):
|
||||
row = df.iloc[i]
|
||||
t, h, w = row["num_frames"], row["height"], row["width"]
|
||||
bucket_id = bucket.get_bucket_id(t, h, w, frame_interval)
|
||||
if bucket_id is None:
|
||||
continue
|
||||
if bucket_cnt[bucket_id] < limit:
|
||||
bucket_cnt[bucket_id] += 1
|
||||
output_df = pd.concat(
|
||||
[output_df, pd.DataFrame([row])], ignore_index=True
|
||||
)
|
||||
if len(output_df) >= total_limit:
|
||||
break
|
||||
if len(output_df) >= total_limit:
|
||||
break
|
||||
assert len(output_df) <= total_limit
|
||||
if len(output_df) == total_limit:
|
||||
print(f"All buckets are full ({total_limit} samples)")
|
||||
else:
|
||||
print(f"Only {len(output_df)} files are used")
|
||||
output_df.to_csv(output_path, index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input", type=str, nargs="+")
|
||||
parser.add_argument("-o", "--output", required=True)
|
||||
parser.add_argument("-c", "--config", required=True)
|
||||
parser.add_argument("-l", "--limit", default=200, type=int)
|
||||
args = parser.parse_args()
|
||||
assert args.limit > 0
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
bucket_config = cfg.bucket_config
|
||||
# rewrite bucket_config
|
||||
for ar, d in bucket_config.items():
|
||||
for frames, t in d.items():
|
||||
p, bs = t
|
||||
if p > 0.0:
|
||||
p = 1.0
|
||||
d[frames] = (p, bs)
|
||||
bucket = Bucket(bucket_config)
|
||||
split_by_bucket(
|
||||
bucket, args.input, args.output, args.limit, cfg.dataset.frame_interval
|
||||
)
|
||||
Loading…
Reference in a new issue