From da9b00e8082aa40c766595573dbd991b0cb2e098 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 23 Feb 2024 11:26:28 +0800 Subject: [PATCH] added dataset processing scripts (#8) * added dataset processing scripts * added dataset processing scripts --- .gitignore | 5 +- README.md | 71 ++++++- requirements.txt | 5 + scripts/data/collate_msr_vtt_dataset.py | 182 ++++++++++++++++++ scripts/data/download_msr_vtt_dataset.sh | 11 ++ .../data/preprocess_data.py | 7 +- 6 files changed, 271 insertions(+), 10 deletions(-) create mode 100644 requirements.txt create mode 100644 scripts/data/collate_msr_vtt_dataset.py create mode 100644 scripts/data/download_msr_vtt_dataset.sh rename preprocess_data.py => scripts/data/preprocess_data.py (92%) diff --git a/.gitignore b/.gitignore index 06884cf..31a93e1 100644 --- a/.gitignore +++ b/.gitignore @@ -158,4 +158,7 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -.vscode/ \ No newline at end of file +.vscode/ + +# files needed to train +dataset/ diff --git a/README.md b/README.md index 27d0737..fb44ec2 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,49 @@ # 🎥 Open-Sora +## 📎 Table of Contents + +- [🎥 Open-Sora](#-open-sora) + - [📎 Table of Contents](#-table-of-contents) + - [📍 Overview](#-overview) + - [📂 Dataset Preparation](#-dataset-preparation) + - [Use MSR-VTT](#use-msr-vtt) + - [Use Customized Datasets](#use-customized-datasets) + - [🚀 Get Started](#-get-started) + - [Training](#training) + - [Inference](#inference) + - [🪄 Acknowledgement](#-acknowledgement) + ## 📍 Overview This repository is an unofficial implementation of OpenAI's Sora. We built this based on the [facebookresearch/DiT](https://github.com/facebookresearch/DiT) repository. -## Dataset preparation +## 📂 Dataset Preparation -We use [MSR-VTT](https://cove.thecvf.com/datasets/839) dataset, which is a large-scale video description dataset. We should preprocess the raw videos before training the model. +### Use MSR-VTT -Before running `preprocess_data.py`, you should prepare a captions file and a video directory. The captions file should be a JSON file or a JSONL file. The video directory should contain all the videos. +We use [MSR-VTT](https://cove.thecvf.com/datasets/839) dataset, which is a large-scale video description dataset. We should preprocess the raw videos before training the model. You can use the following scripts to perform data processing. + + +```bash +# Step 1: download the dataset to ./dataset/MSRVTT +bash scripts/data/download_msr_vtt_dataset.sh + +# Step 2: collate the video and annotations +python scripts/data/collate_msr_vtt_dataset.py -d ./dataset/MSRVTT/ -o ./dataset/MSRVTT-collated + +# Step 3: perform data processing +# NOTE: each script could several minutes so we apply the script to the dataset split individually +python scripts/data/preprocess_data.py -c ./dataset/MSRVTT-collated/train/annotations.json -v ./dataset/MSRVTT-collated/train/videos -o ./dataset/MSRVTT-processed/train +python scripts/data/preprocess_data.py -c ./dataset/MSRVTT-collated/val/annotations.json -v ./dataset/MSRVTT-collated/val/videos -o ./dataset/MSRVTT-processed/val +python scripts/data/preprocess_data.py -c ./dataset/MSRVTT-collated/test/annotations.json -v ./dataset/MSRVTT-collated/test/videos -o ./dataset/MSRVTT-processed/test +``` + +After completing the steps, you should have a processed MSR-VTT dataset in `./dataset/MSRVTT-processed`. + + +### Use Customized Datasets + +You can also use other datasets and transform the dataset to the required format. You should prepare a captions file and a video directory. The captions file should be a JSON file or a JSONL file. The video directory should contain all the videos. Here is an example of the captions file: @@ -40,10 +75,36 @@ We use [VQ-VAE](https://github.com/wilson1yan/VideoGPT/) to quantize the video f The output is an arrow dataset, which contains the following columns: "video_file", "video_latent_states", "text_latent_states". The dimension of "video_latent_states" is (T, H, W), and the dimension of "text_latent_states" is (S, D). -How to run the script: +Then you can run the data processing script with the command below: ```bash -python preprocess_data.py /path/to/captions.json /path/to/video_dir /path/to/output_dir +python preprocess_data.py -c /path/to/captions.json -v /path/to/video_dir -o /path/to/output_dir ``` Note that this script needs to be run on a machine with a GPU. To avoid CUDA OOM, we filter out the videos that are too long. + + +## 🚀 Get Started + +In this section, we will provide a guidance on how to run training and inference. Before that, make sure you installed the dependencies with the command below. + +```bash +pip install -r requirements.txt +``` + +### Training + +To be added. + +### Inference + +To be added. + + +## 🪄 Acknowledgement + +During development of the project, we learnt a lot from the following public materials: + +- [OpenAI Sora Technical Report](https://openai.com/research/video-generation-models-as-world-simulators) +- [VideoGPT Project](https://github.com/wilson1yan/VideoGPT) +- [Diffusion Transformers](https://github.com/facebookresearch/DiT) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1ec0061 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +torch +torchvision +datasets +transformers +av diff --git a/scripts/data/collate_msr_vtt_dataset.py b/scripts/data/collate_msr_vtt_dataset.py new file mode 100644 index 0000000..04f63cb --- /dev/null +++ b/scripts/data/collate_msr_vtt_dataset.py @@ -0,0 +1,182 @@ +import argparse +import json +import multiprocessing +import os +import shutil +import warnings +from typing import Dict, Tuple + +from tqdm import tqdm + +DEFAULT_TYPES = ["train", "val", "test"] + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--data-path", type=str, help="The path to the MSR-VTT dataset") + parser.add_argument("-o", "--output-path", type=str, help="The output to the collated MSR-VTT dataset") + return parser.parse_args() + + +def get_annotations(root_path: str): + """ + Get the annotation data from the MSR-VTT dataset. The annotations are in the format of: + + { + "annotations": [ + { + "image_id": "video1", + "caption": "some + } + ] + } + + Args: + root_path (str): The root path to the MSR-VTT dataset + """ + annotation_json_file = os.path.join(root_path, "annotation/MSR_VTT.json") + with open(annotation_json_file, "r") as f: + data = json.load(f) + return data + + +def get_video_list(root_path: str, dataset_type: str): + """ + Get the list of videos in the dataset split. + + Args: + root_path (str): The root path to the MSR-VTT dataset + dataset_type (str): The dataset split type. It should be one of "train", "val", or "test" + """ + assert dataset_type in DEFAULT_TYPES, f"Expected the dataset type to be in {DEFAULT_TYPES}, but got {dataset_type}" + dataset_file_path = os.path.join(root_path, f"structured-symlinks/{dataset_type}_list_full.txt") + with open(dataset_file_path, "r") as f: + video_list = f.readlines() + video_list = [x.strip() for x in video_list] + return video_list + + +def copy_video(video_id: str, root_path: str, output_path: str, dataset_type: str): + """ + Copy the video from the source path to the destination path. + + Args: + video_id (str): The video id + root_path (str): The root path to the MSR-VTT dataset + output_path (str): The output path to the collated MSR-VTT dataset + dataset_type (str): The dataset split type. It should be one of "train", "val", or "test" + """ + assert dataset_type in DEFAULT_TYPES, f"Expected the dataset type to be in {DEFAULT_TYPES}, but got {dataset_type}" + src_file = os.path.join(root_path, f"videos/all/{video_id}.mp4") + dst_folder = os.path.join(output_path, f"{dataset_type}/videos") + dst_file = os.path.join(dst_folder, f"{video_id}.mp4") + os.makedirs(dst_folder, exist_ok=True) + + # create symlink + assert os.path.isfile(src_file), f"Expected the source file {src_file} to exist" + if not os.path.islink(dst_file): + shutil.copy(src_file, dst_file) + + +def get_annotation_file_path(output_path: str, dataset_type: str): + file_path = os.path.join(output_path, f"{dataset_type}/annotations.json") + return file_path + + +def collate_annotation_files( + annotations: Dict, + root_path: str, + output_path: str, +): + """ + Collate the video and caption data into a single folder. + + Args: + annotations (Dict): The annotations data + root_path (str): The root path to the MSR-VTT dataset + output_path (str): The output path to the collated MSR-VTT dataset + """ + # get all video list + train_video_list = get_video_list(root_path, "train") + val_video_list = get_video_list(root_path, "val") + test_video_list = get_video_list(root_path, "test") + + # iterate over annotations + collated_train_data = [] + collated_val_data = [] + collated_test_data = [] + + print("Collating annotations files") + + for anno in tqdm(annotations["annotations"]): + video_id = anno["image_id"] + caption = anno["caption"] + + obj = {"file": f"{video_id}.mp4", "captions": [caption]} + + if video_id in train_video_list: + collated_train_data.append(obj) + elif video_id in val_video_list: + collated_val_data.append(obj) + elif video_id in test_video_list: + collated_test_data.append(obj) + else: + warnings.warn(f"Video {video_id} not found in any of the dataset splits") + + def _save_caption_files(obj, dataset_type): + dst_file = get_annotation_file_path(output_path, dataset_type) + os.makedirs(os.path.dirname(dst_file), exist_ok=True) + with open(dst_file, "w") as f: + json.dump(obj, f, indent=4) + + _save_caption_files(collated_train_data, "train") + _save_caption_files(collated_val_data, "val") + _save_caption_files(collated_test_data, "test") + + +def copy_file(path_pair: Tuple[str, str]): + src_path, dst_path = path_pair + shutil.copyfile(src_path, dst_path) + + +def copy_videos(root_path: str, output_path: str, num_workers: int = 8): + """ + Batch copy the video files to the output path. + + Args: + root_path (str): The root path to the MSR-VTT dataset + output_path (str): The output path to the collated MSR-VTT dataset + num_workers (int): The number of workers to use for the copy operation + """ + pool = multiprocessing.Pool(num_workers) + + for dataset_type in DEFAULT_TYPES: + print(f"Copying videos for the {dataset_type} dataset") + annotation_file_path = get_annotation_file_path(output_path, dataset_type) + output_video_folder_path = os.path.join(output_path, f"{dataset_type}/videos") + os.makedirs(output_video_folder_path, exist_ok=True) + + with open(annotation_file_path, "r") as f: + annotation_data = json.load(f) + + video_ids = [obj["file"] for obj in annotation_data] + unique_video_ids = list(set(video_ids)) + + path_pairs = [ + (os.path.join(root_path, f"videos/all/{video_id}"), os.path.join(output_video_folder_path, video_id)) + for video_id in unique_video_ids + ] + + for _ in tqdm(pool.imap_unordered(copy_file, path_pairs), total=len(path_pairs)): + pass + + +def main(): + args = parse_args() + annotations = get_annotations(args.data_path) + collate_annotation_files(annotations, args.data_path, args.output_path) + copy_videos(args.data_path, args.output_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/data/download_msr_vtt_dataset.sh b/scripts/data/download_msr_vtt_dataset.sh new file mode 100644 index 0000000..141a7d6 --- /dev/null +++ b/scripts/data/download_msr_vtt_dataset.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +# get root dir +FOLDER_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +ROOT_DIR=$FOLDER_DIR/../.. + +# download at root dir +cd $ROOT_DIR +mkdir -p dataset && cd ./dataset +wget https://www.robots.ox.ac.uk/~maxbain/frozen-in-time/data/MSRVTT.zip +unzip MSRVTT.zip diff --git a/preprocess_data.py b/scripts/data/preprocess_data.py similarity index 92% rename from preprocess_data.py rename to scripts/data/preprocess_data.py index 12199c1..a92a2b2 100644 --- a/preprocess_data.py +++ b/scripts/data/preprocess_data.py @@ -89,7 +89,6 @@ def process_dataset( ds = load_dataset("json", data_files=captions_file, keep_in_memory=False, split=train_splits) for i, part_ds in enumerate(ds): - print(f"Processing part {i+1}/{len(ds)}") part_ds = part_ds.map( process_batch, fn_kwargs={"video_dir": video_dir, "tokenizer": tokenizer, "text_model": text_model, "vqvae": vqvae}, @@ -105,10 +104,10 @@ def process_dataset( if __name__ == "__main__": parser = argparse.ArgumentParser(description="Preprocess data") parser.add_argument( - "captions_file", type=str, help="Path to the captions file. It should be a JSON file or a JSONL file" + "-c", "--captions-file", type=str, help="Path to the captions file. It should be a JSON file or a JSONL file" ) - parser.add_argument("video_dir", type=str, help="Path to the video directory") - parser.add_argument("output_dir", type=str, help="Path to the output directory") + parser.add_argument("-v", "--video-dir", type=str, help="Path to the video directory") + parser.add_argument("-o", "--output_dir", type=str, help="Path to the output directory") parser.add_argument( "-n", "--num_spliced_dataset_bins", type=int, default=10, help="Number of bins for spliced dataset" )