From 2d4a5df287858e6b3c097e381ab34aac41aee103 Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Thu, 28 Mar 2024 22:04:43 +0800 Subject: [PATCH] [wip] image wrong with flashattn --- opensora/datasets/dataloader.py | 5 +++++ opensora/datasets/datasets.py | 9 ++++++--- opensora/datasets/utils.py | 2 +- tools/datasets/convert_dataset.py | 33 +++++++++++++++++++++++++++---- tools/datasets/csvutil.py | 2 +- 5 files changed, 42 insertions(+), 9 deletions(-) diff --git a/opensora/datasets/dataloader.py b/opensora/datasets/dataloader.py index 1a6d466..66d1ff0 100644 --- a/opensora/datasets/dataloader.py +++ b/opensora/datasets/dataloader.py @@ -160,6 +160,7 @@ class Bucket: def info_bucket(self, dataset, frame_interval=1): infos = dict() + infos_ar = dict() for i in range(len(dataset)): T, H, W = dataset.get_data_info(i) bucket_id = self.get_bucket_id(T, H, W, frame_interval) @@ -167,9 +168,13 @@ class Bucket: continue if f"{(bucket_id[0], bucket_id[1])}" not in infos: infos[f"{(bucket_id[0], bucket_id[1])}"] = 0 + if f"{bucket_id[2]}" not in infos_ar: + infos_ar[f"{bucket_id[2]}"] = 0 infos[f"{(bucket_id[0], bucket_id[1])}"] += 1 + infos_ar[f"{bucket_id[2]}"] += 1 print(f"Dataset contains {len(dataset)} samples.") print("Bucket info:", infos) + print("Aspect ratio info:", infos_ar) def get_bucket_id(self, T, H, W, frame_interval=1): # hw diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index 99505c1..e0cae5f 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -1,3 +1,5 @@ +import os + import numpy as np import pandas as pd import torch @@ -38,11 +40,11 @@ class VideoTextDataset(torch.utils.data.Dataset): } def get_type(self, path): - ext = path.split(".")[-1] + ext = os.path.splitext(path)[-1].lower() if ext.lower() in VID_EXTENSIONS: return "video" else: - assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}" + assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}" return "image" def getitem(self, index): @@ -139,13 +141,14 @@ class VariableVideoTextDataset(VideoTextDataset): image = transform(image) # repeat - video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1) + video = image.unsqueeze(0) # TCHW -> CTHW video = video.permute(1, 0, 2, 3) return {"video": video, "text": text, "num_frames": num_frames, "height": height, "width": width, "ar": ar} def __getitem__(self, index): + return self.getitem(index) for _ in range(10): try: return self.getitem(index) diff --git a/opensora/datasets/utils.py b/opensora/datasets/utils.py index 7cf7d32..f5c71c9 100644 --- a/opensora/datasets/utils.py +++ b/opensora/datasets/utils.py @@ -11,7 +11,7 @@ from torchvision.utils import save_image from . import video_transforms -VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv") +VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") def temporal_random_crop(vframes, num_frames, frame_interval): diff --git a/tools/datasets/convert_dataset.py b/tools/datasets/convert_dataset.py index 30472f0..0b01907 100644 --- a/tools/datasets/convert_dataset.py +++ b/tools/datasets/convert_dataset.py @@ -1,16 +1,21 @@ import argparse -import csv import os import pandas as pd from torchvision.datasets import ImageNet -def get_filelist(file_path): +IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") +VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv") + + +def get_filelist(file_path, exts=None): Filelist = [] for home, dirs, files in os.walk(file_path): for filename in files: - Filelist.append(os.path.join(home, filename)) + ext = os.path.splitext(filename)[-1].lower() + if exts is None or ext in exts: + Filelist.append(os.path.join(home, filename)) return Filelist @@ -61,9 +66,25 @@ def process_vidprom(root, info): print(f"Saved {len(df)} samples to vidprom.csv.") +def process_general_images(root): + root = os.path.expanduser(root) + image_lists = get_filelist(root, IMG_EXTENSIONS) + df = pd.DataFrame(dict(path=image_lists)) + df.to_csv("images.csv", index=False) + print(f"Saved {len(df)} samples to images.csv.") + + +def process_general_videos(root): + root = os.path.expanduser(root) + video_lists = get_filelist(root, VID_EXTENSIONS) + df = pd.DataFrame(dict(path=video_lists)) + df.to_csv("videos.csv", index=False) + print(f"Saved {len(df)} samples to videos.csv.") + + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("dataset", type=str, choices=["imagenet", "ucf101", "vidprom"]) + parser.add_argument("dataset", type=str, choices=["imagenet", "ucf101", "vidprom", "image", "video"]) parser.add_argument("root", type=str) parser.add_argument("--split", type=str, default="train") parser.add_argument("--info", type=str, default=None) @@ -75,5 +96,9 @@ if __name__ == "__main__": process_ucf101(args.root, args.split) elif args.dataset == "vidprom": process_vidprom(args.root, args.info) + elif args.dataset == "image": + process_general_images(args.root) + elif args.dataset == "video": + process_general_videos(args.root) else: raise ValueError("Invalid dataset") diff --git a/tools/datasets/csvutil.py b/tools/datasets/csvutil.py index dd76482..8607ec6 100644 --- a/tools/datasets/csvutil.py +++ b/tools/datasets/csvutil.py @@ -42,7 +42,7 @@ def get_video_info(path): int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), float(cap.get(cv2.CAP_PROP_FPS)), ) - aspect_ratio = height / width if width > 0 else np.nan + aspect_ratio = height / width if width > 0 else np.nan return num_frames, height, width, aspect_ratio, fps