added dataset processing scripts (#8)

* added dataset processing scripts

* added dataset processing scripts
This commit is contained in:
Frank Lee 2024-02-23 11:26:28 +08:00 committed by GitHub
parent 603b193d38
commit da9b00e808
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 271 additions and 10 deletions

5
.gitignore vendored
View file

@ -158,4 +158,7 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.vscode/
.vscode/
# files needed to train
dataset/

View file

@ -1,14 +1,49 @@
# 🎥 Open-Sora
## 📎 Table of Contents
- [🎥 Open-Sora](#-open-sora)
- [📎 Table of Contents](#-table-of-contents)
- [📍 Overview](#-overview)
- [📂 Dataset Preparation](#-dataset-preparation)
- [Use MSR-VTT](#use-msr-vtt)
- [Use Customized Datasets](#use-customized-datasets)
- [🚀 Get Started](#-get-started)
- [Training](#training)
- [Inference](#inference)
- [🪄 Acknowledgement](#-acknowledgement)
## 📍 Overview
This repository is an unofficial implementation of OpenAI's Sora. We built this based on the [facebookresearch/DiT](https://github.com/facebookresearch/DiT) repository.
## Dataset preparation
## 📂 Dataset Preparation
We use [MSR-VTT](https://cove.thecvf.com/datasets/839) dataset, which is a large-scale video description dataset. We should preprocess the raw videos before training the model.
### Use MSR-VTT
Before running `preprocess_data.py`, you should prepare a captions file and a video directory. The captions file should be a JSON file or a JSONL file. The video directory should contain all the videos.
We use [MSR-VTT](https://cove.thecvf.com/datasets/839) dataset, which is a large-scale video description dataset. We should preprocess the raw videos before training the model. You can use the following scripts to perform data processing.
```bash
# Step 1: download the dataset to ./dataset/MSRVTT
bash scripts/data/download_msr_vtt_dataset.sh
# Step 2: collate the video and annotations
python scripts/data/collate_msr_vtt_dataset.py -d ./dataset/MSRVTT/ -o ./dataset/MSRVTT-collated
# Step 3: perform data processing
# NOTE: each script could several minutes so we apply the script to the dataset split individually
python scripts/data/preprocess_data.py -c ./dataset/MSRVTT-collated/train/annotations.json -v ./dataset/MSRVTT-collated/train/videos -o ./dataset/MSRVTT-processed/train
python scripts/data/preprocess_data.py -c ./dataset/MSRVTT-collated/val/annotations.json -v ./dataset/MSRVTT-collated/val/videos -o ./dataset/MSRVTT-processed/val
python scripts/data/preprocess_data.py -c ./dataset/MSRVTT-collated/test/annotations.json -v ./dataset/MSRVTT-collated/test/videos -o ./dataset/MSRVTT-processed/test
```
After completing the steps, you should have a processed MSR-VTT dataset in `./dataset/MSRVTT-processed`.
### Use Customized Datasets
You can also use other datasets and transform the dataset to the required format. You should prepare a captions file and a video directory. The captions file should be a JSON file or a JSONL file. The video directory should contain all the videos.
Here is an example of the captions file:
@ -40,10 +75,36 @@ We use [VQ-VAE](https://github.com/wilson1yan/VideoGPT/) to quantize the video f
The output is an arrow dataset, which contains the following columns: "video_file", "video_latent_states", "text_latent_states". The dimension of "video_latent_states" is (T, H, W), and the dimension of "text_latent_states" is (S, D).
How to run the script:
Then you can run the data processing script with the command below:
```bash
python preprocess_data.py /path/to/captions.json /path/to/video_dir /path/to/output_dir
python preprocess_data.py -c /path/to/captions.json -v /path/to/video_dir -o /path/to/output_dir
```
Note that this script needs to be run on a machine with a GPU. To avoid CUDA OOM, we filter out the videos that are too long.
## 🚀 Get Started
In this section, we will provide a guidance on how to run training and inference. Before that, make sure you installed the dependencies with the command below.
```bash
pip install -r requirements.txt
```
### Training
To be added.
### Inference
To be added.
## 🪄 Acknowledgement
During development of the project, we learnt a lot from the following public materials:
- [OpenAI Sora Technical Report](https://openai.com/research/video-generation-models-as-world-simulators)
- [VideoGPT Project](https://github.com/wilson1yan/VideoGPT)
- [Diffusion Transformers](https://github.com/facebookresearch/DiT)

5
requirements.txt Normal file
View file

@ -0,0 +1,5 @@
torch
torchvision
datasets
transformers
av

View file

@ -0,0 +1,182 @@
import argparse
import json
import multiprocessing
import os
import shutil
import warnings
from typing import Dict, Tuple
from tqdm import tqdm
DEFAULT_TYPES = ["train", "val", "test"]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data-path", type=str, help="The path to the MSR-VTT dataset")
parser.add_argument("-o", "--output-path", type=str, help="The output to the collated MSR-VTT dataset")
return parser.parse_args()
def get_annotations(root_path: str):
"""
Get the annotation data from the MSR-VTT dataset. The annotations are in the format of:
{
"annotations": [
{
"image_id": "video1",
"caption": "some
}
]
}
Args:
root_path (str): The root path to the MSR-VTT dataset
"""
annotation_json_file = os.path.join(root_path, "annotation/MSR_VTT.json")
with open(annotation_json_file, "r") as f:
data = json.load(f)
return data
def get_video_list(root_path: str, dataset_type: str):
"""
Get the list of videos in the dataset split.
Args:
root_path (str): The root path to the MSR-VTT dataset
dataset_type (str): The dataset split type. It should be one of "train", "val", or "test"
"""
assert dataset_type in DEFAULT_TYPES, f"Expected the dataset type to be in {DEFAULT_TYPES}, but got {dataset_type}"
dataset_file_path = os.path.join(root_path, f"structured-symlinks/{dataset_type}_list_full.txt")
with open(dataset_file_path, "r") as f:
video_list = f.readlines()
video_list = [x.strip() for x in video_list]
return video_list
def copy_video(video_id: str, root_path: str, output_path: str, dataset_type: str):
"""
Copy the video from the source path to the destination path.
Args:
video_id (str): The video id
root_path (str): The root path to the MSR-VTT dataset
output_path (str): The output path to the collated MSR-VTT dataset
dataset_type (str): The dataset split type. It should be one of "train", "val", or "test"
"""
assert dataset_type in DEFAULT_TYPES, f"Expected the dataset type to be in {DEFAULT_TYPES}, but got {dataset_type}"
src_file = os.path.join(root_path, f"videos/all/{video_id}.mp4")
dst_folder = os.path.join(output_path, f"{dataset_type}/videos")
dst_file = os.path.join(dst_folder, f"{video_id}.mp4")
os.makedirs(dst_folder, exist_ok=True)
# create symlink
assert os.path.isfile(src_file), f"Expected the source file {src_file} to exist"
if not os.path.islink(dst_file):
shutil.copy(src_file, dst_file)
def get_annotation_file_path(output_path: str, dataset_type: str):
file_path = os.path.join(output_path, f"{dataset_type}/annotations.json")
return file_path
def collate_annotation_files(
annotations: Dict,
root_path: str,
output_path: str,
):
"""
Collate the video and caption data into a single folder.
Args:
annotations (Dict): The annotations data
root_path (str): The root path to the MSR-VTT dataset
output_path (str): The output path to the collated MSR-VTT dataset
"""
# get all video list
train_video_list = get_video_list(root_path, "train")
val_video_list = get_video_list(root_path, "val")
test_video_list = get_video_list(root_path, "test")
# iterate over annotations
collated_train_data = []
collated_val_data = []
collated_test_data = []
print("Collating annotations files")
for anno in tqdm(annotations["annotations"]):
video_id = anno["image_id"]
caption = anno["caption"]
obj = {"file": f"{video_id}.mp4", "captions": [caption]}
if video_id in train_video_list:
collated_train_data.append(obj)
elif video_id in val_video_list:
collated_val_data.append(obj)
elif video_id in test_video_list:
collated_test_data.append(obj)
else:
warnings.warn(f"Video {video_id} not found in any of the dataset splits")
def _save_caption_files(obj, dataset_type):
dst_file = get_annotation_file_path(output_path, dataset_type)
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
with open(dst_file, "w") as f:
json.dump(obj, f, indent=4)
_save_caption_files(collated_train_data, "train")
_save_caption_files(collated_val_data, "val")
_save_caption_files(collated_test_data, "test")
def copy_file(path_pair: Tuple[str, str]):
src_path, dst_path = path_pair
shutil.copyfile(src_path, dst_path)
def copy_videos(root_path: str, output_path: str, num_workers: int = 8):
"""
Batch copy the video files to the output path.
Args:
root_path (str): The root path to the MSR-VTT dataset
output_path (str): The output path to the collated MSR-VTT dataset
num_workers (int): The number of workers to use for the copy operation
"""
pool = multiprocessing.Pool(num_workers)
for dataset_type in DEFAULT_TYPES:
print(f"Copying videos for the {dataset_type} dataset")
annotation_file_path = get_annotation_file_path(output_path, dataset_type)
output_video_folder_path = os.path.join(output_path, f"{dataset_type}/videos")
os.makedirs(output_video_folder_path, exist_ok=True)
with open(annotation_file_path, "r") as f:
annotation_data = json.load(f)
video_ids = [obj["file"] for obj in annotation_data]
unique_video_ids = list(set(video_ids))
path_pairs = [
(os.path.join(root_path, f"videos/all/{video_id}"), os.path.join(output_video_folder_path, video_id))
for video_id in unique_video_ids
]
for _ in tqdm(pool.imap_unordered(copy_file, path_pairs), total=len(path_pairs)):
pass
def main():
args = parse_args()
annotations = get_annotations(args.data_path)
collate_annotation_files(annotations, args.data_path, args.output_path)
copy_videos(args.data_path, args.output_path)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,11 @@
#!/usr/bin/env bash
# get root dir
FOLDER_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
ROOT_DIR=$FOLDER_DIR/../..
# download at root dir
cd $ROOT_DIR
mkdir -p dataset && cd ./dataset
wget https://www.robots.ox.ac.uk/~maxbain/frozen-in-time/data/MSRVTT.zip
unzip MSRVTT.zip

View file

@ -89,7 +89,6 @@ def process_dataset(
ds = load_dataset("json", data_files=captions_file, keep_in_memory=False, split=train_splits)
for i, part_ds in enumerate(ds):
print(f"Processing part {i+1}/{len(ds)}")
part_ds = part_ds.map(
process_batch,
fn_kwargs={"video_dir": video_dir, "tokenizer": tokenizer, "text_model": text_model, "vqvae": vqvae},
@ -105,10 +104,10 @@ def process_dataset(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Preprocess data")
parser.add_argument(
"captions_file", type=str, help="Path to the captions file. It should be a JSON file or a JSONL file"
"-c", "--captions-file", type=str, help="Path to the captions file. It should be a JSON file or a JSONL file"
)
parser.add_argument("video_dir", type=str, help="Path to the video directory")
parser.add_argument("output_dir", type=str, help="Path to the output directory")
parser.add_argument("-v", "--video-dir", type=str, help="Path to the video directory")
parser.add_argument("-o", "--output_dir", type=str, help="Path to the output directory")
parser.add_argument(
"-n", "--num_spliced_dataset_bins", type=int, default=10, help="Number of bins for spliced dataset"
)