mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
add intersect func
This commit is contained in:
parent
57a1e7f6c4
commit
d8203d5cc1
|
|
@ -252,6 +252,8 @@ def text_preprocessing(text, use_text_preprocessing: bool = True):
|
||||||
else:
|
else:
|
||||||
return text.lower().strip()
|
return text.lower().strip()
|
||||||
|
|
||||||
|
def get_key_val_given_path(path, key, df):
|
||||||
|
return df.loc[df["path"]==path, key].item()
|
||||||
|
|
||||||
def is_video_valid(path):
|
def is_video_valid(path):
|
||||||
import decord
|
import decord
|
||||||
|
|
@ -395,6 +397,13 @@ def main(args):
|
||||||
print(f"Intersection csv contains {len(data_int)} samples.")
|
print(f"Intersection csv contains {len(data_int)} samples.")
|
||||||
data = data[data["path"].isin(data_int["path"])]
|
data = data[data["path"].isin(data_int["path"])]
|
||||||
input_name += f"-{os.path.basename(args.intersection).split('.')[0]}"
|
input_name += f"-{os.path.basename(args.intersection).split('.')[0]}"
|
||||||
|
# get additional data column from intersection
|
||||||
|
int_new_keys = [k for k in data_int.keys() if not k in data]
|
||||||
|
for k in int_new_keys:
|
||||||
|
# get corresponding values
|
||||||
|
data[k] = apply(data['path'], lambda x: get_key_val_given_path(x, k, data_int))
|
||||||
|
|
||||||
|
print(f"added {len(int_new_keys)} keys from {args.intersection}: {int_new_keys}")
|
||||||
print(f"Filtered number of samples: {len(data)}.")
|
print(f"Filtered number of samples: {len(data)}.")
|
||||||
|
|
||||||
# get output path
|
# get output path
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue