mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
add filter by flow score function
This commit is contained in:
parent
d8203d5cc1
commit
f3c417a61e
|
|
@ -305,6 +305,7 @@ def parse_args():
|
|||
# aesthetic filtering
|
||||
parser.add_argument("--aesmin", type=float, default=None)
|
||||
parser.add_argument("--matchmin", type=float, default=None)
|
||||
parser.add_argument("--flowmin", type=float, default=None)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
@ -352,6 +353,8 @@ def get_output_path(args, input_name):
|
|||
# clip score filtering
|
||||
if args.matchmin is not None:
|
||||
name += f"_matchmin{args.matchmin}"
|
||||
if args.flowmin is not None:
|
||||
name += f"_flowmin{args.flowmin}"
|
||||
# sort
|
||||
if args.sort_descending is not None:
|
||||
assert args.sort_ascending is None
|
||||
|
|
@ -476,6 +479,9 @@ def main(args):
|
|||
if args.matchmin is not None:
|
||||
assert "match" in data.columns
|
||||
data = data[data["match"] >= args.matchmin]
|
||||
if args.flowmin is not None:
|
||||
assert "flow" in data.columns
|
||||
data = data[data["flow"] >= args.flowmin]
|
||||
print(f"Filtered number of samples: {len(data)}.")
|
||||
|
||||
# sort
|
||||
|
|
|
|||
Loading…
Reference in a new issue