From da1038ca5c23fb1de14b6e2c49505032deb9e3d8 Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Thu, 11 Apr 2024 11:48:06 +0800 Subject: [PATCH] a bunch of update --- .../inference/{image.py => sample.py} | 7 +- configs/opensora-v1-1/train/video.py | 2 +- opensora/models/stdit/stdit2.py | 1 + opensora/utils/config_utils.py | 12 +- requirements.txt | 1 + tools/datasets/README.md | 134 ++++--- tools/datasets/analyze.py | 8 + tools/datasets/csvutil.py | 365 +++++++++++------- 8 files changed, 313 insertions(+), 217 deletions(-) rename configs/opensora-v1-1/inference/{image.py => sample.py} (89%) create mode 100644 tools/datasets/analyze.py diff --git a/configs/opensora-v1-1/inference/image.py b/configs/opensora-v1-1/inference/sample.py similarity index 89% rename from configs/opensora-v1-1/inference/image.py rename to configs/opensora-v1-1/inference/sample.py index 0823bcd..f929f56 100644 --- a/configs/opensora-v1-1/inference/image.py +++ b/configs/opensora-v1-1/inference/sample.py @@ -1,7 +1,6 @@ -num_frames = 1 +num_frames = 16 fps = 24 // 3 -image_size = (1358, 680) -# image_size = (256, 256) +image_size = (256, 256) multi_resolution = "STDiT2" # Define model @@ -20,7 +19,7 @@ vae = dict( text_encoder = dict( type="t5", from_pretrained="DeepFloyd/t5-v1_1-xxl", - model_max_length=200, + model_max_length=300, ) scheduler = dict( type="iddpm", diff --git a/configs/opensora-v1-1/train/video.py b/configs/opensora-v1-1/train/video.py index 776d3de..20325ae 100644 --- a/configs/opensora-v1-1/train/video.py +++ b/configs/opensora-v1-1/train/video.py @@ -16,7 +16,7 @@ bucket_config = { # 6s/it "1024": {1: (0.3, 20)}, "1080p": {1: (0.3, 8)}, } -# mask_ratios = { +# mask_ratios = {x # "mask_no": 0.9, # "mask_random": 0.06, # "mask_head": 0.01, diff --git a/opensora/models/stdit/stdit2.py b/opensora/models/stdit/stdit2.py index 506ef1e..0a4edee 100644 --- a/opensora/models/stdit/stdit2.py +++ b/opensora/models/stdit/stdit2.py @@ -356,6 +356,7 @@ class STDiT2(nn.Module): if mask is not None: if mask.shape[0] != y.shape[0]: mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + # mask[:, 100:] = 0 mask = mask.squeeze(1).squeeze(1) y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) y_lens = mask.sum(dim=1).tolist() diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index d47c0a7..d6f1c01 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -19,6 +19,9 @@ def parse_args(training=False): # model config parser.add_argument("config", help="model config file path") + # ====================================================== + # General + # ====================================================== parser.add_argument("--seed", default=42, type=int, help="generation seed") parser.add_argument("--ckpt-path", type=str, help="path to model ckpt; will overwrite cfg.ckpt_path if specified") parser.add_argument("--batch-size", default=None, type=int, help="batch size") @@ -26,15 +29,22 @@ def parse_args(training=False): # ====================================================== # Inference # ====================================================== - if not training: # prompt parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file") parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples") + # image/video + parser.add_argument("--num-frames", default=None, type=int, help="number of frames") + parser.add_argument("--fps", default=None, type=int, help="fps") + parser.add_argument("--image-size", default=None, type=int, nargs=2, help="image size") + # hyperparameters parser.add_argument("--num-sampling-steps", default=None, type=int, help="sampling steps") parser.add_argument("--cfg-scale", default=None, type=float, help="balance between cond & uncond") + # ====================================================== + # Training + # ====================================================== else: parser.add_argument("--wandb", default=None, type=bool, help="enable wandb") parser.add_argument("--load", default=None, type=str, help="path to continue training") diff --git a/requirements.txt b/requirements.txt index 9c74d73..f093684 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ gdown mmengine pandas pre-commit +pyarrow pyav tensorboard timm diff --git a/tools/datasets/README.md b/tools/datasets/README.md index 3462e6e..fc685ba 100644 --- a/tools/datasets/README.md +++ b/tools/datasets/README.md @@ -5,7 +5,9 @@ - [Dataset to CSV](#dataset-to-csv) - [Manage datasets](#manage-datasets) - [Requirement](#requirement) - - [Usage](#usage) + - [Basic Usage](#basic-usage) + - [Score filtering](#score-filtering) + - [Documentation](#documentation) - [Transform datasets](#transform-datasets) - [Resize](#resize) - [Frame extraction](#frame-extraction) @@ -17,27 +19,29 @@ After preparing the raw dataset according to the [instructions](/docs/datasets.m ## Dataset Format -All dataset should be provided in a CSV file, which is used both for training and data preprocessing. The CSV file should only contain the following columns (some are optional). +All dataset should be provided in a `.csv` file (or `parquet.gzip` to save space), which is used for both training and data preprocessing. The columns should follow the words below: -- `path`: the relative/absolute path or url to the image or video file. The only required column. -- `text`: the caption or description of the image or video. Necessary for training. -- `num_frames`: the number of frames in the video. Necessary for training. -- `fps`: the frame rate of the video (optional) -- `width`: the width of the video frame. Necessary for STDiT2. -- `height`: the height of the video frame. Necessary for STDiT2. -- `resolution`: height x width (optional) -- `aspect_ratio`: the aspect ratio of the video frame (height / width) (optional) -- `aes`: aesthetic score calculated by [asethetic scorer](/tools/aesthetic/README.md) (optional) -- `flow`: optical flow score calculated by [UniMatch](/tools/scoring/README.md) (optional) -- `match`: matching score of a image-text/video-text pair calculated by [CLIP](/tools/scoring/README.md) (optional) -- `cmotion`: the camera motion (optional) +- `path`: the relative/absolute path or url to the image or video file. Required. +- `text`: the caption or description of the image or video. Required for training. +- `num_frames`: the number of frames in the video. Required for training. +- `width`: the width of the video frame. Required for dynamic bucket. +- `height`: the height of the video frame. Required for dynamic bucket. +- `aspect_ratio`: the aspect ratio of the video frame (height / width). Required for dynamic bucket. +- `resolution`: height x width. For analysis. +- `text_len`: the number of tokens in the text. For analysis. +- `aes`: aesthetic score calculated by [asethetic scorer](/tools/aesthetic/README.md). For filtering. +- `flow`: optical flow score calculated by [UniMatch](/tools/scoring/README.md). For filtering. +- `match`: matching score of a image-text/video-text pair calculated by [CLIP](/tools/scoring/README.md). For filtering. +- `fps`: the frame rate of the video. Optional. +- `cmotion`: the camera motion. -Example: +An example ready for training: ```csv -path, text, num_frames, fps, width, height, aspect_ratio, aes, flow, match, ... -/absolute/path/to/image1.jpg, caption1, num_of_frames -/absolute/path/to/video2.mp4, caption2, num_of_frames +path, text, num_frames, width, height, aspect_ratio +/absolute/path/to/image1.jpg, caption, 1, 720, 1280, 0.5625 +/absolute/path/to/video1.mp4, caption, 120, 720, 1280, 0.5625 +/absolute/path/to/video2.mp4, caption, 20, 256, 256, 1 ``` We use pandas to manage the CSV files. The following code is for reading and writing the CSV files: @@ -53,10 +57,11 @@ As a start point, `convert.py` is used to convert the dataset to a CSV file. You ```bash python -m tools.datasets.convert DATASET-TYPE DATA_FOLDER + # general video folder -python -m tools.datasets.convert video VIDEO_FOLDER +python -m tools.datasets.convert video VIDEO_FOLDER --output video.csv # general image folder -python -m tools.datasets.convert image IMAGE_FOLDER +python -m tools.datasets.convert image IMAGE_FOLDER --output image.csv # imagenet python -m tools.datasets.convert imagenet IMAGENET_FOLDER --split train # ucf101 @@ -98,9 +103,9 @@ To filter a specific language, you need to install [lingua](https://github.com/p pip install lingua-language-detector ``` -### Usage +### Basic Usage -You can use the following commands to process the CSV files. The output csv file will be saved in the same directory as the input csv file, with different suffixes indicating the processing method. +You can use the following commands to process the CSV files. The output csv file will be saved in the same directory as the input csv file, with different suffixes indicating the processed method. ```bash # csvutil takes multiple CSV files as input and merge them into one CSV file @@ -120,57 +125,23 @@ python -m tools.datasets.csvutil DATA.csv --fmin 128 --fmax 256 --disable-parall # Remove corrupted video from the csv python -m tools.datasets.csvutil DATA.csv --remove-corrupted -``` -Here are more examples: - -```bash -# modify the path to absolute path by root given -# output: DATA_abspath.csv -python -m tools.datasets.csvutil DATA.csv --abspath /absolute/path/to/dataset - -# modify the path to relative path by root given -# output: DATA_relpath.csv -python -m tools.datasets.csvutil DATA.csv --relpath /relative/path/to/dataset - -# remove the rows with empty captions -# output: DATA_noempty.csv -python -m tools.datasets.csvutil DATA.csv --remove-empty-caption - -# remove the rows with urls -# output: DATA_nourl.csv -python -m tools.datasets.csvutil DATA.csv --remove-url - -# unescape the caption -# output: DATA_unescape.csv -python -m tools.datasets.csvutil DATA.csv --unescape - -# modify LLaVA caption -# output: DATA_rcp.csv -python -m tools.datasets.csvutil DATA.csv --remove-caption-prefix - -# keep only the rows with english captions -# output: DATA_en.csv -python -m tools.datasets.csvutil DATA.csv --lang en - -# compute num_frames, height, width, fps, aspect_ratio for videos or images +# Compute num_frames, height, width, fps, aspect_ratio for videos or images # output: IMG_DATA+VID_DATA_vinfo.csv -python -m tools.datasets.csvutil IMG_DATA.csv VID_DATA --video-info -``` +python -m tools.datasets.csvutil IMG_DATA.csv VID_DATA.csv --video-info -You can apply multiple operations at the same time: - -```bash -# output: DATA_vinfo_noempty_nourl_en.csv +# You can run multiple operations at the same time. python -m tools.datasets.csvutil DATA.csv --video-info --remove-empty-caption --remove-url --lang en ``` +### Score filtering + To examine and filter the quality of the dataset by aesthetic score and clip score, you can use the following commands: ```bash # sort the dataset by aesthetic score # output: DATA_sort.csv -python -m tools.datasets.csvutil DATA.csv --sort-descending aesthetic_score +python -m tools.datasets.csvutil DATA.csv --sort aesthetic_score # View examples of high aesthetic score head -n 10 DATA_sort.csv # View examples of low aesthetic score @@ -178,7 +149,7 @@ tail -n 10 DATA_sort.csv # sort the dataset by clip score # output: DATA_sort.csv -python -m tools.datasets.csvutil DATA.csv --sort-descending clip_score +python -m tools.datasets.csvutil DATA.csv --sort clip_score # filter the dataset by aesthetic score # output: DATA_aesmin_0.5.csv @@ -188,6 +159,43 @@ python -m tools.datasets.csvutil DATA.csv --aesmin 0.5 python -m tools.datasets.csvutil DATA.csv --matchmin 0.5 ``` +### Documentation + +You can also use `python -m tools.datasets.csvutil --help` to see usage. + +| Args | File suffix | Description | +| --------------------------- | -------------- | ------------------------------------------------------------- | +| `--output OUTPUT` | | Output path | +| `--format FORMAT` | | Output format (csv, parquet, parquet.gzip) | +| `--disable-parallel` | | Disable `pandarallel` | +| `--seed SEED` | | Random seed | +| `--shard SHARD` | `_0`,`_1` | Shard the dataset | +| `--sort KEY` | `_sort` | Sort the dataset by KEY | +| `--sort-descending KEY` | `_sort` | Sort the dataset by KEY in descending order | +| `--difference DATA.csv` | | Remove the paths in DATA.csv from the dataset | +| `--intersection DATA.csv` | | Keep the paths in DATA.csv from the dataset and merge columns | +| `--info` | `_info` | Get the basic information of each video and image | +| `--ext` | `_ext` | Remove rows if the file does not exist | +| `--remove-corrupted` | `_nocorrupted` | Remove the corrupted video and image | +| `--relpath` | `_relpath` | Modify the path to relative path by root given | +| `--abspath` | `_abspath` | Modify the path to absolute path by root given | +| `--remove-empty-caption` | `_noempty` | Remove rows with empty caption | +| `--remove-url` | `_nourl` | Remove rows with url in caption | +| `--lang LANG` | `_lang` | Remove rows with other language | +| `--remove-path-duplication` | `_noduppath` | Remove rows with duplicated path | +| `--remove-text-duplication` | `_noduptext` | Remove rows with duplicated caption | +| `--refine-llm-caption` | `_llm` | Modify the caption generated by LLM | +| `--clean-caption MODEL` | `_clean` | Modify the caption according to T5 pipeline to suit training | +| `--unescape` | `_unescape` | Unescape the caption | +| `--merge-cmotion` | `_cmotion` | Merge the camera motion to the caption | +| `--count-num-token` | `_ntoken` | Count the number of tokens in the caption | +| `--fmin FMIN` | `_fmin` | Filter the dataset by minimum number of frames | +| `--fmax FMAX` | `_fmax` | Filter the dataset by maximum number of frames | +| `--hwmax HWMAX` | `_hwmax` | Filter the dataset by maximum height x width | +| `--aesmin AESMIN` | `_aesmin` | Filter the dataset by minimum aesthetic score | +| `--matchmin MATCHMIN` | `_matchmin` | Filter the dataset by minimum clip score | +| `--flowmin FLOWMIN` | `_flowmin` | Filter the dataset by minimum optical flow score | + ## Transform datasets The `tools.datasets.transform` module provides a set of tools to transform the dataset. The general usage is as follows: diff --git a/tools/datasets/analyze.py b/tools/datasets/analyze.py new file mode 100644 index 0000000..134db45 --- /dev/null +++ b/tools/datasets/analyze.py @@ -0,0 +1,8 @@ +if __name__ == "__main__": + args = parse_args() + if args.disable_parallel: + pandas_has_parallel = False + if args.seed is not None: + random.seed(args.seed) + np.random.seed(args.seed) + main(args) diff --git a/tools/datasets/csvutil.py b/tools/datasets/csvutil.py index db962fa..d1ad6da 100644 --- a/tools/datasets/csvutil.py +++ b/tools/datasets/csvutil.py @@ -21,12 +21,6 @@ except ImportError: pandas_has_parallel = False -def apply(df, func, **kwargs): - if pandas_has_parallel: - return df.parallel_apply(func, **kwargs) - return df.progress_apply(func, **kwargs) - - IMG_EXTENSIONS = ( ".jpg", ".jpeg", @@ -40,7 +34,18 @@ IMG_EXTENSIONS = ( ) -def get_video_info(path): +def apply(df, func, **kwargs): + if pandas_has_parallel: + return df.parallel_apply(func, **kwargs) + return df.progress_apply(func, **kwargs) + + +# ====================================================== +# --info +# ====================================================== + + +def get_info(path): import cv2 ext = os.path.splitext(path)[1].lower() @@ -63,6 +68,25 @@ def get_video_info(path): return num_frames, height, width, aspect_ratio, fps, hw +# ====================================================== +# --remove-corrupted +# ====================================================== + + +def is_video_valid(path): + import decord + + try: + decord.VideoReader(path, num_threads=1) + return True + except: + return False + + +# ====================================================== +# --refine-llm-caption +# ====================================================== + LLAVA_PREFIX = [ "The video shows", "The video captures", @@ -86,7 +110,7 @@ LLAVA_PREFIX = [ def remove_caption_prefix(caption): for prefix in LLAVA_PREFIX: - if caption.startswith(prefix): + if caption.startswith(prefix) or caption.startswith(prefix.lower()): caption = caption[len(prefix) :].strip() if caption[0].islower(): caption = caption[0].upper() + caption[1:] @@ -94,6 +118,10 @@ def remove_caption_prefix(caption): return caption +# ====================================================== +# --merge-cmotion +# ====================================================== + CMOTION_TEXT = { "static": "The camera is static.", "dynamic": "The camera is moving.", @@ -129,6 +157,11 @@ def merge_cmotion(caption, cmotion): return caption +# ====================================================== +# --lang +# ====================================================== + + def build_lang_detector(lang_to_detect): from lingua import Language, LanguageDetectorBuilder @@ -147,6 +180,11 @@ def build_lang_detector(lang_to_detect): return detect_lang +# ====================================================== +# --clean-caption +# ====================================================== + + def basic_clean(text): import ftfy @@ -288,156 +326,56 @@ def text_preprocessing(text, use_text_preprocessing: bool = True): return text.lower().strip() -def get_key_val_given_path(path, key, df): - return df.loc[df["path"] == path, key].item() +# ====================================================== +# read & write +# ====================================================== -def is_video_valid(path): - import decord - - try: - decord.VideoReader(path, num_threads=1) - return True - except: - return False +def read_file(input_path): + if input_path.endswith(".csv"): + return pd.read_csv(input_path) + elif input_path.endswith(".parquet"): + return pd.read_parquet(input_path) + else: + raise NotImplementedError(f"Unsupported file format: {input_path}") -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("input", type=str, nargs="+") - parser.add_argument("--output", type=str, default=None) - parser.add_argument("--disable-parallel", action="store_true") - parser.add_argument("--seed", type=int, default=None) - # special case - parser.add_argument("--shard", type=int, default=None) - parser.add_argument("--sort", type=str, default=None) - parser.add_argument("--sort-ascending", type=str, default=None) - parser.add_argument("--difference", type=str, default=None) - parser.add_argument("--intersection", type=str, default=None) - - # path processing - parser.add_argument("--relpath", type=str, default=None) - parser.add_argument("--abspath", type=str, default=None) - - # path filtering - parser.add_argument("--ext", action="store_true") - - # caption filtering - parser.add_argument("--remove-empty-caption", action="store_true") - parser.add_argument("--remove-duplicate-path", action="store_true") - parser.add_argument("--lang", type=str, default=None) - parser.add_argument("--remove-url", action="store_true") - parser.add_argument("--remove-corrupted", action="store_true") - parser.add_argument("--remove-text-duplication", action="store_true") - - # caption processing - parser.add_argument("--remove-caption-prefix", action="store_true") - parser.add_argument("--unescape", action="store_true") - parser.add_argument("--clean-caption", action="store_true") - parser.add_argument("--merge-cmotion", action="store_true") - parser.add_argument("--count-text-token", type=str, choices=["t5"], default=None) - - # num_frames processing - parser.add_argument("--info", action="store_true") - - # num_frames filtering - parser.add_argument("--fmin", type=int, default=None) - parser.add_argument("--fmax", type=int, default=None) - parser.add_argument("--hwmax", type=int, default=None) - - # aesthetic filtering - parser.add_argument("--aesmin", type=float, default=None) - parser.add_argument("--matchmin", type=float, default=None) - parser.add_argument("--flowmin", type=float, default=None) - - return parser.parse_args() +def save_file(data, output_path): + if output_path.endswith(".csv"): + return data.to_csv(output_path, index=False) + elif output_path.endswith(".parquet"): + return data.to_parquet(output_path, index=False) + else: + raise NotImplementedError(f"Unsupported file format: {output_path}") -def get_output_path(args, input_name): - if args.output is not None: - return args.output - - name = input_name - dir_path = os.path.dirname(args.input[0]) - - # path processing - if args.relpath is not None: - name += "_relpath" - if args.abspath is not None: - name += "_abspath" - # path filtering - if args.ext: - name += "_ext" - # caption filtering - if args.remove_empty_caption: - name += "_noempty" - if args.remove_duplicate_path: - name += "_noduppath" - if args.lang is not None: - name += f"_{args.lang}" - if args.remove_url: - name += "_nourl" - # caption processing - if args.remove_caption_prefix: - name += "_rcp" - if args.unescape: - name += "_unescape" - if args.clean_caption: - name += "_clean" - if args.merge_cmotion: - name += "_cmcaption" - if args.count_text_token: - name += "_textlen" - # num_frames processing - if args.info: - name += "_info" - # num_frames filtering - if args.fmin is not None: - name += f"_fmin{args.fmin}" - if args.fmax is not None: - name += f"_fmax{args.fmax}" - if args.hwmax is not None: - name += f"_hwmax{args.hwmax}" - # aesthetic filtering - if args.aesmin is not None: - name += f"_aesmin{args.aesmin}" - # clip score filtering - if args.matchmin is not None: - name += f"_matchmin{args.matchmin}" - if args.flowmin is not None: - name += f"_flowmin{args.flowmin}" - # sort - if args.sort is not None: - assert args.sort_ascending is None - name += "_sort" - if args.sort_ascending is not None: - assert args.sort is None - name += "_sort" - if args.remove_corrupted: - name += "_remove_corrupted" - if args.remove_text_duplication: - name += "_notd" - - output_path = os.path.join(dir_path, f"{name}.csv") - return output_path - - -def main(args): - # reading data +def read_data(input_paths): data = [] input_name = "" input_list = [] - for input_path in args.input: + for input_path in input_paths: input_list.extend(glob(input_path)) print("Input files:", input_list) for i, input_path in enumerate(input_list): - data.append(pd.read_csv(input_path)) + data.append(read_file(input_path)) input_name += os.path.basename(input_path).split(".")[0] if i != len(input_list) - 1: input_name += "+" print(f"Loaded {len(data[-1])} samples from {input_path}.") data = pd.concat(data, ignore_index=True, sort=False) print(f"Total number of samples: {len(data)}.") + return data, input_name + + +# ====================================================== +# main +# ====================================================== +# To add a new method, register it in the main, parse_args, and get_output_path functions, and update the doc at /tools/datasets/README.md#documentation + + +def main(args): + # reading data + data, input_name = read_data(args.input) # make difference if args.difference is not None: @@ -462,7 +400,7 @@ def main(args): # preparation if args.lang is not None: detect_lang = build_lang_detector(args.lang) - if args.count_text_token == "t5": + if args.count_num_token == "t5": from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("DeepFloyd/t5-v1_1-xxl") @@ -481,7 +419,7 @@ def main(args): assert "text" in data.columns data = data[data["text"].str.len() > 0] data = data[~data["text"].isna()] - if args.remove_duplicate_path: + if args.remove_path_duplication: assert "path" in data.columns data = data.drop_duplicates(subset=["path"]) if args.remove_corrupted: @@ -493,7 +431,7 @@ def main(args): data["path"] = apply(data["path"], lambda x: os.path.relpath(x, args.relpath)) if args.abspath is not None: data["path"] = apply(data["path"], lambda x: os.path.join(args.abspath, x)) - if args.remove_caption_prefix: + if args.refine_llm_caption: assert "text" in data.columns data["text"] = apply(data["text"], remove_caption_prefix) if args.unescape: @@ -507,11 +445,11 @@ def main(args): ) if args.merge_cmotion: data["text"] = apply(data, lambda x: merge_cmotion(x["text"], x["cmotion"]), axis=1) - if args.count_text_token is not None: + if args.count_num_token is not None: assert "text" in data.columns data["text_len"] = apply(data["text"], lambda x: len(tokenizer(x)["input_ids"])) if args.info: - info = apply(data["path"], get_video_info) + info = apply(data["path"], get_info) ( data["num_frames"], data["height"], @@ -558,14 +496,145 @@ def main(args): if args.shard is not None: sharded_data = np.array_split(data, args.shard) for i in range(args.shard): - output_path_s = output_path.replace(".csv", f"_{i}.csv") - sharded_data[i].to_csv(output_path_s, index=False) + output_path_part = output_path.split(".") + output_path_s = ".".join(output_path_part[:-1]) + f"_{i}." + output_path_part[-1] + save_file(sharded_data[i], output_path_s) print(f"Saved {len(sharded_data[i])} samples to {output_path_s}.") else: - data.to_csv(output_path, index=False) + save_file(data, output_path) print(f"Saved {len(data)} samples to {output_path}.") +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input", type=str, nargs="+") + parser.add_argument("--output", type=str, default=None, help="output path") + parser.add_argument("--format", type=str, default="csv", help="output format", choices=["csv", "parquet"]) + parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing") + parser.add_argument("--seed", type=int, default=None, help="random seed") + + # special case + parser.add_argument("--shard", type=int, default=None, help="shard the dataset") + parser.add_argument("--sort", type=str, default=None, help="sort by column") + parser.add_argument("--sort-ascending", type=str, default=None, help="sort by column (ascending order)") + parser.add_argument("--difference", type=str, default=None, help="remove the paths in csv from the dataset") + parser.add_argument( + "--intersection", type=str, default=None, help="keep the paths in csv from the dataset and merge columns" + ) + + # IO-related + parser.add_argument("--info", action="store_true", help="get the basic information of each video and image") + parser.add_argument("--ext", action="store_true", help="check if the file exists") + parser.add_argument("--remove-corrupted", action="store_true", help="remove the corrupted video and image") + + # path processing + parser.add_argument("--relpath", type=str, default=None, help="modify the path to relative path by root given") + parser.add_argument("--abspath", type=str, default=None, help="modify the path to absolute path by root given") + + # caption filtering + parser.add_argument( + "--remove-empty-caption", + action="store_true", + help="remove the empty caption", + help="remove rows with empty caption", + ) + parser.add_argument("--remove-url", action="store_true", help="remove rows with url in caption") + parser.add_argument("--lang", type=str, default=None, help="remove rows with other language") + parser.add_argument("--remove-path-duplication", action="store_true", help="remove rows with duplicated path") + parser.add_argument("--remove-text-duplication", action="store_true", help="remove rows with duplicated caption") + + # caption processing + parser.add_argument("--refine-llm-caption", action="store_true", help="modify the caption generated by LLM") + parser.add_argument( + "--clean-caption", action="store_true", help="modify the caption according to T5 pipeline to suit training" + ) + parser.add_argument("--unescape", action="store_true", help="unescape the caption") + parser.add_argument("--merge-cmotion", action="store_true", help="merge the camera motion to the caption") + parser.add_argument( + "--count-num-token", type=str, choices=["t5"], default=None, help="Count the number of tokens in the caption" + ) + + # score filtering + parser.add_argument("--fmin", type=int, default=None, help="filter the dataset by minimum number of frames") + parser.add_argument("--fmax", type=int, default=None, help="filter the dataset by maximum number of frames") + parser.add_argument("--hwmax", type=int, default=None, help="filter the dataset by maximum resolution") + parser.add_argument("--aesmin", type=float, default=None, help="filter the dataset by minimum aes score") + parser.add_argument("--matchmin", type=float, default=None, help="filter the dataset by minimum match score") + parser.add_argument("--flowmin", type=float, default=None, help="filter the dataset by minimum flow score") + + return parser.parse_args() + + +def get_output_path(args, input_name): + if args.output is not None: + return args.output + name = input_name + dir_path = os.path.dirname(args.input[0]) + + # sort + if args.sort is not None: + assert args.sort_ascending is None + name += "_sort" + if args.sort_ascending is not None: + assert args.sort is None + name += "_sort" + + # IO-related + if args.info: + name += "_info" + if args.ext: + name += "_ext" + if args.remove_corrupted: + name += "_nocorrupted" + + # path processing + if args.relpath is not None: + name += "_relpath" + if args.abspath is not None: + name += "_abspath" + + # caption filtering + if args.remove_empty_caption: + name += "_noempty" + if args.remove_url: + name += "_nourl" + if args.lang is not None: + name += f"_{args.lang}" + if args.remove_path_duplication: + name += "_noduppath" + if args.remove_text_duplication: + name += "_noduptext" + + # caption processing + if args.refine_llm_caption: + name += "_llm" + if args.unescape: + name += "_unescape" + if args.clean_caption: + name += "_clean" + if args.merge_cmotion: + name += "_cmcaption" + if args.count_num_token: + name += "_ntoken" + + # score filtering + if args.fmin is not None: + name += f"_fmin{args.fmin}" + if args.fmax is not None: + name += f"_fmax{args.fmax}" + if args.hwmax is not None: + name += f"_hwmax{args.hwmax}" + if args.aesmin is not None: + name += f"_aesmin{args.aesmin}" + if args.matchmin is not None: + name += f"_matchmin{args.matchmin}" + if args.flowmin is not None: + name += f"_flowmin{args.flowmin}" + + output_path = os.path.join(dir_path, f"{name}.{args.format}") + return output_path + + if __name__ == "__main__": args = parse_args() if args.disable_parallel: