diff --git a/tools/datasets/csvutil.py b/tools/datasets/csvutil.py index 9d66426..b1ac152 100644 --- a/tools/datasets/csvutil.py +++ b/tools/datasets/csvutil.py @@ -1,6 +1,7 @@ import argparse import html import os +from glob import glob import numpy as np import pandas as pd @@ -191,10 +192,14 @@ def main(args): # reading data data = [] input_name = "" - for i, input_path in enumerate(args.input): + input_list = [] + for input_path in args.input: + input_list.extend(glob(input_path)) + print("Input files:", input_list) + for i, input_path in enumerate(input_list): data.append(pd.read_csv(input_path)) input_name += os.path.basename(input_path).split(".")[0] - if i != len(args.input) - 1: + if i != len(input_list) - 1: input_name += "+" print(f"Loaded {len(data[-1])} samples from {input_path}.") data = pd.concat(data, ignore_index=True, sort=False) @@ -223,6 +228,21 @@ def main(args): if args.lang is not None: detect_lang = build_lang_detector(args.lang) + # filtering + if args.ext: + assert "path" in data.columns + data = data[apply(data["path"], os.path.exists)] + if args.remove_empty_caption: + assert "text" in data.columns + data = data[data["text"].str.len() > 0] + data = data[~data["text"].isna()] + if args.remove_url: + assert "text" in data.columns + data = data[~data["text"].str.contains(r"(?Phttps?://[^\s]+)", regex=True)] + if args.lang is not None: + assert "text" in data.columns + data = data[data["text"].progress_apply(detect_lang)] # cannot parallelize + # processing if args.relpath is not None: data["path"] = apply(data["path"], lambda x: os.path.relpath(x, args.relpath)) @@ -239,19 +259,6 @@ def main(args): data["num_frames"], data["height"], data["width"], data["aspect_ratio"], data["fps"] = zip(*info) # filtering - if args.ext: - assert "path" in data.columns - data = data[apply(data["path"], os.path.exists)] - if args.remove_empty_caption: - assert "text" in data.columns - data = data[data["text"].str.len() > 0] - data = data[~data["text"].isna()] - if args.remove_url: - assert "text" in data.columns - data = data[~data["text"].str.contains(r"(?Phttps?://[^\s]+)", regex=True)] - if args.lang is not None: - assert "text" in data.columns - data = data[data["text"].progress_apply(detect_lang)] # cannot parallelize if args.fmin is not None: assert "num_frames" in data.columns data = data[data["num_frames"] >= args.fmin]