mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-17 05:36:01 +02:00
[feat] dynamic for video (base_size not completed)
This commit is contained in:
parent
94e686177e
commit
cba02e8a58
|
|
@ -4,14 +4,18 @@ dataset = dict(
|
|||
data_path=None,
|
||||
num_frames=None,
|
||||
frame_interval=3,
|
||||
image_size=(360, 480),
|
||||
image_size=(360, 480), # base size
|
||||
transform_name="resize_crop",
|
||||
bucket=[24, 48, 72],
|
||||
batch_size_bucket={24: 12, 48: 6, 72: 4},
|
||||
)
|
||||
bucket_config = {
|
||||
"240p": {1: (1.0, 128), 16: (1.0, 4), 48: (1.0, 4), 72: (1.0, 2)},
|
||||
"360p": {1: (1.0, 64), 16: (0.5, 2), 24: (0.5, 2), 72: (0.0, None)},
|
||||
"720p": {1: (1.0, 32), 16: (0.5, 1), 72: (0.0, None)},
|
||||
"1080p": {1: (1.0, 16)},
|
||||
}
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
num_workers = 0
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
|
|
@ -51,6 +55,6 @@ log_every = 10
|
|||
ckpt_every = 1000
|
||||
load = None
|
||||
|
||||
batch_size = 10 # only for logging
|
||||
batch_size = 10 # only for logging
|
||||
lr = 2e-5
|
||||
grad_clip = 1.0
|
||||
|
|
@ -1,3 +1,3 @@
|
|||
from .datasets import VideoTextDataset
|
||||
from .datasets_variable import VariableVideoTextDataset
|
||||
from .utils import get_transforms_image, get_transforms_video, prepare_dataloader, save_sample
|
||||
from .dataloader import prepare_dataloader, prepare_variable_dataloader
|
||||
from .datasets import VariableVideoTextDataset, VideoTextDataset
|
||||
from .utils import get_transforms_image, get_transforms_video, save_sample
|
||||
|
|
|
|||
|
|
@ -253,7 +253,7 @@ ASPECT_RATIO_256 = {
|
|||
def get_closest_ratio(height: float, width: float, ratios: dict):
|
||||
aspect_ratio = height / width
|
||||
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
||||
return ratios[closest_ratio], float(closest_ratio)
|
||||
return closest_ratio
|
||||
|
||||
|
||||
ASPECT_RATIOS = {
|
||||
|
|
|
|||
303
opensora/datasets/dataloader.py
Normal file
303
opensora/datasets/dataloader.py
Normal file
|
|
@ -0,0 +1,303 @@
|
|||
import random
|
||||
from collections import OrderedDict
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from .aspect import ASPECT_RATIOS, get_closest_ratio
|
||||
|
||||
|
||||
class StatefulDistributedSampler(DistributedSampler):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
) -> None:
|
||||
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
|
||||
self.start_index: int = 0
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
iterator = super().__iter__()
|
||||
indices = list(iterator)
|
||||
indices = indices[self.start_index :]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples - self.start_index
|
||||
|
||||
def set_start_index(self, start_index: int) -> None:
|
||||
self.start_index = start_index
|
||||
|
||||
|
||||
def prepare_dataloader(
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
|
||||
|
||||
|
||||
Args:
|
||||
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
||||
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
||||
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
||||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
||||
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
||||
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||
|
||||
Returns:
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = StatefulDistributedSampler(
|
||||
dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def find_approximate_hw(hw, hw_dict, approx=0.8):
|
||||
for k, v in hw_dict.items():
|
||||
if hw >= v * approx:
|
||||
return k
|
||||
return None
|
||||
|
||||
|
||||
def find_closet_smaller_bucket(t, t_dict, frame_interval):
|
||||
# process image
|
||||
if t == 1:
|
||||
if 1 in t_dict:
|
||||
return 1
|
||||
else:
|
||||
return None
|
||||
# process video
|
||||
for k, v in t_dict.items():
|
||||
if t >= v * frame_interval and v != 1:
|
||||
return k
|
||||
return None
|
||||
|
||||
|
||||
class Bucket:
|
||||
def __init__(self, bucket_config):
|
||||
for key in bucket_config:
|
||||
assert key in ASPECT_RATIOS, f"Aspect ratio {key} not found."
|
||||
# wrap config with OrderedDict
|
||||
bucket_probs = OrderedDict()
|
||||
bucket_bs = OrderedDict()
|
||||
bucket_names = sorted(bucket_config.keys(), key=lambda x: ASPECT_RATIOS[x][0], reverse=True)
|
||||
for key in bucket_names:
|
||||
bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True)
|
||||
bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names})
|
||||
bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names})
|
||||
|
||||
# first level: HW
|
||||
num_bucket = 0
|
||||
bucket = dict()
|
||||
hw_criteria = dict()
|
||||
t_criteria = dict()
|
||||
ar_criteria = dict()
|
||||
for k1, v1 in bucket_probs.items():
|
||||
bucket[k1] = dict()
|
||||
hw_criteria[k1] = ASPECT_RATIOS[k1][0]
|
||||
t_criteria[k1] = dict()
|
||||
ar_criteria[k1] = dict()
|
||||
for k2, _ in v1.items():
|
||||
bucket[k1][k2] = dict()
|
||||
t_criteria[k1][k2] = k2
|
||||
ar_criteria[k1][k2] = dict()
|
||||
for k3, v3 in ASPECT_RATIOS[k1][1].items():
|
||||
bucket[k1][k2][k3] = []
|
||||
ar_criteria[k1][k2][k3] = v3
|
||||
num_bucket += 1
|
||||
|
||||
self.bucket_probs = bucket_probs
|
||||
self.bucket_bs = bucket_bs
|
||||
self.bucket = bucket
|
||||
self.hw_criteria = hw_criteria
|
||||
self.t_criteria = t_criteria
|
||||
self.ar_criteria = ar_criteria
|
||||
self.num_bucket = num_bucket
|
||||
print(f"Number of buckets: {num_bucket}")
|
||||
|
||||
def info_bucket(self, dataset, frame_interval=1):
|
||||
infos = dict()
|
||||
for i in range(len(dataset)):
|
||||
T, H, W = dataset.get_data_info(i)
|
||||
bucket_id = self.get_bucket_id(T, H, W, frame_interval)
|
||||
if bucket_id is None:
|
||||
continue
|
||||
if f"{(bucket_id[0], bucket_id[1])}" not in infos:
|
||||
infos[f"{(bucket_id[0], bucket_id[1])}"] = 0
|
||||
infos[f"{(bucket_id[0], bucket_id[1])}"] += 1
|
||||
print(f"Dataset contains {len(dataset)} samples.")
|
||||
print("Bucket info:", infos)
|
||||
|
||||
def get_bucket_id(self, T, H, W, frame_interval=1):
|
||||
# hw
|
||||
hw = H * W
|
||||
hw_id = find_approximate_hw(hw, self.hw_criteria)
|
||||
if hw_id is None:
|
||||
return None
|
||||
|
||||
# hw drops by probablity
|
||||
while True:
|
||||
# T
|
||||
T_id = find_closet_smaller_bucket(T, self.t_criteria[hw_id], frame_interval)
|
||||
if T_id is not None:
|
||||
prob = self.get_prob((hw_id, T_id))
|
||||
if random.random() < prob:
|
||||
break
|
||||
hw_id = list(self.hw_criteria.keys()).index(hw_id)
|
||||
if hw_id == len(self.hw_criteria) - 1:
|
||||
break
|
||||
hw_id = list(self.hw_criteria.keys())[hw_id + 1]
|
||||
if T_id is None:
|
||||
return None
|
||||
|
||||
# ar
|
||||
ar_criteria = self.ar_criteria[hw_id][T_id]
|
||||
ar_id = get_closest_ratio(H, W, ar_criteria)
|
||||
return hw_id, T_id, ar_id
|
||||
|
||||
def get_thw(self, bucket_id):
|
||||
assert len(bucket_id) == 3
|
||||
T = self.t_criteria[bucket_id[0]][bucket_id[1]]
|
||||
H, W = self.ar_criteria[bucket_id[0]][bucket_id[1]][bucket_id[2]]
|
||||
return T, H, W
|
||||
|
||||
def get_prob(self, bucket_id):
|
||||
return self.bucket_probs[bucket_id[0]][bucket_id[1]]
|
||||
|
||||
def get_batch_size(self, bucket_id):
|
||||
return self.bucket_bs[bucket_id[0]][bucket_id[1]]
|
||||
|
||||
def __getitem__(self, index):
|
||||
assert len(index) == 3
|
||||
return self.bucket[index[0]][index[1]][index[2]]
|
||||
|
||||
def set_empty(self, index):
|
||||
assert len(index) == 3
|
||||
self.bucket[index[0]][index[1]][index[2]] = []
|
||||
|
||||
def __len__(self):
|
||||
return self.num_bucket
|
||||
|
||||
|
||||
def closet_smaller_bucket(value, bucket):
|
||||
for i in range(1, len(bucket)):
|
||||
if value < bucket[i]:
|
||||
return bucket[i - 1]
|
||||
return bucket[-1]
|
||||
|
||||
|
||||
class VariableVideoBatchSampler(torch.utils.data.BatchSampler):
|
||||
def __init__(self, sampler, batch_size, drop_last, dataset, buckect_config):
|
||||
self.sampler = sampler
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.drop_last = drop_last
|
||||
self.bucket = Bucket(buckect_config)
|
||||
self.frame_interval = self.dataset.frame_interval
|
||||
self.bucket.info_bucket(self.dataset, self.frame_interval)
|
||||
|
||||
def __iter__(self):
|
||||
for idx in self.sampler:
|
||||
T, H, W = self.dataset.get_data_info(idx)
|
||||
bucket_id = self.bucket.get_bucket_id(T, H, W, self.frame_interval)
|
||||
if bucket_id is None:
|
||||
continue
|
||||
rT, rH, rW = self.bucket.get_thw(bucket_id)
|
||||
self.dataset.set_data_info(idx, rT, rH, rW)
|
||||
buffer = self.bucket[bucket_id]
|
||||
buffer.append(idx)
|
||||
if len(buffer) >= self.bucket.get_batch_size(bucket_id):
|
||||
yield buffer
|
||||
self.bucket.set_empty(bucket_id)
|
||||
|
||||
for k1, v1 in self.bucket.bucket.items():
|
||||
for k2, v2 in v1.items():
|
||||
for k3, buffer in v2.items():
|
||||
if len(buffer) > 0 and not self.drop_last:
|
||||
yield buffer
|
||||
self.bucket.set_empty((k1, k2, k3))
|
||||
|
||||
|
||||
def prepare_variable_dataloader(
|
||||
dataset,
|
||||
batch_size,
|
||||
bucket_config,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
process_group=None,
|
||||
**kwargs,
|
||||
):
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = StatefulDistributedSampler(
|
||||
dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle
|
||||
)
|
||||
batch_sampler = VariableVideoBatchSampler(
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
drop_last=drop_last,
|
||||
dataset=dataset,
|
||||
buckect_config=bucket_config,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
|
|
@ -87,3 +87,69 @@ class VideoTextDataset(torch.utils.data.Dataset):
|
|||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class VariableVideoTextDataset(VideoTextDataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_path,
|
||||
num_frames=None,
|
||||
frame_interval=1,
|
||||
image_size=None,
|
||||
transform_name=None,
|
||||
):
|
||||
super().__init__(data_path, num_frames, frame_interval, image_size, transform_name=None)
|
||||
self.transform_name = transform_name
|
||||
self.data_info = self.data[["num_frames", "height", "width"]].to_numpy().tolist()
|
||||
|
||||
def set_data_info(self, idx, T, H, W):
|
||||
self.data_info[idx] = [T, H, W]
|
||||
|
||||
def get_data_info(self, index):
|
||||
T = self.data.iloc[index]["num_frames"]
|
||||
H = self.data.iloc[index]["height"]
|
||||
W = self.data.iloc[index]["width"]
|
||||
return T, H, W
|
||||
|
||||
def getitem(self, index):
|
||||
sample = self.data.iloc[index]
|
||||
path = sample["path"]
|
||||
text = sample["text"]
|
||||
file_type = self.get_type(path)
|
||||
num_frames, height, width = self.data_info[index]
|
||||
ar = width / height
|
||||
|
||||
if file_type == "video":
|
||||
# loading
|
||||
vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
|
||||
# Sampling video frames
|
||||
video = temporal_random_crop(vframes, num_frames, self.frame_interval)
|
||||
|
||||
# transform
|
||||
transform = get_transforms_video(self.transform_name, (height, width))
|
||||
video = transform(video) # T C H W
|
||||
else:
|
||||
# loading
|
||||
image = pil_loader(path)
|
||||
|
||||
# transform
|
||||
transform = get_transforms_image(self.transform_name, (height, width))
|
||||
image = transform(image)
|
||||
|
||||
# repeat
|
||||
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
|
||||
|
||||
# TCHW -> CTHW
|
||||
video = video.permute(1, 0, 2, 3)
|
||||
return {"video": video, "text": text, "num_frames": num_frames, "height": height, "width": width, "ar": ar}
|
||||
|
||||
def __getitem__(self, index):
|
||||
for _ in range(10):
|
||||
try:
|
||||
return self.getitem(index)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
index = np.random.randint(len(self))
|
||||
raise RuntimeError("Too many bad data.")
|
||||
|
|
|
|||
|
|
@ -1,192 +0,0 @@
|
|||
import random
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
|
||||
|
||||
from opensora.registry import DATASETS
|
||||
|
||||
from .utils import (
|
||||
VID_EXTENSIONS,
|
||||
StatefulDistributedSampler,
|
||||
get_transforms_image,
|
||||
get_transforms_video,
|
||||
temporal_random_crop,
|
||||
)
|
||||
|
||||
|
||||
def closet_smaller_bucket(value, bucket):
|
||||
for i in range(1, len(bucket)):
|
||||
if value < bucket[i]:
|
||||
return bucket[i - 1]
|
||||
return bucket[-1]
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class VariableVideoTextDataset(torch.utils.data.Dataset):
|
||||
"""load video according to the csv file.
|
||||
|
||||
Args:
|
||||
target_video_len (int): the number of video frames will be load.
|
||||
align_transform (callable): Align different videos in a specified size.
|
||||
temporal_sample (callable): Sample the target length of a video.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path,
|
||||
num_frames=None,
|
||||
frame_interval=1,
|
||||
image_size=(256, 256),
|
||||
bucket=None,
|
||||
batch_size_bucket=None,
|
||||
transform_name="center",
|
||||
):
|
||||
self.data_path = data_path
|
||||
self.data = pd.read_csv(data_path)
|
||||
self.batch_size_bucket = batch_size_bucket
|
||||
|
||||
# build bucket
|
||||
self.bucket = bucket
|
||||
num_effect_frames = self.data["num_frames"] // frame_interval
|
||||
self.data = self.data[num_effect_frames >= self.bucket[0]]
|
||||
self.data["bucket"] = num_effect_frames.apply(lambda x: closet_smaller_bucket(x, bucket))
|
||||
gb = self.data.groupby("bucket")
|
||||
self.data_bucket = {x: gb.get_group(x) for x in bucket}
|
||||
self.data_bucket_len = {x: len(self.data_bucket[x]) for x in bucket}
|
||||
print(self.data_bucket_len)
|
||||
|
||||
self.num_frames = num_frames
|
||||
assert self.num_frames is None, "num_frames must be None"
|
||||
self.frame_interval = frame_interval
|
||||
self.image_size = image_size
|
||||
self.transforms = {
|
||||
"image": get_transforms_image(transform_name, image_size),
|
||||
"video": get_transforms_video(transform_name, image_size),
|
||||
}
|
||||
|
||||
def get_type(self, path):
|
||||
ext = path.split(".")[-1]
|
||||
if ext.lower() in VID_EXTENSIONS:
|
||||
return "video"
|
||||
else:
|
||||
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
||||
return "image"
|
||||
|
||||
def get_bucket(self, index):
|
||||
return self.data.iloc[index]["bucket"]
|
||||
|
||||
def getitem(self, index):
|
||||
sample = self.data.iloc[index]
|
||||
path = sample["path"]
|
||||
text = sample["text"]
|
||||
file_type = self.get_type(path)
|
||||
num_frames = self.get_bucket(index)
|
||||
|
||||
if file_type == "video":
|
||||
# loading
|
||||
vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
|
||||
# Sampling video frames
|
||||
video = temporal_random_crop(vframes, num_frames, self.frame_interval)
|
||||
|
||||
# transform
|
||||
transform = self.transforms["video"]
|
||||
video = transform(video) # T C H W
|
||||
else:
|
||||
# loading
|
||||
image = pil_loader(path)
|
||||
|
||||
# transform
|
||||
transform = self.transforms["image"]
|
||||
image = transform(image)
|
||||
|
||||
# repeat
|
||||
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
|
||||
|
||||
# TCHW -> CTHW
|
||||
video = video.permute(1, 0, 2, 3)
|
||||
return {"video": video, "text": text, "num_frames": num_frames}
|
||||
|
||||
def __getitem__(self, index):
|
||||
for _ in range(10):
|
||||
try:
|
||||
return self.getitem(index)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
index = np.random.randint(len(self))
|
||||
raise RuntimeError("Too many bad data.")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class VariableVideoBatchSampler(torch.utils.data.BatchSampler):
|
||||
def __init__(self, sampler, dataset, batch_size, bucket, batch_size_bucket, drop_last=False):
|
||||
self.sampler = sampler
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.bucket = bucket
|
||||
self.batch_size_bucket = batch_size_bucket
|
||||
self._buckets = {x: [] for x in bucket}
|
||||
self.drop_last = drop_last
|
||||
|
||||
def __iter__(self):
|
||||
for idx in self.sampler:
|
||||
bucket_id = self.dataset.get_bucket(idx)
|
||||
bucket = self._buckets[bucket_id]
|
||||
bucket.append(idx)
|
||||
if len(bucket) >= self.batch_size_bucket[bucket_id]:
|
||||
yield bucket
|
||||
self._buckets[bucket_id] = []
|
||||
|
||||
for bucket in self._buckets.values():
|
||||
if len(bucket) > 0 and not self.drop_last:
|
||||
yield bucket
|
||||
|
||||
|
||||
def prepare_dataloader_with_batchsampler(
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
process_group=None,
|
||||
bucket=None,
|
||||
batch_size_bucket=None,
|
||||
**kwargs,
|
||||
):
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = StatefulDistributedSampler(
|
||||
dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle
|
||||
)
|
||||
batch_sampler = VariableVideoBatchSampler(
|
||||
sampler,
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
bucket=bucket,
|
||||
batch_size_bucket=batch_size_bucket,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
|
|
@ -1,15 +1,9 @@
|
|||
import random
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
|
||||
from torchvision.io import write_video
|
||||
from torchvision.utils import save_image
|
||||
|
|
@ -29,14 +23,16 @@ def temporal_random_crop(vframes, num_frames, frame_interval):
|
|||
return video
|
||||
|
||||
|
||||
def get_transforms_video(name="center", resolution=(256, 256)):
|
||||
if name == "center":
|
||||
assert resolution[0] == resolution[1], "Resolution must be square for center crop"
|
||||
def get_transforms_video(name="center", image_size=(256, 256)):
|
||||
if name is None:
|
||||
return None
|
||||
elif name == "center":
|
||||
assert image_size[0] == image_size[1], "image_size must be square for center crop"
|
||||
transform_video = transforms.Compose(
|
||||
[
|
||||
video_transforms.ToTensorVideo(), # TCHW
|
||||
# video_transforms.RandomHorizontalFlipVideo(),
|
||||
video_transforms.UCFCenterCropVideo(resolution[0]),
|
||||
video_transforms.UCFCenterCropVideo(image_size[0]),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
]
|
||||
)
|
||||
|
|
@ -44,7 +40,7 @@ def get_transforms_video(name="center", resolution=(256, 256)):
|
|||
transform_video = transforms.Compose(
|
||||
[
|
||||
video_transforms.ToTensorVideo(), # TCHW
|
||||
video_transforms.ResizeCrop(resolution),
|
||||
video_transforms.ResizeCrop(image_size),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
]
|
||||
)
|
||||
|
|
@ -54,7 +50,9 @@ def get_transforms_video(name="center", resolution=(256, 256)):
|
|||
|
||||
|
||||
def get_transforms_image(name="center", image_size=(256, 256)):
|
||||
if name == "center":
|
||||
if name is None:
|
||||
return None
|
||||
elif name == "center":
|
||||
assert image_size[0] == image_size[1], "Image size must be square for center crop"
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
|
|
@ -123,89 +121,6 @@ def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1)):
|
|||
return save_path
|
||||
|
||||
|
||||
class StatefulDistributedSampler(DistributedSampler):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
) -> None:
|
||||
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
|
||||
self.start_index: int = 0
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
iterator = super().__iter__()
|
||||
indices = list(iterator)
|
||||
indices = indices[self.start_index :]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples - self.start_index
|
||||
|
||||
def set_start_index(self, start_index: int) -> None:
|
||||
self.start_index = start_index
|
||||
|
||||
|
||||
def prepare_dataloader(
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
|
||||
|
||||
|
||||
Args:
|
||||
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
||||
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
||||
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
||||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
||||
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
||||
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||
|
||||
Returns:
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = StatefulDistributedSampler(
|
||||
dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
|
|
|
|||
|
|
@ -212,6 +212,7 @@ class STDiT2(nn.Module):
|
|||
# support dynamic input
|
||||
self.patch_size = patch_size
|
||||
self.input_size = input_size
|
||||
self.base_size = (input_size[1] / patch_size[1] * input_size[2] / patch_size[2]) ** 0.5
|
||||
|
||||
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
|
|
@ -280,7 +281,7 @@ class STDiT2(nn.Module):
|
|||
W = W // self.patch_size[2]
|
||||
return (T, H, W)
|
||||
|
||||
def forward(self, x, timestep, y, mask=None, x_mask=None, num_frames=None):
|
||||
def forward(self, x, timestep, y, mask=None, x_mask=None, num_frames=None, height=None, width=None, ar=None):
|
||||
"""
|
||||
Forward pass of STDiT.
|
||||
Args:
|
||||
|
|
@ -297,26 +298,38 @@ class STDiT2(nn.Module):
|
|||
timestep = timestep.to(self.dtype)
|
||||
y = y.to(self.dtype)
|
||||
|
||||
# TODO: hard-coded for now
|
||||
hw = torch.tensor([self.input_size[1], self.input_size[2]], device=x.device, dtype=x.dtype).repeat(B, 1)
|
||||
ar = torch.tensor([[self.input_size[1] / self.input_size[2]]], device=x.device, dtype=x.dtype).repeat(B, 1)
|
||||
# === process data info ===
|
||||
# 1. get dynamic size
|
||||
if height is None or width is None:
|
||||
hw = torch.tensor([self.input_size[1], self.input_size[2]], device=x.device, dtype=x.dtype).repeat(B, 1)
|
||||
else:
|
||||
hw = torch.cat([height[:, None], width[:, None]], dim=1)
|
||||
csize = self.csize_embedder(hw, B)
|
||||
|
||||
# 2. get aspect ratio
|
||||
if ar is None:
|
||||
ar = torch.tensor([[self.input_size[1] / self.input_size[2]]], device=x.device, dtype=x.dtype).repeat(B, 1)
|
||||
else:
|
||||
ar = ar.unsqueeze(1)
|
||||
ar = self.ar_embedder(ar, B)
|
||||
data_info = torch.cat([csize, ar], dim=1)
|
||||
|
||||
# 3. get number of frames
|
||||
if num_frames is None:
|
||||
num_frames = torch.tensor([x.shape[2]], device=x.device, dtype=x.dtype)
|
||||
fl = num_frames.unsqueeze(1)
|
||||
csize = self.csize_embedder(hw, B)
|
||||
ar = self.ar_embedder(ar, B)
|
||||
fl = self.fl_embedder(fl, B)
|
||||
data_info = torch.cat([csize, ar], dim=1)
|
||||
|
||||
# Dynamic
|
||||
# === get dynamic shape size ===
|
||||
_, _, Tx, Hx, Wx = x.size()
|
||||
T, H, W = self.get_dynamic_size(x)
|
||||
S = H * W
|
||||
pos_emb = self.get_spatial_pos_embed(H, W, self.base_size).to(x.device, x.dtype)
|
||||
|
||||
# embedding
|
||||
x = self.x_embedder(x) # [B, N, C]
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
x = x + self.get_spatial_pos_embed(H, W).to(x.device, x.dtype)
|
||||
x = x + pos_emb
|
||||
x = rearrange(x, "B T S C -> B (T S) C")
|
||||
|
||||
# shard over the sequence dim if sp is enabled
|
||||
|
|
@ -419,11 +432,12 @@ class STDiT2(nn.Module):
|
|||
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
|
||||
return imgs
|
||||
|
||||
def get_spatial_pos_embed(self, H, W):
|
||||
def get_spatial_pos_embed(self, H, W, base_size=None):
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
self.hidden_size,
|
||||
(H, W),
|
||||
scale=self.space_scale,
|
||||
base_size=base_size,
|
||||
)
|
||||
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
|
||||
return pos_embed
|
||||
|
|
|
|||
|
|
@ -71,6 +71,8 @@ def merge_args(cfg, args, training=False):
|
|||
cfg["mask_ratios"] = None
|
||||
if "transform_name" not in cfg.dataset:
|
||||
cfg.dataset["transform_name"] = "center"
|
||||
if "bucket_config" not in cfg:
|
||||
cfg["bucket_config"] = None
|
||||
|
||||
# Both training and inference
|
||||
if "multi_resolution" not in cfg:
|
||||
|
|
|
|||
|
|
@ -18,8 +18,7 @@ from opensora.acceleration.parallel_states import (
|
|||
set_sequence_parallel_group,
|
||||
)
|
||||
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
|
||||
from opensora.datasets import prepare_dataloader
|
||||
from opensora.datasets.datasets_variable import prepare_dataloader_with_batchsampler
|
||||
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
|
||||
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save
|
||||
from opensora.utils.config_utils import (
|
||||
|
|
@ -91,39 +90,22 @@ def main():
|
|||
# 3. build dataset and dataloader
|
||||
# ======================================================
|
||||
dataset = build_module(cfg.dataset, DATASETS)
|
||||
|
||||
dataloader_args = dict(
|
||||
dataset=dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
num_workers=cfg.num_workers,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
)
|
||||
# TODO: use plugin's prepare dataloader
|
||||
# a batch contains:
|
||||
# {
|
||||
# "video": torch.Tensor, # [B, C, T, H, W],
|
||||
# "text": List[str],
|
||||
# }
|
||||
if cfg.dataset.type == "VideoTextDataset":
|
||||
dataloader = prepare_dataloader(
|
||||
dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
num_workers=cfg.num_workers,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
)
|
||||
if cfg.bucket_config is None:
|
||||
dataloader = prepare_dataloader(**dataloader_args)
|
||||
else:
|
||||
dataloader = prepare_dataloader_with_batchsampler(
|
||||
dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
batch_size_bucket=dataset.batch_size_bucket,
|
||||
bucket=dataset.bucket,
|
||||
num_workers=cfg.num_workers,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
)
|
||||
print(dataset.batch_size_bucket)
|
||||
logger.info(f"Dataset contains {len(dataset):,} videos ({dataset.data_path})")
|
||||
|
||||
dataloader = prepare_variable_dataloader(bucket_config=cfg.bucket_config, **dataloader_args)
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
|
||||
logger.info(f"Dataset contains {len(dataset):,} videos ({dataset.data_path})")
|
||||
logger.info(f"Total batch size: {total_batch_size}")
|
||||
|
||||
# ======================================================
|
||||
|
|
@ -219,9 +201,8 @@ def main():
|
|||
) as pbar:
|
||||
for step in pbar:
|
||||
batch = next(dataloader_iter)
|
||||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
y = batch["text"]
|
||||
|
||||
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
|
||||
y = batch.pop("text")
|
||||
# Visual and text encoding
|
||||
with torch.no_grad():
|
||||
# Prepare visual inputs
|
||||
|
|
@ -237,8 +218,8 @@ def main():
|
|||
mask = None
|
||||
|
||||
# Video info
|
||||
if "num_frames" in batch:
|
||||
model_args["num_frames"] = batch["num_frames"].to(device, dtype)
|
||||
for k, v in batch.items():
|
||||
model_args[k] = v.to(device, dtype)
|
||||
|
||||
# Diffusion
|
||||
t = torch.randint(0, scheduler.num_timesteps, (x.shape[0],), device=device)
|
||||
|
|
|
|||
Loading…
Reference in a new issue