import os import numpy as np import torch import torchvision import torchvision.transforms as transforms from PIL import Image from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader from torchvision.io import write_video from torchvision.utils import save_image from . import video_transforms VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") def temporal_random_crop(vframes, num_frames, frame_interval): temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval) total_frames = len(vframes) start_frame_ind, end_frame_ind = temporal_sample(total_frames) assert end_frame_ind - start_frame_ind >= num_frames frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int) video = vframes[frame_indice] return video def get_transforms_video(name="center", image_size=(256, 256)): if name is None: return None elif name == "center": assert image_size[0] == image_size[1], "image_size must be square for center crop" transform_video = transforms.Compose( [ video_transforms.ToTensorVideo(), # TCHW # video_transforms.RandomHorizontalFlipVideo(), video_transforms.UCFCenterCropVideo(image_size[0]), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) elif name == "resize_crop": transform_video = transforms.Compose( [ video_transforms.ToTensorVideo(), # TCHW video_transforms.ResizeCrop(image_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) else: raise NotImplementedError(f"Transform {name} not implemented") return transform_video def get_transforms_image(name="center", image_size=(256, 256)): if name is None: return None elif name == "center": assert image_size[0] == image_size[1], "Image size must be square for center crop" transform = transforms.Compose( [ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])), # transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) elif name == "resize_crop": transform = transforms.Compose( [ transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) else: raise NotImplementedError(f"Transform {name} not implemented") return transform def read_image_from_path(path, transform=None, num_frames=1, image_size=(256, 256)): image = pil_loader(path) if transform is None: transform = get_transforms_image(image_size=image_size) image = transform(image) video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1) video = video.permute(1, 0, 2, 3) return video def read_video_from_path(path, transform=None, image_size=(256, 256)): vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") if transform is None: transform = get_transforms_video(image_size=image_size) video = transform(vframes) # T C H W video = video.permute(1, 0, 2, 3) return video def read_from_path(path, image_size): ext = os.path.splitext(path)[-1].lower() if ext.lower() in VID_EXTENSIONS: return read_video_from_path(path, image_size=image_size) else: assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}" return read_image_from_path(path, image_size=image_size) def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1)): """ Args: x (Tensor): shape [C, T, H, W] """ assert x.ndim == 4 if x.shape[1] == 1: # T = 1: save as image save_path += ".png" x = x.squeeze(1) save_image([x], save_path, normalize=normalize, value_range=value_range) else: save_path += ".mp4" if normalize: low, high = value_range x.clamp_(min=low, max=high) x.sub_(low).div_(max(high - low, 1e-5)) x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8) write_video(save_path, x, fps=fps, video_codec="h264") print(f"Saved to {save_path}") return save_path def center_crop_arr(pil_image, image_size): """ Center cropping implementation from ADM. https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 """ while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) scale = image_size / min(*pil_image.size) pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 crop_x = (arr.shape[1] - image_size) // 2 return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) def resize_crop_to_fill(pil_image, image_size): w, h = pil_image.size # PIL is (W, H) th, tw = image_size rh, rw = th / h, tw / w if rh > rw: sh, sw = th, round(w * rh) image = pil_image.resize((sw, sh), Image.BICUBIC) i = 0 j = int(round((sw - tw) / 2.0)) else: sh, sw = round(h * rw), tw image = pil_image.resize((sw, sh), Image.BICUBIC) i = int(round((sh - th) / 2.0)) j = 0 arr = np.array(image) assert i + th <= arr.shape[0] and j + tw <= arr.shape[1] return Image.fromarray(arr[i : i + th, j : j + tw])