Open-Sora/opensora/datasets/utils.py

163 lines
5.9 KiB
Python
Raw Normal View History

2024-03-15 15:00:46 +01:00
import numpy as np
import torch
2024-03-23 13:28:34 +01:00
import torchvision
import torchvision.transforms as transforms
2024-03-15 15:00:46 +01:00
from PIL import Image
2024-03-23 13:28:34 +01:00
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
2024-03-15 15:00:46 +01:00
from torchvision.io import write_video
from torchvision.utils import save_image
2024-03-23 13:28:34 +01:00
from . import video_transforms
2024-03-28 15:04:43 +01:00
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
2024-03-23 09:32:51 +01:00
2024-03-26 17:24:46 +01:00
2024-03-26 10:32:15 +01:00
def temporal_random_crop(vframes, num_frames, frame_interval):
temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
total_frames = len(vframes)
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
assert end_frame_ind - start_frame_ind >= num_frames
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int)
video = vframes[frame_indice]
return video
2024-03-26 17:24:46 +01:00
def get_transforms_video(name="center", image_size=(256, 256)):
if name is None:
return None
elif name == "center":
assert image_size[0] == image_size[1], "image_size must be square for center crop"
2024-03-26 17:24:46 +01:00
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
# video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(image_size[0]),
2024-03-26 17:24:46 +01:00
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "resize_crop":
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.ResizeCrop(image_size),
2024-03-26 17:24:46 +01:00
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
else:
raise NotImplementedError(f"Transform {name} not implemented")
2024-03-23 13:28:34 +01:00
return transform_video
2024-03-26 17:24:46 +01:00
def get_transforms_image(name="center", image_size=(256, 256)):
if name is None:
return None
elif name == "center":
2024-03-26 17:24:46 +01:00
assert image_size[0] == image_size[1], "Image size must be square for center crop"
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "resize_crop":
2024-03-28 14:35:33 +01:00
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
2024-03-26 17:24:46 +01:00
else:
raise NotImplementedError(f"Transform {name} not implemented")
2024-03-23 13:28:34 +01:00
return transform
2024-03-26 17:24:46 +01:00
def read_image_from_path(path, transform=None, num_frames=1, image_size=(256, 256)):
2024-03-23 09:32:51 +01:00
image = pil_loader(path)
if transform is None:
2024-03-26 17:24:46 +01:00
transform = get_transforms_image(image_size=image_size)
2024-03-23 09:32:51 +01:00
image = transform(image)
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
video = video.permute(1, 0, 2, 3)
return video
2024-03-26 17:24:46 +01:00
def read_video_from_path(path, transform=None, image_size=(256, 256)):
2024-03-23 09:32:51 +01:00
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
if transform is None:
2024-03-26 17:24:46 +01:00
transform = get_transforms_video(image_size=image_size)
2024-03-23 09:32:51 +01:00
video = transform(vframes) # T C H W
video = video.permute(1, 0, 2, 3)
return video
def read_from_path(path, image_size):
ext = path.split(".")[-1]
if ext.lower() in VID_EXTENSIONS:
return read_video_from_path(path, image_size=image_size)
else:
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
return read_image_from_path(path, image_size=image_size)
2024-03-15 15:00:46 +01:00
def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1)):
"""
Args:
x (Tensor): shape [C, T, H, W]
"""
assert x.ndim == 4
if x.shape[1] == 1: # T = 1: save as image
save_path += ".png"
x = x.squeeze(1)
save_image([x], save_path, normalize=normalize, value_range=value_range)
else:
save_path += ".mp4"
if normalize:
low, high = value_range
x.clamp_(min=low, max=high)
x.sub_(low).div_(max(high - low, 1e-5))
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8)
write_video(save_path, x, fps=fps, video_codec="h264")
print(f"Saved to {save_path}")
return save_path
2024-03-15 15:00:46 +01:00
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
2024-03-28 14:35:33 +01:00
def resize_crop_to_fill(pil_image, image_size):
2024-03-30 05:02:48 +01:00
w, h = pil_image.size # PIL is (W, H)
th, tw = image_size
2024-03-28 14:35:33 +01:00
rh, rw = th / h, tw / w
if rh > rw:
2024-03-30 10:05:15 +01:00
sh, sw = th, round(w * rh)
2024-03-28 14:35:33 +01:00
image = pil_image.resize((sw, sh), Image.BICUBIC)
i = 0
j = int(round((sw - tw) / 2.0))
else:
2024-03-30 10:05:15 +01:00
sh, sw = round(h * rw), tw
2024-03-28 14:35:33 +01:00
image = pil_image.resize((sw, sh), Image.BICUBIC)
i = int(round((sh - th) / 2.0))
j = 0
arr = np.array(image)
2024-03-30 10:05:15 +01:00
assert i + th <= arr.shape[0] and j + tw <= arr.shape[1]
2024-03-28 14:35:33 +01:00
return Image.fromarray(arr[i : i + th, j : j + tw])