add intersect func

This commit is contained in:
Shen-Chenhui 2024-04-05 11:21:00 +08:00
parent 57a1e7f6c4
commit d8203d5cc1

View file

@ -252,6 +252,8 @@ def text_preprocessing(text, use_text_preprocessing: bool = True):
else: else:
return text.lower().strip() return text.lower().strip()
def get_key_val_given_path(path, key, df):
return df.loc[df["path"]==path, key].item()
def is_video_valid(path): def is_video_valid(path):
import decord import decord
@ -395,6 +397,13 @@ def main(args):
print(f"Intersection csv contains {len(data_int)} samples.") print(f"Intersection csv contains {len(data_int)} samples.")
data = data[data["path"].isin(data_int["path"])] data = data[data["path"].isin(data_int["path"])]
input_name += f"-{os.path.basename(args.intersection).split('.')[0]}" input_name += f"-{os.path.basename(args.intersection).split('.')[0]}"
# get additional data column from intersection
int_new_keys = [k for k in data_int.keys() if not k in data]
for k in int_new_keys:
# get corresponding values
data[k] = apply(data['path'], lambda x: get_key_val_given_path(x, k, data_int))
print(f"added {len(int_new_keys)} keys from {args.intersection}: {int_new_keys}")
print(f"Filtered number of samples: {len(data)}.") print(f"Filtered number of samples: {len(data)}.")
# get output path # get output path