Open-Sora/opensora/datasets/utils.py

224 lines
8.1 KiB
Python
Raw Normal View History

2024-03-15 15:00:46 +01:00
import random
from typing import Iterator, Optional
import numpy as np
import torch
2024-03-23 13:28:34 +01:00
import torchvision
import torchvision.transforms as transforms
2024-03-15 15:00:46 +01:00
from PIL import Image
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
2024-03-23 13:28:34 +01:00
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
2024-03-15 15:00:46 +01:00
from torchvision.io import write_video
from torchvision.utils import save_image
2024-03-23 13:28:34 +01:00
from . import video_transforms
2024-03-23 09:32:51 +01:00
VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv")
2024-03-26 17:24:46 +01:00
2024-03-26 10:32:15 +01:00
def temporal_random_crop(vframes, num_frames, frame_interval):
temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
total_frames = len(vframes)
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
assert end_frame_ind - start_frame_ind >= num_frames
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int)
video = vframes[frame_indice]
return video
2024-03-26 17:24:46 +01:00
def get_transforms_video(name="center", resolution=(256, 256)):
if name == "center":
assert resolution[0] == resolution[1], "Resolution must be square for center crop"
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
# video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(resolution[0]),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "resize_crop":
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.ResizeCrop(resolution),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
else:
raise NotImplementedError(f"Transform {name} not implemented")
2024-03-23 13:28:34 +01:00
return transform_video
2024-03-26 17:24:46 +01:00
def get_transforms_image(name="center", image_size=(256, 256)):
if name == "center":
assert image_size[0] == image_size[1], "Image size must be square for center crop"
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
elif name == "resize_crop":
transform = None
else:
raise NotImplementedError(f"Transform {name} not implemented")
2024-03-23 13:28:34 +01:00
return transform
2024-03-26 17:24:46 +01:00
def read_image_from_path(path, transform=None, num_frames=1, image_size=(256, 256)):
2024-03-23 09:32:51 +01:00
image = pil_loader(path)
if transform is None:
2024-03-26 17:24:46 +01:00
transform = get_transforms_image(image_size=image_size)
2024-03-23 09:32:51 +01:00
image = transform(image)
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
video = video.permute(1, 0, 2, 3)
return video
2024-03-26 17:24:46 +01:00
def read_video_from_path(path, transform=None, image_size=(256, 256)):
2024-03-23 09:32:51 +01:00
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
if transform is None:
2024-03-26 17:24:46 +01:00
transform = get_transforms_video(image_size=image_size)
2024-03-23 09:32:51 +01:00
video = transform(vframes) # T C H W
video = video.permute(1, 0, 2, 3)
return video
def read_from_path(path, image_size):
ext = path.split(".")[-1]
if ext.lower() in VID_EXTENSIONS:
return read_video_from_path(path, image_size=image_size)
else:
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
return read_image_from_path(path, image_size=image_size)
2024-03-15 15:00:46 +01:00
def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1)):
"""
Args:
x (Tensor): shape [C, T, H, W]
"""
assert x.ndim == 4
if x.shape[1] == 1: # T = 1: save as image
save_path += ".png"
x = x.squeeze(1)
save_image([x], save_path, normalize=normalize, value_range=value_range)
else:
save_path += ".mp4"
if normalize:
low, high = value_range
x.clamp_(min=low, max=high)
x.sub_(low).div_(max(high - low, 1e-5))
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8)
write_video(save_path, x, fps=fps, video_codec="h264")
print(f"Saved to {save_path}")
return save_path
2024-03-15 15:00:46 +01:00
class StatefulDistributedSampler(DistributedSampler):
def __init__(
self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
self.start_index: int = 0
def __iter__(self) -> Iterator:
iterator = super().__iter__()
indices = list(iterator)
indices = indices[self.start_index :]
return iter(indices)
def __len__(self) -> int:
return self.num_samples - self.start_index
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index
def prepare_dataloader(
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
process_group: Optional[ProcessGroup] = None,
**kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(
dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
2024-03-15 15:16:20 +01:00
2024-03-15 15:00:46 +01:00
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])