mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
add intersect func
This commit is contained in:
parent
57a1e7f6c4
commit
d8203d5cc1
|
|
@ -251,7 +251,9 @@ def text_preprocessing(text, use_text_preprocessing: bool = True):
|
|||
return text
|
||||
else:
|
||||
return text.lower().strip()
|
||||
|
||||
|
||||
def get_key_val_given_path(path, key, df):
|
||||
return df.loc[df["path"]==path, key].item()
|
||||
|
||||
def is_video_valid(path):
|
||||
import decord
|
||||
|
|
@ -395,6 +397,13 @@ def main(args):
|
|||
print(f"Intersection csv contains {len(data_int)} samples.")
|
||||
data = data[data["path"].isin(data_int["path"])]
|
||||
input_name += f"-{os.path.basename(args.intersection).split('.')[0]}"
|
||||
# get additional data column from intersection
|
||||
int_new_keys = [k for k in data_int.keys() if not k in data]
|
||||
for k in int_new_keys:
|
||||
# get corresponding values
|
||||
data[k] = apply(data['path'], lambda x: get_key_val_given_path(x, k, data_int))
|
||||
|
||||
print(f"added {len(int_new_keys)} keys from {args.intersection}: {int_new_keys}")
|
||||
print(f"Filtered number of samples: {len(data)}.")
|
||||
|
||||
# get output path
|
||||
|
|
|
|||
Loading…
Reference in a new issue