Open-Sora/tools/datasets/convert.py

144 lines
4.8 KiB
Python
Raw Normal View History

2024-03-17 13:09:58 +01:00
import argparse
import os
2024-04-11 07:34:02 +02:00
import time
2024-03-17 13:09:58 +01:00
2024-03-24 15:03:31 +01:00
import pandas as pd
2024-03-17 13:09:58 +01:00
from torchvision.datasets import ImageNet
2024-03-28 15:04:43 +01:00
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
2024-03-28 15:04:43 +01:00
2024-04-11 07:34:02 +02:00
def scan_recursively(root):
num = 0
for entry in os.scandir(root):
if entry.is_file():
yield entry
elif entry.is_dir():
num += 1
if num % 100 == 0:
print(f"Scanned {num} directories.")
yield from scan_recursively(entry.path)
2024-03-28 15:04:43 +01:00
def get_filelist(file_path, exts=None):
2024-04-11 07:34:02 +02:00
filelist = []
time_start = time.time()
# == OS Walk ==
# for home, dirs, files in os.walk(file_path):
# for filename in files:
# ext = os.path.splitext(filename)[-1].lower()
# if exts is None or ext in exts:
# filelist.append(os.path.join(home, filename))
# == Scandir ==
obj = scan_recursively(file_path)
for entry in obj:
if entry.is_file():
ext = os.path.splitext(entry.name)[-1].lower()
2024-03-28 15:04:43 +01:00
if exts is None or ext in exts:
2024-04-11 07:34:02 +02:00
filelist.append(entry.path)
time_end = time.time()
print(f"Scanned {len(filelist)} files in {time_end - time_start:.2f} seconds.")
return filelist
2024-03-17 13:09:58 +01:00
def split_by_capital(name):
# BoxingPunchingBag -> Boxing Punching Bag
new_name = ""
for i in range(len(name)):
if name[i].isupper() and i != 0:
new_name += " "
new_name += name[i]
return new_name
def process_imagenet(root, split):
root = os.path.expanduser(root)
data = ImageNet(root, split=split)
samples = [(path, data.classes[label][0]) for path, label in data.samples]
output = f"imagenet_{split}.csv"
2024-03-24 15:03:31 +01:00
df = pd.DataFrame(samples, columns=["path", "text"])
df.to_csv(output, index=False)
2024-03-17 13:09:58 +01:00
print(f"Saved {len(samples)} samples to {output}.")
def process_ucf101(root, split):
root = os.path.expanduser(root)
video_lists = get_filelist(os.path.join(root, split))
classes = [x.split("/")[-2] for x in video_lists]
classes = [split_by_capital(x) for x in classes]
samples = list(zip(video_lists, classes))
output = f"ucf101_{split}.csv"
2024-03-24 15:03:31 +01:00
df = pd.DataFrame(samples, columns=["path", "text"])
df.to_csv(output, index=False)
2024-03-17 13:09:58 +01:00
print(f"Saved {len(samples)} samples to {output}.")
2024-03-23 09:02:26 +01:00
def process_vidprom(root, info):
root = os.path.expanduser(root)
video_lists = get_filelist(root)
video_set = set(video_lists)
# read info csv
2024-03-24 15:03:31 +01:00
infos = pd.read_csv(info)
abs_path = infos["uuid"].apply(lambda x: os.path.join(root, f"pika-{x}.mp4"))
is_exist = abs_path.apply(lambda x: x in video_set)
df = pd.DataFrame(dict(path=abs_path[is_exist], text=infos["prompt"][is_exist]))
df.to_csv("vidprom.csv", index=False)
print(f"Saved {len(df)} samples to vidprom.csv.")
2024-03-23 09:02:26 +01:00
def process_general_images(root, output):
2024-03-28 15:04:43 +01:00
root = os.path.expanduser(root)
2024-06-10 07:30:59 +02:00
if not os.path.exists(root):
return
Dev/pxy (#100) * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scene_cut * update scene_cut * update scene_cut[A * update scene_cut * update scene_cut * update scene_cut * update scene_cut * update scene_cut * update scene_cut * m * m * m * m * m * m * m * m * m * m * m * m * m * m * update readme * update readme * extract frames using opencv everywhere * extract frames using opencv everywhere * extract frames using opencv everywhere * filter panda10m * filter panda10m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * ocr * add ocr * add main.sh * add ocr * add ocr * add ocr * add ocr * add ocr * add ocr * update scene_cut * update remove main.sh * update scoring * update scoring * update scoring * update README * update readme * update scene_cut * update readme * update scoring * update readme * update readme * update filter_panda10m * update readme * update readme * update launch.ipynb * update scene_cut * update scene_cut * update readme * update launch.ipynb * update readme * add 1.1 demo * update readme * add 1.1 demo * update readme * Update README.md * add num_workers for pandarallel * update scene_cut * update readme * update datautil * update scoring * update scoring * update readme * update scoring * update scene_cut * update scene_cut * udpate datautil * update datautil
2024-05-14 05:21:14 +02:00
path_list = get_filelist(root, IMG_EXTENSIONS)
fname_list = [os.path.splitext(os.path.basename(x))[0] for x in path_list]
df = pd.DataFrame(dict(id=fname_list, path=path_list))
os.makedirs(os.path.dirname(output), exist_ok=True)
df.to_csv(output, index=False)
print(f"Saved {len(df)} samples to {output}.")
2024-03-28 15:04:43 +01:00
def process_general_videos(root, output):
2024-03-28 15:04:43 +01:00
root = os.path.expanduser(root)
2024-06-10 07:30:59 +02:00
if not os.path.exists(root):
return
Dev/pxy (#100) * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scene_cut * update scene_cut * update scene_cut[A * update scene_cut * update scene_cut * update scene_cut * update scene_cut * update scene_cut * update scene_cut * m * m * m * m * m * m * m * m * m * m * m * m * m * m * update readme * update readme * extract frames using opencv everywhere * extract frames using opencv everywhere * extract frames using opencv everywhere * filter panda10m * filter panda10m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * ocr * add ocr * add main.sh * add ocr * add ocr * add ocr * add ocr * add ocr * add ocr * update scene_cut * update remove main.sh * update scoring * update scoring * update scoring * update README * update readme * update scene_cut * update readme * update scoring * update readme * update readme * update filter_panda10m * update readme * update readme * update launch.ipynb * update scene_cut * update scene_cut * update readme * update launch.ipynb * update readme * add 1.1 demo * update readme * add 1.1 demo * update readme * Update README.md * add num_workers for pandarallel * update scene_cut * update readme * update datautil * update scoring * update scoring * update readme * update scoring * update scene_cut * update scene_cut * udpate datautil * update datautil
2024-05-14 05:21:14 +02:00
path_list = get_filelist(root, VID_EXTENSIONS)
2024-06-17 17:37:23 +02:00
path_list = list(set(path_list)) # remove duplicates
Dev/pxy (#100) * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scoring/matching * update scene_cut * update scene_cut * update scene_cut[A * update scene_cut * update scene_cut * update scene_cut * update scene_cut * update scene_cut * update scene_cut * m * m * m * m * m * m * m * m * m * m * m * m * m * m * update readme * update readme * extract frames using opencv everywhere * extract frames using opencv everywhere * extract frames using opencv everywhere * filter panda10m * filter panda10m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * m * ocr * add ocr * add main.sh * add ocr * add ocr * add ocr * add ocr * add ocr * add ocr * update scene_cut * update remove main.sh * update scoring * update scoring * update scoring * update README * update readme * update scene_cut * update readme * update scoring * update readme * update readme * update filter_panda10m * update readme * update readme * update launch.ipynb * update scene_cut * update scene_cut * update readme * update launch.ipynb * update readme * add 1.1 demo * update readme * add 1.1 demo * update readme * Update README.md * add num_workers for pandarallel * update scene_cut * update readme * update datautil * update scoring * update scoring * update readme * update scoring * update scene_cut * update scene_cut * udpate datautil * update datautil
2024-05-14 05:21:14 +02:00
fname_list = [os.path.splitext(os.path.basename(x))[0] for x in path_list]
relpath_list = [os.path.relpath(x, root) for x in path_list]
df = pd.DataFrame(dict(path=path_list, id=fname_list, relpath=relpath_list))
os.makedirs(os.path.dirname(output), exist_ok=True)
df.to_csv(output, index=False)
print(f"Saved {len(df)} samples to {output}.")
2024-03-28 15:04:43 +01:00
2024-03-17 13:09:58 +01:00
if __name__ == "__main__":
parser = argparse.ArgumentParser()
2024-03-28 15:04:43 +01:00
parser.add_argument("dataset", type=str, choices=["imagenet", "ucf101", "vidprom", "image", "video"])
2024-03-17 13:09:58 +01:00
parser.add_argument("root", type=str)
parser.add_argument("--split", type=str, default="train")
2024-03-23 09:02:26 +01:00
parser.add_argument("--info", type=str, default=None)
2024-06-17 17:37:23 +02:00
parser.add_argument("--output", type=str, default=None, required=True, help="Output path")
2024-03-17 13:09:58 +01:00
args = parser.parse_args()
if args.dataset == "imagenet":
process_imagenet(args.root, args.split)
elif args.dataset == "ucf101":
process_ucf101(args.root, args.split)
2024-03-23 09:02:26 +01:00
elif args.dataset == "vidprom":
process_vidprom(args.root, args.info)
2024-03-28 15:04:43 +01:00
elif args.dataset == "image":
process_general_images(args.root, args.output)
2024-03-28 15:04:43 +01:00
elif args.dataset == "video":
process_general_videos(args.root, args.output)
2024-03-17 13:09:58 +01:00
else:
raise ValueError("Invalid dataset")