a bunch of update

This commit is contained in:
Zangwei Zheng 2024-04-11 11:48:06 +08:00
parent aced1fb80f
commit da1038ca5c
8 changed files with 313 additions and 217 deletions

View file

@ -1,7 +1,6 @@
num_frames = 1
num_frames = 16
fps = 24 // 3
image_size = (1358, 680)
# image_size = (256, 256)
image_size = (256, 256)
multi_resolution = "STDiT2"
# Define model
@ -20,7 +19,7 @@ vae = dict(
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=200,
model_max_length=300,
)
scheduler = dict(
type="iddpm",

View file

@ -16,7 +16,7 @@ bucket_config = { # 6s/it
"1024": {1: (0.3, 20)},
"1080p": {1: (0.3, 8)},
}
# mask_ratios = {
# mask_ratios = {x
# "mask_no": 0.9,
# "mask_random": 0.06,
# "mask_head": 0.01,

View file

@ -356,6 +356,7 @@ class STDiT2(nn.Module):
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
# mask[:, 100:] = 0
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()

View file

@ -19,6 +19,9 @@ def parse_args(training=False):
# model config
parser.add_argument("config", help="model config file path")
# ======================================================
# General
# ======================================================
parser.add_argument("--seed", default=42, type=int, help="generation seed")
parser.add_argument("--ckpt-path", type=str, help="path to model ckpt; will overwrite cfg.ckpt_path if specified")
parser.add_argument("--batch-size", default=None, type=int, help="batch size")
@ -26,15 +29,22 @@ def parse_args(training=False):
# ======================================================
# Inference
# ======================================================
if not training:
# prompt
parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file")
parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples")
# image/video
parser.add_argument("--num-frames", default=None, type=int, help="number of frames")
parser.add_argument("--fps", default=None, type=int, help="fps")
parser.add_argument("--image-size", default=None, type=int, nargs=2, help="image size")
# hyperparameters
parser.add_argument("--num-sampling-steps", default=None, type=int, help="sampling steps")
parser.add_argument("--cfg-scale", default=None, type=float, help="balance between cond & uncond")
# ======================================================
# Training
# ======================================================
else:
parser.add_argument("--wandb", default=None, type=bool, help="enable wandb")
parser.add_argument("--load", default=None, type=str, help="path to continue training")

View file

@ -6,6 +6,7 @@ gdown
mmengine
pandas
pre-commit
pyarrow
pyav
tensorboard
timm

View file

@ -5,7 +5,9 @@
- [Dataset to CSV](#dataset-to-csv)
- [Manage datasets](#manage-datasets)
- [Requirement](#requirement)
- [Usage](#usage)
- [Basic Usage](#basic-usage)
- [Score filtering](#score-filtering)
- [Documentation](#documentation)
- [Transform datasets](#transform-datasets)
- [Resize](#resize)
- [Frame extraction](#frame-extraction)
@ -17,27 +19,29 @@ After preparing the raw dataset according to the [instructions](/docs/datasets.m
## Dataset Format
All dataset should be provided in a CSV file, which is used both for training and data preprocessing. The CSV file should only contain the following columns (some are optional).
All dataset should be provided in a `.csv` file (or `parquet.gzip` to save space), which is used for both training and data preprocessing. The columns should follow the words below:
- `path`: the relative/absolute path or url to the image or video file. The only required column.
- `text`: the caption or description of the image or video. Necessary for training.
- `num_frames`: the number of frames in the video. Necessary for training.
- `fps`: the frame rate of the video (optional)
- `width`: the width of the video frame. Necessary for STDiT2.
- `height`: the height of the video frame. Necessary for STDiT2.
- `resolution`: height x width (optional)
- `aspect_ratio`: the aspect ratio of the video frame (height / width) (optional)
- `aes`: aesthetic score calculated by [asethetic scorer](/tools/aesthetic/README.md) (optional)
- `flow`: optical flow score calculated by [UniMatch](/tools/scoring/README.md) (optional)
- `match`: matching score of a image-text/video-text pair calculated by [CLIP](/tools/scoring/README.md) (optional)
- `cmotion`: the camera motion (optional)
- `path`: the relative/absolute path or url to the image or video file. Required.
- `text`: the caption or description of the image or video. Required for training.
- `num_frames`: the number of frames in the video. Required for training.
- `width`: the width of the video frame. Required for dynamic bucket.
- `height`: the height of the video frame. Required for dynamic bucket.
- `aspect_ratio`: the aspect ratio of the video frame (height / width). Required for dynamic bucket.
- `resolution`: height x width. For analysis.
- `text_len`: the number of tokens in the text. For analysis.
- `aes`: aesthetic score calculated by [asethetic scorer](/tools/aesthetic/README.md). For filtering.
- `flow`: optical flow score calculated by [UniMatch](/tools/scoring/README.md). For filtering.
- `match`: matching score of a image-text/video-text pair calculated by [CLIP](/tools/scoring/README.md). For filtering.
- `fps`: the frame rate of the video. Optional.
- `cmotion`: the camera motion.
Example:
An example ready for training:
```csv
path, text, num_frames, fps, width, height, aspect_ratio, aes, flow, match, ...
/absolute/path/to/image1.jpg, caption1, num_of_frames
/absolute/path/to/video2.mp4, caption2, num_of_frames
path, text, num_frames, width, height, aspect_ratio
/absolute/path/to/image1.jpg, caption, 1, 720, 1280, 0.5625
/absolute/path/to/video1.mp4, caption, 120, 720, 1280, 0.5625
/absolute/path/to/video2.mp4, caption, 20, 256, 256, 1
```
We use pandas to manage the CSV files. The following code is for reading and writing the CSV files:
@ -53,10 +57,11 @@ As a start point, `convert.py` is used to convert the dataset to a CSV file. You
```bash
python -m tools.datasets.convert DATASET-TYPE DATA_FOLDER
# general video folder
python -m tools.datasets.convert video VIDEO_FOLDER
python -m tools.datasets.convert video VIDEO_FOLDER --output video.csv
# general image folder
python -m tools.datasets.convert image IMAGE_FOLDER
python -m tools.datasets.convert image IMAGE_FOLDER --output image.csv
# imagenet
python -m tools.datasets.convert imagenet IMAGENET_FOLDER --split train
# ucf101
@ -98,9 +103,9 @@ To filter a specific language, you need to install [lingua](https://github.com/p
pip install lingua-language-detector
```
### Usage
### Basic Usage
You can use the following commands to process the CSV files. The output csv file will be saved in the same directory as the input csv file, with different suffixes indicating the processing method.
You can use the following commands to process the CSV files. The output csv file will be saved in the same directory as the input csv file, with different suffixes indicating the processed method.
```bash
# csvutil takes multiple CSV files as input and merge them into one CSV file
@ -120,57 +125,23 @@ python -m tools.datasets.csvutil DATA.csv --fmin 128 --fmax 256 --disable-parall
# Remove corrupted video from the csv
python -m tools.datasets.csvutil DATA.csv --remove-corrupted
```
Here are more examples:
```bash
# modify the path to absolute path by root given
# output: DATA_abspath.csv
python -m tools.datasets.csvutil DATA.csv --abspath /absolute/path/to/dataset
# modify the path to relative path by root given
# output: DATA_relpath.csv
python -m tools.datasets.csvutil DATA.csv --relpath /relative/path/to/dataset
# remove the rows with empty captions
# output: DATA_noempty.csv
python -m tools.datasets.csvutil DATA.csv --remove-empty-caption
# remove the rows with urls
# output: DATA_nourl.csv
python -m tools.datasets.csvutil DATA.csv --remove-url
# unescape the caption
# output: DATA_unescape.csv
python -m tools.datasets.csvutil DATA.csv --unescape
# modify LLaVA caption
# output: DATA_rcp.csv
python -m tools.datasets.csvutil DATA.csv --remove-caption-prefix
# keep only the rows with english captions
# output: DATA_en.csv
python -m tools.datasets.csvutil DATA.csv --lang en
# compute num_frames, height, width, fps, aspect_ratio for videos or images
# Compute num_frames, height, width, fps, aspect_ratio for videos or images
# output: IMG_DATA+VID_DATA_vinfo.csv
python -m tools.datasets.csvutil IMG_DATA.csv VID_DATA --video-info
```
python -m tools.datasets.csvutil IMG_DATA.csv VID_DATA.csv --video-info
You can apply multiple operations at the same time:
```bash
# output: DATA_vinfo_noempty_nourl_en.csv
# You can run multiple operations at the same time.
python -m tools.datasets.csvutil DATA.csv --video-info --remove-empty-caption --remove-url --lang en
```
### Score filtering
To examine and filter the quality of the dataset by aesthetic score and clip score, you can use the following commands:
```bash
# sort the dataset by aesthetic score
# output: DATA_sort.csv
python -m tools.datasets.csvutil DATA.csv --sort-descending aesthetic_score
python -m tools.datasets.csvutil DATA.csv --sort aesthetic_score
# View examples of high aesthetic score
head -n 10 DATA_sort.csv
# View examples of low aesthetic score
@ -178,7 +149,7 @@ tail -n 10 DATA_sort.csv
# sort the dataset by clip score
# output: DATA_sort.csv
python -m tools.datasets.csvutil DATA.csv --sort-descending clip_score
python -m tools.datasets.csvutil DATA.csv --sort clip_score
# filter the dataset by aesthetic score
# output: DATA_aesmin_0.5.csv
@ -188,6 +159,43 @@ python -m tools.datasets.csvutil DATA.csv --aesmin 0.5
python -m tools.datasets.csvutil DATA.csv --matchmin 0.5
```
### Documentation
You can also use `python -m tools.datasets.csvutil --help` to see usage.
| Args | File suffix | Description |
| --------------------------- | -------------- | ------------------------------------------------------------- |
| `--output OUTPUT` | | Output path |
| `--format FORMAT` | | Output format (csv, parquet, parquet.gzip) |
| `--disable-parallel` | | Disable `pandarallel` |
| `--seed SEED` | | Random seed |
| `--shard SHARD` | `_0`,`_1` | Shard the dataset |
| `--sort KEY` | `_sort` | Sort the dataset by KEY |
| `--sort-descending KEY` | `_sort` | Sort the dataset by KEY in descending order |
| `--difference DATA.csv` | | Remove the paths in DATA.csv from the dataset |
| `--intersection DATA.csv` | | Keep the paths in DATA.csv from the dataset and merge columns |
| `--info` | `_info` | Get the basic information of each video and image |
| `--ext` | `_ext` | Remove rows if the file does not exist |
| `--remove-corrupted` | `_nocorrupted` | Remove the corrupted video and image |
| `--relpath` | `_relpath` | Modify the path to relative path by root given |
| `--abspath` | `_abspath` | Modify the path to absolute path by root given |
| `--remove-empty-caption` | `_noempty` | Remove rows with empty caption |
| `--remove-url` | `_nourl` | Remove rows with url in caption |
| `--lang LANG` | `_lang` | Remove rows with other language |
| `--remove-path-duplication` | `_noduppath` | Remove rows with duplicated path |
| `--remove-text-duplication` | `_noduptext` | Remove rows with duplicated caption |
| `--refine-llm-caption` | `_llm` | Modify the caption generated by LLM |
| `--clean-caption MODEL` | `_clean` | Modify the caption according to T5 pipeline to suit training |
| `--unescape` | `_unescape` | Unescape the caption |
| `--merge-cmotion` | `_cmotion` | Merge the camera motion to the caption |
| `--count-num-token` | `_ntoken` | Count the number of tokens in the caption |
| `--fmin FMIN` | `_fmin` | Filter the dataset by minimum number of frames |
| `--fmax FMAX` | `_fmax` | Filter the dataset by maximum number of frames |
| `--hwmax HWMAX` | `_hwmax` | Filter the dataset by maximum height x width |
| `--aesmin AESMIN` | `_aesmin` | Filter the dataset by minimum aesthetic score |
| `--matchmin MATCHMIN` | `_matchmin` | Filter the dataset by minimum clip score |
| `--flowmin FLOWMIN` | `_flowmin` | Filter the dataset by minimum optical flow score |
## Transform datasets
The `tools.datasets.transform` module provides a set of tools to transform the dataset. The general usage is as follows:

View file

@ -0,0 +1,8 @@
if __name__ == "__main__":
args = parse_args()
if args.disable_parallel:
pandas_has_parallel = False
if args.seed is not None:
random.seed(args.seed)
np.random.seed(args.seed)
main(args)

View file

@ -21,12 +21,6 @@ except ImportError:
pandas_has_parallel = False
def apply(df, func, **kwargs):
if pandas_has_parallel:
return df.parallel_apply(func, **kwargs)
return df.progress_apply(func, **kwargs)
IMG_EXTENSIONS = (
".jpg",
".jpeg",
@ -40,7 +34,18 @@ IMG_EXTENSIONS = (
)
def get_video_info(path):
def apply(df, func, **kwargs):
if pandas_has_parallel:
return df.parallel_apply(func, **kwargs)
return df.progress_apply(func, **kwargs)
# ======================================================
# --info
# ======================================================
def get_info(path):
import cv2
ext = os.path.splitext(path)[1].lower()
@ -63,6 +68,25 @@ def get_video_info(path):
return num_frames, height, width, aspect_ratio, fps, hw
# ======================================================
# --remove-corrupted
# ======================================================
def is_video_valid(path):
import decord
try:
decord.VideoReader(path, num_threads=1)
return True
except:
return False
# ======================================================
# --refine-llm-caption
# ======================================================
LLAVA_PREFIX = [
"The video shows",
"The video captures",
@ -86,7 +110,7 @@ LLAVA_PREFIX = [
def remove_caption_prefix(caption):
for prefix in LLAVA_PREFIX:
if caption.startswith(prefix):
if caption.startswith(prefix) or caption.startswith(prefix.lower()):
caption = caption[len(prefix) :].strip()
if caption[0].islower():
caption = caption[0].upper() + caption[1:]
@ -94,6 +118,10 @@ def remove_caption_prefix(caption):
return caption
# ======================================================
# --merge-cmotion
# ======================================================
CMOTION_TEXT = {
"static": "The camera is static.",
"dynamic": "The camera is moving.",
@ -129,6 +157,11 @@ def merge_cmotion(caption, cmotion):
return caption
# ======================================================
# --lang
# ======================================================
def build_lang_detector(lang_to_detect):
from lingua import Language, LanguageDetectorBuilder
@ -147,6 +180,11 @@ def build_lang_detector(lang_to_detect):
return detect_lang
# ======================================================
# --clean-caption
# ======================================================
def basic_clean(text):
import ftfy
@ -288,156 +326,56 @@ def text_preprocessing(text, use_text_preprocessing: bool = True):
return text.lower().strip()
def get_key_val_given_path(path, key, df):
return df.loc[df["path"] == path, key].item()
# ======================================================
# read & write
# ======================================================
def is_video_valid(path):
import decord
try:
decord.VideoReader(path, num_threads=1)
return True
except:
return False
def read_file(input_path):
if input_path.endswith(".csv"):
return pd.read_csv(input_path)
elif input_path.endswith(".parquet"):
return pd.read_parquet(input_path)
else:
raise NotImplementedError(f"Unsupported file format: {input_path}")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, nargs="+")
parser.add_argument("--output", type=str, default=None)
parser.add_argument("--disable-parallel", action="store_true")
parser.add_argument("--seed", type=int, default=None)
# special case
parser.add_argument("--shard", type=int, default=None)
parser.add_argument("--sort", type=str, default=None)
parser.add_argument("--sort-ascending", type=str, default=None)
parser.add_argument("--difference", type=str, default=None)
parser.add_argument("--intersection", type=str, default=None)
# path processing
parser.add_argument("--relpath", type=str, default=None)
parser.add_argument("--abspath", type=str, default=None)
# path filtering
parser.add_argument("--ext", action="store_true")
# caption filtering
parser.add_argument("--remove-empty-caption", action="store_true")
parser.add_argument("--remove-duplicate-path", action="store_true")
parser.add_argument("--lang", type=str, default=None)
parser.add_argument("--remove-url", action="store_true")
parser.add_argument("--remove-corrupted", action="store_true")
parser.add_argument("--remove-text-duplication", action="store_true")
# caption processing
parser.add_argument("--remove-caption-prefix", action="store_true")
parser.add_argument("--unescape", action="store_true")
parser.add_argument("--clean-caption", action="store_true")
parser.add_argument("--merge-cmotion", action="store_true")
parser.add_argument("--count-text-token", type=str, choices=["t5"], default=None)
# num_frames processing
parser.add_argument("--info", action="store_true")
# num_frames filtering
parser.add_argument("--fmin", type=int, default=None)
parser.add_argument("--fmax", type=int, default=None)
parser.add_argument("--hwmax", type=int, default=None)
# aesthetic filtering
parser.add_argument("--aesmin", type=float, default=None)
parser.add_argument("--matchmin", type=float, default=None)
parser.add_argument("--flowmin", type=float, default=None)
return parser.parse_args()
def save_file(data, output_path):
if output_path.endswith(".csv"):
return data.to_csv(output_path, index=False)
elif output_path.endswith(".parquet"):
return data.to_parquet(output_path, index=False)
else:
raise NotImplementedError(f"Unsupported file format: {output_path}")
def get_output_path(args, input_name):
if args.output is not None:
return args.output
name = input_name
dir_path = os.path.dirname(args.input[0])
# path processing
if args.relpath is not None:
name += "_relpath"
if args.abspath is not None:
name += "_abspath"
# path filtering
if args.ext:
name += "_ext"
# caption filtering
if args.remove_empty_caption:
name += "_noempty"
if args.remove_duplicate_path:
name += "_noduppath"
if args.lang is not None:
name += f"_{args.lang}"
if args.remove_url:
name += "_nourl"
# caption processing
if args.remove_caption_prefix:
name += "_rcp"
if args.unescape:
name += "_unescape"
if args.clean_caption:
name += "_clean"
if args.merge_cmotion:
name += "_cmcaption"
if args.count_text_token:
name += "_textlen"
# num_frames processing
if args.info:
name += "_info"
# num_frames filtering
if args.fmin is not None:
name += f"_fmin{args.fmin}"
if args.fmax is not None:
name += f"_fmax{args.fmax}"
if args.hwmax is not None:
name += f"_hwmax{args.hwmax}"
# aesthetic filtering
if args.aesmin is not None:
name += f"_aesmin{args.aesmin}"
# clip score filtering
if args.matchmin is not None:
name += f"_matchmin{args.matchmin}"
if args.flowmin is not None:
name += f"_flowmin{args.flowmin}"
# sort
if args.sort is not None:
assert args.sort_ascending is None
name += "_sort"
if args.sort_ascending is not None:
assert args.sort is None
name += "_sort"
if args.remove_corrupted:
name += "_remove_corrupted"
if args.remove_text_duplication:
name += "_notd"
output_path = os.path.join(dir_path, f"{name}.csv")
return output_path
def main(args):
# reading data
def read_data(input_paths):
data = []
input_name = ""
input_list = []
for input_path in args.input:
for input_path in input_paths:
input_list.extend(glob(input_path))
print("Input files:", input_list)
for i, input_path in enumerate(input_list):
data.append(pd.read_csv(input_path))
data.append(read_file(input_path))
input_name += os.path.basename(input_path).split(".")[0]
if i != len(input_list) - 1:
input_name += "+"
print(f"Loaded {len(data[-1])} samples from {input_path}.")
data = pd.concat(data, ignore_index=True, sort=False)
print(f"Total number of samples: {len(data)}.")
return data, input_name
# ======================================================
# main
# ======================================================
# To add a new method, register it in the main, parse_args, and get_output_path functions, and update the doc at /tools/datasets/README.md#documentation
def main(args):
# reading data
data, input_name = read_data(args.input)
# make difference
if args.difference is not None:
@ -462,7 +400,7 @@ def main(args):
# preparation
if args.lang is not None:
detect_lang = build_lang_detector(args.lang)
if args.count_text_token == "t5":
if args.count_num_token == "t5":
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("DeepFloyd/t5-v1_1-xxl")
@ -481,7 +419,7 @@ def main(args):
assert "text" in data.columns
data = data[data["text"].str.len() > 0]
data = data[~data["text"].isna()]
if args.remove_duplicate_path:
if args.remove_path_duplication:
assert "path" in data.columns
data = data.drop_duplicates(subset=["path"])
if args.remove_corrupted:
@ -493,7 +431,7 @@ def main(args):
data["path"] = apply(data["path"], lambda x: os.path.relpath(x, args.relpath))
if args.abspath is not None:
data["path"] = apply(data["path"], lambda x: os.path.join(args.abspath, x))
if args.remove_caption_prefix:
if args.refine_llm_caption:
assert "text" in data.columns
data["text"] = apply(data["text"], remove_caption_prefix)
if args.unescape:
@ -507,11 +445,11 @@ def main(args):
)
if args.merge_cmotion:
data["text"] = apply(data, lambda x: merge_cmotion(x["text"], x["cmotion"]), axis=1)
if args.count_text_token is not None:
if args.count_num_token is not None:
assert "text" in data.columns
data["text_len"] = apply(data["text"], lambda x: len(tokenizer(x)["input_ids"]))
if args.info:
info = apply(data["path"], get_video_info)
info = apply(data["path"], get_info)
(
data["num_frames"],
data["height"],
@ -558,14 +496,145 @@ def main(args):
if args.shard is not None:
sharded_data = np.array_split(data, args.shard)
for i in range(args.shard):
output_path_s = output_path.replace(".csv", f"_{i}.csv")
sharded_data[i].to_csv(output_path_s, index=False)
output_path_part = output_path.split(".")
output_path_s = ".".join(output_path_part[:-1]) + f"_{i}." + output_path_part[-1]
save_file(sharded_data[i], output_path_s)
print(f"Saved {len(sharded_data[i])} samples to {output_path_s}.")
else:
data.to_csv(output_path, index=False)
save_file(data, output_path)
print(f"Saved {len(data)} samples to {output_path}.")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, nargs="+")
parser.add_argument("--output", type=str, default=None, help="output path")
parser.add_argument("--format", type=str, default="csv", help="output format", choices=["csv", "parquet"])
parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing")
parser.add_argument("--seed", type=int, default=None, help="random seed")
# special case
parser.add_argument("--shard", type=int, default=None, help="shard the dataset")
parser.add_argument("--sort", type=str, default=None, help="sort by column")
parser.add_argument("--sort-ascending", type=str, default=None, help="sort by column (ascending order)")
parser.add_argument("--difference", type=str, default=None, help="remove the paths in csv from the dataset")
parser.add_argument(
"--intersection", type=str, default=None, help="keep the paths in csv from the dataset and merge columns"
)
# IO-related
parser.add_argument("--info", action="store_true", help="get the basic information of each video and image")
parser.add_argument("--ext", action="store_true", help="check if the file exists")
parser.add_argument("--remove-corrupted", action="store_true", help="remove the corrupted video and image")
# path processing
parser.add_argument("--relpath", type=str, default=None, help="modify the path to relative path by root given")
parser.add_argument("--abspath", type=str, default=None, help="modify the path to absolute path by root given")
# caption filtering
parser.add_argument(
"--remove-empty-caption",
action="store_true",
help="remove the empty caption",
help="remove rows with empty caption",
)
parser.add_argument("--remove-url", action="store_true", help="remove rows with url in caption")
parser.add_argument("--lang", type=str, default=None, help="remove rows with other language")
parser.add_argument("--remove-path-duplication", action="store_true", help="remove rows with duplicated path")
parser.add_argument("--remove-text-duplication", action="store_true", help="remove rows with duplicated caption")
# caption processing
parser.add_argument("--refine-llm-caption", action="store_true", help="modify the caption generated by LLM")
parser.add_argument(
"--clean-caption", action="store_true", help="modify the caption according to T5 pipeline to suit training"
)
parser.add_argument("--unescape", action="store_true", help="unescape the caption")
parser.add_argument("--merge-cmotion", action="store_true", help="merge the camera motion to the caption")
parser.add_argument(
"--count-num-token", type=str, choices=["t5"], default=None, help="Count the number of tokens in the caption"
)
# score filtering
parser.add_argument("--fmin", type=int, default=None, help="filter the dataset by minimum number of frames")
parser.add_argument("--fmax", type=int, default=None, help="filter the dataset by maximum number of frames")
parser.add_argument("--hwmax", type=int, default=None, help="filter the dataset by maximum resolution")
parser.add_argument("--aesmin", type=float, default=None, help="filter the dataset by minimum aes score")
parser.add_argument("--matchmin", type=float, default=None, help="filter the dataset by minimum match score")
parser.add_argument("--flowmin", type=float, default=None, help="filter the dataset by minimum flow score")
return parser.parse_args()
def get_output_path(args, input_name):
if args.output is not None:
return args.output
name = input_name
dir_path = os.path.dirname(args.input[0])
# sort
if args.sort is not None:
assert args.sort_ascending is None
name += "_sort"
if args.sort_ascending is not None:
assert args.sort is None
name += "_sort"
# IO-related
if args.info:
name += "_info"
if args.ext:
name += "_ext"
if args.remove_corrupted:
name += "_nocorrupted"
# path processing
if args.relpath is not None:
name += "_relpath"
if args.abspath is not None:
name += "_abspath"
# caption filtering
if args.remove_empty_caption:
name += "_noempty"
if args.remove_url:
name += "_nourl"
if args.lang is not None:
name += f"_{args.lang}"
if args.remove_path_duplication:
name += "_noduppath"
if args.remove_text_duplication:
name += "_noduptext"
# caption processing
if args.refine_llm_caption:
name += "_llm"
if args.unescape:
name += "_unescape"
if args.clean_caption:
name += "_clean"
if args.merge_cmotion:
name += "_cmcaption"
if args.count_num_token:
name += "_ntoken"
# score filtering
if args.fmin is not None:
name += f"_fmin{args.fmin}"
if args.fmax is not None:
name += f"_fmax{args.fmax}"
if args.hwmax is not None:
name += f"_hwmax{args.hwmax}"
if args.aesmin is not None:
name += f"_aesmin{args.aesmin}"
if args.matchmin is not None:
name += f"_matchmin{args.matchmin}"
if args.flowmin is not None:
name += f"_flowmin{args.flowmin}"
output_path = os.path.join(dir_path, f"{name}.{args.format}")
return output_path
if __name__ == "__main__":
args = parse_args()
if args.disable_parallel: