mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 13:14:44 +02:00
[wip] image wrong with flashattn
This commit is contained in:
parent
a01f6da20e
commit
2d4a5df287
|
|
@ -160,6 +160,7 @@ class Bucket:
|
|||
|
||||
def info_bucket(self, dataset, frame_interval=1):
|
||||
infos = dict()
|
||||
infos_ar = dict()
|
||||
for i in range(len(dataset)):
|
||||
T, H, W = dataset.get_data_info(i)
|
||||
bucket_id = self.get_bucket_id(T, H, W, frame_interval)
|
||||
|
|
@ -167,9 +168,13 @@ class Bucket:
|
|||
continue
|
||||
if f"{(bucket_id[0], bucket_id[1])}" not in infos:
|
||||
infos[f"{(bucket_id[0], bucket_id[1])}"] = 0
|
||||
if f"{bucket_id[2]}" not in infos_ar:
|
||||
infos_ar[f"{bucket_id[2]}"] = 0
|
||||
infos[f"{(bucket_id[0], bucket_id[1])}"] += 1
|
||||
infos_ar[f"{bucket_id[2]}"] += 1
|
||||
print(f"Dataset contains {len(dataset)} samples.")
|
||||
print("Bucket info:", infos)
|
||||
print("Aspect ratio info:", infos_ar)
|
||||
|
||||
def get_bucket_id(self, T, H, W, frame_interval=1):
|
||||
# hw
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
|
@ -38,11 +40,11 @@ class VideoTextDataset(torch.utils.data.Dataset):
|
|||
}
|
||||
|
||||
def get_type(self, path):
|
||||
ext = path.split(".")[-1]
|
||||
ext = os.path.splitext(path)[-1].lower()
|
||||
if ext.lower() in VID_EXTENSIONS:
|
||||
return "video"
|
||||
else:
|
||||
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
||||
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
||||
return "image"
|
||||
|
||||
def getitem(self, index):
|
||||
|
|
@ -139,13 +141,14 @@ class VariableVideoTextDataset(VideoTextDataset):
|
|||
image = transform(image)
|
||||
|
||||
# repeat
|
||||
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
|
||||
video = image.unsqueeze(0)
|
||||
|
||||
# TCHW -> CTHW
|
||||
video = video.permute(1, 0, 2, 3)
|
||||
return {"video": video, "text": text, "num_frames": num_frames, "height": height, "width": width, "ar": ar}
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.getitem(index)
|
||||
for _ in range(10):
|
||||
try:
|
||||
return self.getitem(index)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from torchvision.utils import save_image
|
|||
|
||||
from . import video_transforms
|
||||
|
||||
VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv")
|
||||
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||
|
||||
|
||||
def temporal_random_crop(vframes, num_frames, frame_interval):
|
||||
|
|
|
|||
|
|
@ -1,16 +1,21 @@
|
|||
import argparse
|
||||
import csv
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
from torchvision.datasets import ImageNet
|
||||
|
||||
|
||||
def get_filelist(file_path):
|
||||
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
|
||||
VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv")
|
||||
|
||||
|
||||
def get_filelist(file_path, exts=None):
|
||||
Filelist = []
|
||||
for home, dirs, files in os.walk(file_path):
|
||||
for filename in files:
|
||||
Filelist.append(os.path.join(home, filename))
|
||||
ext = os.path.splitext(filename)[-1].lower()
|
||||
if exts is None or ext in exts:
|
||||
Filelist.append(os.path.join(home, filename))
|
||||
return Filelist
|
||||
|
||||
|
||||
|
|
@ -61,9 +66,25 @@ def process_vidprom(root, info):
|
|||
print(f"Saved {len(df)} samples to vidprom.csv.")
|
||||
|
||||
|
||||
def process_general_images(root):
|
||||
root = os.path.expanduser(root)
|
||||
image_lists = get_filelist(root, IMG_EXTENSIONS)
|
||||
df = pd.DataFrame(dict(path=image_lists))
|
||||
df.to_csv("images.csv", index=False)
|
||||
print(f"Saved {len(df)} samples to images.csv.")
|
||||
|
||||
|
||||
def process_general_videos(root):
|
||||
root = os.path.expanduser(root)
|
||||
video_lists = get_filelist(root, VID_EXTENSIONS)
|
||||
df = pd.DataFrame(dict(path=video_lists))
|
||||
df.to_csv("videos.csv", index=False)
|
||||
print(f"Saved {len(df)} samples to videos.csv.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("dataset", type=str, choices=["imagenet", "ucf101", "vidprom"])
|
||||
parser.add_argument("dataset", type=str, choices=["imagenet", "ucf101", "vidprom", "image", "video"])
|
||||
parser.add_argument("root", type=str)
|
||||
parser.add_argument("--split", type=str, default="train")
|
||||
parser.add_argument("--info", type=str, default=None)
|
||||
|
|
@ -75,5 +96,9 @@ if __name__ == "__main__":
|
|||
process_ucf101(args.root, args.split)
|
||||
elif args.dataset == "vidprom":
|
||||
process_vidprom(args.root, args.info)
|
||||
elif args.dataset == "image":
|
||||
process_general_images(args.root)
|
||||
elif args.dataset == "video":
|
||||
process_general_videos(args.root)
|
||||
else:
|
||||
raise ValueError("Invalid dataset")
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ def get_video_info(path):
|
|||
int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
|
||||
float(cap.get(cv2.CAP_PROP_FPS)),
|
||||
)
|
||||
aspect_ratio = height / width if width > 0 else np.nan
|
||||
aspect_ratio = height / width if width > 0 else np.nan
|
||||
return num_frames, height, width, aspect_ratio, fps
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue