mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
update convert_dataset and launch exp
This commit is contained in:
parent
0344b24264
commit
b1f6e128dc
|
|
@ -4,6 +4,7 @@ diffusers
|
|||
ftfy
|
||||
gdown
|
||||
mmengine
|
||||
pandas
|
||||
pre-commit
|
||||
pyav
|
||||
tensorboard
|
||||
|
|
|
|||
|
|
@ -2,11 +2,12 @@
|
|||
|
||||
## Dataset Format
|
||||
|
||||
The training data should be provided in a CSV file with the following format:
|
||||
The training data should be provided in a CSV file, with each row containing the following information:
|
||||
|
||||
```csv
|
||||
path, text, num_frames, aesthetic_score
|
||||
/absolute/path/to/image1.jpg, caption1, num_of_frames
|
||||
/absolute/path/to/image2.jpg, caption2, num_of_frames
|
||||
/absolute/path/to/video2.mp4, caption2, num_of_frames
|
||||
```
|
||||
|
||||
## HD-VG-130M
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import argparse
|
|||
import csv
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
from torchvision.datasets import ImageNet
|
||||
|
||||
|
||||
|
|
@ -29,10 +30,8 @@ def process_imagenet(root, split):
|
|||
samples = [(path, data.classes[label][0]) for path, label in data.samples]
|
||||
output = f"imagenet_{split}.csv"
|
||||
|
||||
with open(output, "w") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerows(samples)
|
||||
|
||||
df = pd.DataFrame(samples, columns=["path", "text"])
|
||||
df.to_csv(output, index=False)
|
||||
print(f"Saved {len(samples)} samples to {output}.")
|
||||
|
||||
|
||||
|
|
@ -44,10 +43,8 @@ def process_ucf101(root, split):
|
|||
samples = list(zip(video_lists, classes))
|
||||
output = f"ucf101_{split}.csv"
|
||||
|
||||
with open(output, "w") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerows(samples)
|
||||
|
||||
df = pd.DataFrame(samples, columns=["path", "text"])
|
||||
df.to_csv(output, index=False)
|
||||
print(f"Saved {len(samples)} samples to {output}.")
|
||||
|
||||
|
||||
|
|
@ -56,18 +53,12 @@ def process_vidprom(root, info):
|
|||
video_lists = get_filelist(root)
|
||||
video_set = set(video_lists)
|
||||
# read info csv
|
||||
with open(info, "r") as f:
|
||||
reader = csv.reader(f)
|
||||
lines = list(reader)
|
||||
infos = [(os.path.join(root, f"pika-{x[0]}.mp4"), x[1]) for x in lines[1:]]
|
||||
samples = [x for x in infos if x[0] in video_set]
|
||||
output = "vidprom.csv"
|
||||
|
||||
with open(output, "w") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerows(samples)
|
||||
|
||||
print(f"Saved {len(samples)} samples to {output}.")
|
||||
infos = pd.read_csv(info)
|
||||
abs_path = infos["uuid"].apply(lambda x: os.path.join(root, f"pika-{x}.mp4"))
|
||||
is_exist = abs_path.apply(lambda x: x in video_set)
|
||||
df = pd.DataFrame(dict(path=abs_path[is_exist], text=infos["prompt"][is_exist]))
|
||||
df.to_csv("vidprom.csv", index=False)
|
||||
print(f"Saved {len(df)} samples to vidprom.csv.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in a new issue