Open-Sora/opensora/datasets/utils.py
Zheng Zangwei (Alex Zheng) f1c6b8b88e open-sora v1.3 code upload (#786)
Co-authored-by: gxyes <gxynoz@gmail.com>
2025-02-20 16:50:24 +08:00

325 lines
11 KiB
Python

import os
import random
import re
import cv2
import numpy as np
import pandas as pd
import requests
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from torchvision.io import write_video
from torchvision.utils import save_image
from . import video_transforms
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
regex = re.compile(
r"^(?:http|ftp)s?://" # http:// or https://
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain...
r"localhost|" # localhost...
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip
r"(?::\d+)?" # optional port
r"(?:/?|[/?]\S+)$",
re.IGNORECASE,
)
def is_img(path):
ext = os.path.splitext(path)[-1].lower()
return ext in IMG_EXTENSIONS
def is_vid(path):
ext = os.path.splitext(path)[-1].lower()
return ext in VID_EXTENSIONS
def is_url(url):
return re.match(regex, url) is not None
def read_file(input_path):
if input_path.endswith(".csv"):
return pd.read_csv(input_path)
elif input_path.endswith(".parquet"):
return pd.read_parquet(input_path)
else:
raise NotImplementedError(f"Unsupported file format: {input_path}")
def download_url(input_path):
output_dir = "cache"
os.makedirs(output_dir, exist_ok=True)
base_name = os.path.basename(input_path)
output_path = os.path.join(output_dir, base_name)
img_data = requests.get(input_path).content
with open(output_path, "wb") as handler:
handler.write(img_data)
print(f"URL {input_path} downloaded to {output_path}")
return output_path
def temporal_random_crop(vframes, num_frames, frame_interval):
temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
total_frames = len(vframes)
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
assert (
end_frame_ind - start_frame_ind >= num_frames
), f"Not enough frames to sample, {end_frame_ind} - {start_frame_ind} < {num_frames}"
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int)
# print(frame_indice)
video = vframes[frame_indice]
return video
def get_transforms_video(name="center", image_size=(256, 256)):
if name is None:
return None
elif name == "center":
assert image_size[0] == image_size[1], "image_size must be square for center crop"
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
# video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(image_size[0]),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "resize_crop":
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.ResizeCrop(image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "rand_size_crop":
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomSizedCrop(image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
else:
raise NotImplementedError(f"Transform {name} not implemented")
return transform_video
def get_transforms_image(name="center", image_size=(256, 256)):
if name is None:
return None
elif name == "center":
assert image_size[0] == image_size[1], "Image size must be square for center crop"
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "resize_crop":
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "rand_size_crop":
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: rand_size_crop_arr(pil_image, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
else:
raise NotImplementedError(f"Transform {name} not implemented")
return transform
def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)):
image = pil_loader(path)
if transform is None:
transform = get_transforms_image(image_size=image_size, name=transform_name)
image = transform(image)
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
video = video.permute(1, 0, 2, 3)
return video
def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)):
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
if transform is None:
transform = get_transforms_video(image_size=image_size, name=transform_name)
video = transform(vframes) # T C H W
video = video.permute(1, 0, 2, 3)
return video
def read_from_path(path, image_size, transform_name="center"):
if is_url(path):
path = download_url(path)
ext = os.path.splitext(path)[-1].lower()
if ext.lower() in VID_EXTENSIONS:
return read_video_from_path(path, image_size=image_size, transform_name=transform_name)
else:
assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
return read_image_from_path(path, image_size=image_size, transform_name=transform_name)
def write_video_cv2(
filename: str,
video: torch.Tensor,
fps: float,
):
image_size = (video.size(2), video.size(1))
fourcc = cv2.VideoWriter_fourcc(
*"avc1"
) # NOTE: your opencv must be installed with `conda install -c conda-forge opencv` else the video may not display on browser
output = cv2.VideoWriter(filename, fourcc, fps, image_size)
for frame_idx in range(video.size(0)):
frame = np.array(video[frame_idx]) # H,W,C
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
output.write(frame)
output.release()
def save_sample(
x,
save_path=None,
fps=8,
normalize=True,
value_range=(-1, 1),
force_video=False,
verbose=True,
write_video_backend="cv2",
):
"""
Args:
x (Tensor): shape [C, T, H, W]
"""
assert x.ndim == 4
if not force_video and x.shape[1] == 1: # T = 1: save as image
with_ext = False
for ext in IMG_EXTENSIONS:
if save_path.endswith(ext):
with_ext = True
break
if not with_ext:
save_path += ".png"
x = x.squeeze(1)
save_image([x], save_path, normalize=normalize, value_range=value_range)
else:
with_ext = False
for ext in VID_EXTENSIONS:
if save_path.endswith(ext):
with_ext = True
break
if not with_ext:
save_path += ".mp4"
if normalize:
low, high = value_range
x.clamp_(min=low, max=high)
x.sub_(low).div_(max(high - low, 1e-5))
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8)
if write_video_backend == "cv2":
write_video_cv2(save_path, x, fps=fps)
else:
write_video(save_path, x, fps=fps, video_codec="h264")
if verbose:
print(f"Saved to {save_path}")
return save_path
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
def rand_size_crop_arr(pil_image, image_size):
"""
Randomly crop image for height and width, ranging from image_size[0] to image_size[1]
"""
arr = np.array(pil_image)
# get random target h w
height = random.randint(image_size[0], image_size[1])
width = random.randint(image_size[0], image_size[1])
# ensure that h w are factors of 8
height = height - height % 8
width = width - width % 8
# get random start pos
h_start = random.randint(0, max(len(arr) - height, 0))
w_start = random.randint(0, max(len(arr[0]) - height, 0))
# crop
return Image.fromarray(arr[h_start : h_start + height, w_start : w_start + width])
def resize_crop_to_fill(pil_image, image_size):
w, h = pil_image.size # PIL is (W, H)
th, tw = image_size
rh, rw = th / h, tw / w
if rh > rw:
sh, sw = th, round(w * rh)
image = pil_image.resize((sw, sh), Image.BICUBIC)
i = 0
j = int(round((sw - tw) / 2.0))
else:
sh, sw = round(h * rw), tw
image = pil_image.resize((sw, sh), Image.BICUBIC)
i = int(round((sh - th) / 2.0))
j = 0
arr = np.array(image)
assert i + th <= arr.shape[0] and j + tw <= arr.shape[1]
return Image.fromarray(arr[i : i + th, j : j + tw])
def rand_size_crop_arr(pil_image, image_size):
w, h = pil_image.size # PIL is (W, H)
th = random.randint(self.size[0], self.size[1])
tw = random.randint(self.size[0], self.size[1])
# ensure that h w are factors of 8
th = th - th % multiples_of
tw = tw - tw % multiples_of
if h < th:
th = h - h % multiples_of
if w < tw:
tw = w - w % multiples_of
if w == tw and h == th:
i = j = 0
else:
# get random start pos
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return Image.fromarray(arr[i : i + th, j : j + tw])