[wip] image wrong with flashattn

This commit is contained in:
Zangwei Zheng 2024-03-28 22:04:43 +08:00
parent a01f6da20e
commit 2d4a5df287
5 changed files with 42 additions and 9 deletions

View file

@ -160,6 +160,7 @@ class Bucket:
def info_bucket(self, dataset, frame_interval=1):
infos = dict()
infos_ar = dict()
for i in range(len(dataset)):
T, H, W = dataset.get_data_info(i)
bucket_id = self.get_bucket_id(T, H, W, frame_interval)
@ -167,9 +168,13 @@ class Bucket:
continue
if f"{(bucket_id[0], bucket_id[1])}" not in infos:
infos[f"{(bucket_id[0], bucket_id[1])}"] = 0
if f"{bucket_id[2]}" not in infos_ar:
infos_ar[f"{bucket_id[2]}"] = 0
infos[f"{(bucket_id[0], bucket_id[1])}"] += 1
infos_ar[f"{bucket_id[2]}"] += 1
print(f"Dataset contains {len(dataset)} samples.")
print("Bucket info:", infos)
print("Aspect ratio info:", infos_ar)
def get_bucket_id(self, T, H, W, frame_interval=1):
# hw

View file

@ -1,3 +1,5 @@
import os
import numpy as np
import pandas as pd
import torch
@ -38,11 +40,11 @@ class VideoTextDataset(torch.utils.data.Dataset):
}
def get_type(self, path):
ext = path.split(".")[-1]
ext = os.path.splitext(path)[-1].lower()
if ext.lower() in VID_EXTENSIONS:
return "video"
else:
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
return "image"
def getitem(self, index):
@ -139,13 +141,14 @@ class VariableVideoTextDataset(VideoTextDataset):
image = transform(image)
# repeat
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
video = image.unsqueeze(0)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text, "num_frames": num_frames, "height": height, "width": width, "ar": ar}
def __getitem__(self, index):
return self.getitem(index)
for _ in range(10):
try:
return self.getitem(index)

View file

@ -11,7 +11,7 @@ from torchvision.utils import save_image
from . import video_transforms
VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv")
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
def temporal_random_crop(vframes, num_frames, frame_interval):

View file

@ -1,16 +1,21 @@
import argparse
import csv
import os
import pandas as pd
from torchvision.datasets import ImageNet
def get_filelist(file_path):
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv")
def get_filelist(file_path, exts=None):
Filelist = []
for home, dirs, files in os.walk(file_path):
for filename in files:
Filelist.append(os.path.join(home, filename))
ext = os.path.splitext(filename)[-1].lower()
if exts is None or ext in exts:
Filelist.append(os.path.join(home, filename))
return Filelist
@ -61,9 +66,25 @@ def process_vidprom(root, info):
print(f"Saved {len(df)} samples to vidprom.csv.")
def process_general_images(root):
root = os.path.expanduser(root)
image_lists = get_filelist(root, IMG_EXTENSIONS)
df = pd.DataFrame(dict(path=image_lists))
df.to_csv("images.csv", index=False)
print(f"Saved {len(df)} samples to images.csv.")
def process_general_videos(root):
root = os.path.expanduser(root)
video_lists = get_filelist(root, VID_EXTENSIONS)
df = pd.DataFrame(dict(path=video_lists))
df.to_csv("videos.csv", index=False)
print(f"Saved {len(df)} samples to videos.csv.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("dataset", type=str, choices=["imagenet", "ucf101", "vidprom"])
parser.add_argument("dataset", type=str, choices=["imagenet", "ucf101", "vidprom", "image", "video"])
parser.add_argument("root", type=str)
parser.add_argument("--split", type=str, default="train")
parser.add_argument("--info", type=str, default=None)
@ -75,5 +96,9 @@ if __name__ == "__main__":
process_ucf101(args.root, args.split)
elif args.dataset == "vidprom":
process_vidprom(args.root, args.info)
elif args.dataset == "image":
process_general_images(args.root)
elif args.dataset == "video":
process_general_videos(args.root)
else:
raise ValueError("Invalid dataset")

View file

@ -42,7 +42,7 @@ def get_video_info(path):
int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
float(cap.get(cv2.CAP_PROP_FPS)),
)
aspect_ratio = height / width if width > 0 else np.nan
aspect_ratio = height / width if width > 0 else np.nan
return num_frames, height, width, aspect_ratio, fps