From d8203d5cc15b274fdd72c9a4d6fe61942ffbf29d Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Fri, 5 Apr 2024 11:21:00 +0800 Subject: [PATCH] add intersect func --- tools/datasets/csvutil.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tools/datasets/csvutil.py b/tools/datasets/csvutil.py index a93d252..8ad7294 100644 --- a/tools/datasets/csvutil.py +++ b/tools/datasets/csvutil.py @@ -251,7 +251,9 @@ def text_preprocessing(text, use_text_preprocessing: bool = True): return text else: return text.lower().strip() - + +def get_key_val_given_path(path, key, df): + return df.loc[df["path"]==path, key].item() def is_video_valid(path): import decord @@ -395,6 +397,13 @@ def main(args): print(f"Intersection csv contains {len(data_int)} samples.") data = data[data["path"].isin(data_int["path"])] input_name += f"-{os.path.basename(args.intersection).split('.')[0]}" + # get additional data column from intersection + int_new_keys = [k for k in data_int.keys() if not k in data] + for k in int_new_keys: + # get corresponding values + data[k] = apply(data['path'], lambda x: get_key_val_given_path(x, k, data_int)) + + print(f"added {len(int_new_keys)} keys from {args.intersection}: {int_new_keys}") print(f"Filtered number of samples: {len(data)}.") # get output path