mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-24 17:44:47 +02:00
67 lines
1.8 KiB
Python
67 lines
1.8 KiB
Python
import argparse
|
|
import csv
|
|
import os
|
|
|
|
from torchvision.datasets import ImageNet
|
|
|
|
|
|
def get_filelist(file_path):
|
|
Filelist = []
|
|
for home, dirs, files in os.walk(file_path):
|
|
for filename in files:
|
|
Filelist.append(os.path.join(home, filename))
|
|
return Filelist
|
|
|
|
|
|
def split_by_capital(name):
|
|
# BoxingPunchingBag -> Boxing Punching Bag
|
|
new_name = ""
|
|
for i in range(len(name)):
|
|
if name[i].isupper() and i != 0:
|
|
new_name += " "
|
|
new_name += name[i]
|
|
return new_name
|
|
|
|
|
|
def process_imagenet(root, split):
|
|
root = os.path.expanduser(root)
|
|
data = ImageNet(root, split=split)
|
|
samples = [(path, data.classes[label][0]) for path, label in data.samples]
|
|
output = f"imagenet_{split}.csv"
|
|
|
|
with open(output, "w") as f:
|
|
writer = csv.writer(f)
|
|
writer.writerows(samples)
|
|
|
|
print(f"Saved {len(samples)} samples to {output}.")
|
|
|
|
|
|
def process_ucf101(root, split):
|
|
root = os.path.expanduser(root)
|
|
video_lists = get_filelist(os.path.join(root, split))
|
|
classes = [x.split("/")[-2] for x in video_lists]
|
|
classes = [split_by_capital(x) for x in classes]
|
|
samples = list(zip(video_lists, classes))
|
|
output = f"ucf101_{split}.csv"
|
|
|
|
with open(output, "w") as f:
|
|
writer = csv.writer(f)
|
|
writer.writerows(samples)
|
|
|
|
print(f"Saved {len(samples)} samples to {output}.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("dataset", type=str, choices=["imagenet", "ucf101"])
|
|
parser.add_argument("root", type=str)
|
|
parser.add_argument("--split", type=str, default="train")
|
|
args = parser.parse_args()
|
|
|
|
if args.dataset == "imagenet":
|
|
process_imagenet(args.root, args.split)
|
|
elif args.dataset == "ucf101":
|
|
process_ucf101(args.root, args.split)
|
|
else:
|
|
raise ValueError("Invalid dataset")
|