[feat] dynamic for video (base_size not completed)

This commit is contained in:
Zangwei Zheng 2024-03-28 17:58:36 +08:00
parent 94e686177e
commit cba02e8a58
11 changed files with 437 additions and 344 deletions

View file

@ -4,14 +4,18 @@ dataset = dict(
data_path=None,
num_frames=None,
frame_interval=3,
image_size=(360, 480),
image_size=(360, 480), # base size
transform_name="resize_crop",
bucket=[24, 48, 72],
batch_size_bucket={24: 12, 48: 6, 72: 4},
)
bucket_config = {
"240p": {1: (1.0, 128), 16: (1.0, 4), 48: (1.0, 4), 72: (1.0, 2)},
"360p": {1: (1.0, 64), 16: (0.5, 2), 24: (0.5, 2), 72: (0.0, None)},
"720p": {1: (1.0, 32), 16: (0.5, 1), 72: (0.0, None)},
"1080p": {1: (1.0, 16)},
}
# Define acceleration
num_workers = 4
num_workers = 0
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
@ -51,6 +55,6 @@ log_every = 10
ckpt_every = 1000
load = None
batch_size = 10 # only for logging
batch_size = 10 # only for logging
lr = 2e-5
grad_clip = 1.0

View file

@ -1,3 +1,3 @@
from .datasets import VideoTextDataset
from .datasets_variable import VariableVideoTextDataset
from .utils import get_transforms_image, get_transforms_video, prepare_dataloader, save_sample
from .dataloader import prepare_dataloader, prepare_variable_dataloader
from .datasets import VariableVideoTextDataset, VideoTextDataset
from .utils import get_transforms_image, get_transforms_video, save_sample

View file

@ -253,7 +253,7 @@ ASPECT_RATIO_256 = {
def get_closest_ratio(height: float, width: float, ratios: dict):
aspect_ratio = height / width
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
return ratios[closest_ratio], float(closest_ratio)
return closest_ratio
ASPECT_RATIOS = {

View file

@ -0,0 +1,303 @@
import random
from collections import OrderedDict
from typing import Iterator, Optional
import numpy as np
import torch
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from .aspect import ASPECT_RATIOS, get_closest_ratio
class StatefulDistributedSampler(DistributedSampler):
def __init__(
self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
self.start_index: int = 0
def __iter__(self) -> Iterator:
iterator = super().__iter__()
indices = list(iterator)
indices = indices[self.start_index :]
return iter(indices)
def __len__(self) -> int:
return self.num_samples - self.start_index
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index
def prepare_dataloader(
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
process_group: Optional[ProcessGroup] = None,
**kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(
dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
def find_approximate_hw(hw, hw_dict, approx=0.8):
for k, v in hw_dict.items():
if hw >= v * approx:
return k
return None
def find_closet_smaller_bucket(t, t_dict, frame_interval):
# process image
if t == 1:
if 1 in t_dict:
return 1
else:
return None
# process video
for k, v in t_dict.items():
if t >= v * frame_interval and v != 1:
return k
return None
class Bucket:
def __init__(self, bucket_config):
for key in bucket_config:
assert key in ASPECT_RATIOS, f"Aspect ratio {key} not found."
# wrap config with OrderedDict
bucket_probs = OrderedDict()
bucket_bs = OrderedDict()
bucket_names = sorted(bucket_config.keys(), key=lambda x: ASPECT_RATIOS[x][0], reverse=True)
for key in bucket_names:
bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True)
bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names})
bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names})
# first level: HW
num_bucket = 0
bucket = dict()
hw_criteria = dict()
t_criteria = dict()
ar_criteria = dict()
for k1, v1 in bucket_probs.items():
bucket[k1] = dict()
hw_criteria[k1] = ASPECT_RATIOS[k1][0]
t_criteria[k1] = dict()
ar_criteria[k1] = dict()
for k2, _ in v1.items():
bucket[k1][k2] = dict()
t_criteria[k1][k2] = k2
ar_criteria[k1][k2] = dict()
for k3, v3 in ASPECT_RATIOS[k1][1].items():
bucket[k1][k2][k3] = []
ar_criteria[k1][k2][k3] = v3
num_bucket += 1
self.bucket_probs = bucket_probs
self.bucket_bs = bucket_bs
self.bucket = bucket
self.hw_criteria = hw_criteria
self.t_criteria = t_criteria
self.ar_criteria = ar_criteria
self.num_bucket = num_bucket
print(f"Number of buckets: {num_bucket}")
def info_bucket(self, dataset, frame_interval=1):
infos = dict()
for i in range(len(dataset)):
T, H, W = dataset.get_data_info(i)
bucket_id = self.get_bucket_id(T, H, W, frame_interval)
if bucket_id is None:
continue
if f"{(bucket_id[0], bucket_id[1])}" not in infos:
infos[f"{(bucket_id[0], bucket_id[1])}"] = 0
infos[f"{(bucket_id[0], bucket_id[1])}"] += 1
print(f"Dataset contains {len(dataset)} samples.")
print("Bucket info:", infos)
def get_bucket_id(self, T, H, W, frame_interval=1):
# hw
hw = H * W
hw_id = find_approximate_hw(hw, self.hw_criteria)
if hw_id is None:
return None
# hw drops by probablity
while True:
# T
T_id = find_closet_smaller_bucket(T, self.t_criteria[hw_id], frame_interval)
if T_id is not None:
prob = self.get_prob((hw_id, T_id))
if random.random() < prob:
break
hw_id = list(self.hw_criteria.keys()).index(hw_id)
if hw_id == len(self.hw_criteria) - 1:
break
hw_id = list(self.hw_criteria.keys())[hw_id + 1]
if T_id is None:
return None
# ar
ar_criteria = self.ar_criteria[hw_id][T_id]
ar_id = get_closest_ratio(H, W, ar_criteria)
return hw_id, T_id, ar_id
def get_thw(self, bucket_id):
assert len(bucket_id) == 3
T = self.t_criteria[bucket_id[0]][bucket_id[1]]
H, W = self.ar_criteria[bucket_id[0]][bucket_id[1]][bucket_id[2]]
return T, H, W
def get_prob(self, bucket_id):
return self.bucket_probs[bucket_id[0]][bucket_id[1]]
def get_batch_size(self, bucket_id):
return self.bucket_bs[bucket_id[0]][bucket_id[1]]
def __getitem__(self, index):
assert len(index) == 3
return self.bucket[index[0]][index[1]][index[2]]
def set_empty(self, index):
assert len(index) == 3
self.bucket[index[0]][index[1]][index[2]] = []
def __len__(self):
return self.num_bucket
def closet_smaller_bucket(value, bucket):
for i in range(1, len(bucket)):
if value < bucket[i]:
return bucket[i - 1]
return bucket[-1]
class VariableVideoBatchSampler(torch.utils.data.BatchSampler):
def __init__(self, sampler, batch_size, drop_last, dataset, buckect_config):
self.sampler = sampler
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.bucket = Bucket(buckect_config)
self.frame_interval = self.dataset.frame_interval
self.bucket.info_bucket(self.dataset, self.frame_interval)
def __iter__(self):
for idx in self.sampler:
T, H, W = self.dataset.get_data_info(idx)
bucket_id = self.bucket.get_bucket_id(T, H, W, self.frame_interval)
if bucket_id is None:
continue
rT, rH, rW = self.bucket.get_thw(bucket_id)
self.dataset.set_data_info(idx, rT, rH, rW)
buffer = self.bucket[bucket_id]
buffer.append(idx)
if len(buffer) >= self.bucket.get_batch_size(bucket_id):
yield buffer
self.bucket.set_empty(bucket_id)
for k1, v1 in self.bucket.bucket.items():
for k2, v2 in v1.items():
for k3, buffer in v2.items():
if len(buffer) > 0 and not self.drop_last:
yield buffer
self.bucket.set_empty((k1, k2, k3))
def prepare_variable_dataloader(
dataset,
batch_size,
bucket_config,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
process_group=None,
**kwargs,
):
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(
dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle
)
batch_sampler = VariableVideoBatchSampler(
sampler=sampler,
batch_size=batch_size,
drop_last=drop_last,
dataset=dataset,
buckect_config=bucket_config,
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
worker_init_fn=seed_worker,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)

View file

@ -87,3 +87,69 @@ class VideoTextDataset(torch.utils.data.Dataset):
def __len__(self):
return len(self.data)
@DATASETS.register_module()
class VariableVideoTextDataset(VideoTextDataset):
def __init__(
self,
data_path,
num_frames=None,
frame_interval=1,
image_size=None,
transform_name=None,
):
super().__init__(data_path, num_frames, frame_interval, image_size, transform_name=None)
self.transform_name = transform_name
self.data_info = self.data[["num_frames", "height", "width"]].to_numpy().tolist()
def set_data_info(self, idx, T, H, W):
self.data_info[idx] = [T, H, W]
def get_data_info(self, index):
T = self.data.iloc[index]["num_frames"]
H = self.data.iloc[index]["height"]
W = self.data.iloc[index]["width"]
return T, H, W
def getitem(self, index):
sample = self.data.iloc[index]
path = sample["path"]
text = sample["text"]
file_type = self.get_type(path)
num_frames, height, width = self.data_info[index]
ar = width / height
if file_type == "video":
# loading
vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
# Sampling video frames
video = temporal_random_crop(vframes, num_frames, self.frame_interval)
# transform
transform = get_transforms_video(self.transform_name, (height, width))
video = transform(video) # T C H W
else:
# loading
image = pil_loader(path)
# transform
transform = get_transforms_image(self.transform_name, (height, width))
image = transform(image)
# repeat
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text, "num_frames": num_frames, "height": height, "width": width, "ar": ar}
def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
print(e)
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")

View file

@ -1,192 +0,0 @@
import random
import numpy as np
import pandas as pd
import torch
import torchvision
from torch.distributed.distributed_c10d import _get_default_group
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from opensora.registry import DATASETS
from .utils import (
VID_EXTENSIONS,
StatefulDistributedSampler,
get_transforms_image,
get_transforms_video,
temporal_random_crop,
)
def closet_smaller_bucket(value, bucket):
for i in range(1, len(bucket)):
if value < bucket[i]:
return bucket[i - 1]
return bucket[-1]
@DATASETS.register_module()
class VariableVideoTextDataset(torch.utils.data.Dataset):
"""load video according to the csv file.
Args:
target_video_len (int): the number of video frames will be load.
align_transform (callable): Align different videos in a specified size.
temporal_sample (callable): Sample the target length of a video.
"""
def __init__(
self,
data_path,
num_frames=None,
frame_interval=1,
image_size=(256, 256),
bucket=None,
batch_size_bucket=None,
transform_name="center",
):
self.data_path = data_path
self.data = pd.read_csv(data_path)
self.batch_size_bucket = batch_size_bucket
# build bucket
self.bucket = bucket
num_effect_frames = self.data["num_frames"] // frame_interval
self.data = self.data[num_effect_frames >= self.bucket[0]]
self.data["bucket"] = num_effect_frames.apply(lambda x: closet_smaller_bucket(x, bucket))
gb = self.data.groupby("bucket")
self.data_bucket = {x: gb.get_group(x) for x in bucket}
self.data_bucket_len = {x: len(self.data_bucket[x]) for x in bucket}
print(self.data_bucket_len)
self.num_frames = num_frames
assert self.num_frames is None, "num_frames must be None"
self.frame_interval = frame_interval
self.image_size = image_size
self.transforms = {
"image": get_transforms_image(transform_name, image_size),
"video": get_transforms_video(transform_name, image_size),
}
def get_type(self, path):
ext = path.split(".")[-1]
if ext.lower() in VID_EXTENSIONS:
return "video"
else:
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
return "image"
def get_bucket(self, index):
return self.data.iloc[index]["bucket"]
def getitem(self, index):
sample = self.data.iloc[index]
path = sample["path"]
text = sample["text"]
file_type = self.get_type(path)
num_frames = self.get_bucket(index)
if file_type == "video":
# loading
vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
# Sampling video frames
video = temporal_random_crop(vframes, num_frames, self.frame_interval)
# transform
transform = self.transforms["video"]
video = transform(video) # T C H W
else:
# loading
image = pil_loader(path)
# transform
transform = self.transforms["image"]
image = transform(image)
# repeat
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1)
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text, "num_frames": num_frames}
def __getitem__(self, index):
for _ in range(10):
try:
return self.getitem(index)
except Exception as e:
print(e)
index = np.random.randint(len(self))
raise RuntimeError("Too many bad data.")
def __len__(self):
return len(self.data)
class VariableVideoBatchSampler(torch.utils.data.BatchSampler):
def __init__(self, sampler, dataset, batch_size, bucket, batch_size_bucket, drop_last=False):
self.sampler = sampler
self.dataset = dataset
self.batch_size = batch_size
self.bucket = bucket
self.batch_size_bucket = batch_size_bucket
self._buckets = {x: [] for x in bucket}
self.drop_last = drop_last
def __iter__(self):
for idx in self.sampler:
bucket_id = self.dataset.get_bucket(idx)
bucket = self._buckets[bucket_id]
bucket.append(idx)
if len(bucket) >= self.batch_size_bucket[bucket_id]:
yield bucket
self._buckets[bucket_id] = []
for bucket in self._buckets.values():
if len(bucket) > 0 and not self.drop_last:
yield bucket
def prepare_dataloader_with_batchsampler(
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
process_group=None,
bucket=None,
batch_size_bucket=None,
**kwargs,
):
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(
dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle
)
batch_sampler = VariableVideoBatchSampler(
sampler,
dataset,
batch_size=batch_size,
bucket=bucket,
batch_size_bucket=batch_size_bucket,
drop_last=drop_last,
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
worker_init_fn=seed_worker,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)

View file

@ -1,15 +1,9 @@
import random
from typing import Iterator, Optional
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from torchvision.io import write_video
from torchvision.utils import save_image
@ -29,14 +23,16 @@ def temporal_random_crop(vframes, num_frames, frame_interval):
return video
def get_transforms_video(name="center", resolution=(256, 256)):
if name == "center":
assert resolution[0] == resolution[1], "Resolution must be square for center crop"
def get_transforms_video(name="center", image_size=(256, 256)):
if name is None:
return None
elif name == "center":
assert image_size[0] == image_size[1], "image_size must be square for center crop"
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
# video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(resolution[0]),
video_transforms.UCFCenterCropVideo(image_size[0]),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
@ -44,7 +40,7 @@ def get_transforms_video(name="center", resolution=(256, 256)):
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.ResizeCrop(resolution),
video_transforms.ResizeCrop(image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
@ -54,7 +50,9 @@ def get_transforms_video(name="center", resolution=(256, 256)):
def get_transforms_image(name="center", image_size=(256, 256)):
if name == "center":
if name is None:
return None
elif name == "center":
assert image_size[0] == image_size[1], "Image size must be square for center crop"
transform = transforms.Compose(
[
@ -123,89 +121,6 @@ def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1)):
return save_path
class StatefulDistributedSampler(DistributedSampler):
def __init__(
self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
self.start_index: int = 0
def __iter__(self) -> Iterator:
iterator = super().__iter__()
indices = list(iterator)
indices = indices[self.start_index :]
return iter(indices)
def __len__(self) -> int:
return self.num_samples - self.start_index
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index
def prepare_dataloader(
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
process_group: Optional[ProcessGroup] = None,
**kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(
dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.

View file

@ -212,6 +212,7 @@ class STDiT2(nn.Module):
# support dynamic input
self.patch_size = patch_size
self.input_size = input_size
self.base_size = (input_size[1] / patch_size[1] * input_size[2] / patch_size[2]) ** 0.5
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
@ -280,7 +281,7 @@ class STDiT2(nn.Module):
W = W // self.patch_size[2]
return (T, H, W)
def forward(self, x, timestep, y, mask=None, x_mask=None, num_frames=None):
def forward(self, x, timestep, y, mask=None, x_mask=None, num_frames=None, height=None, width=None, ar=None):
"""
Forward pass of STDiT.
Args:
@ -297,26 +298,38 @@ class STDiT2(nn.Module):
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
# TODO: hard-coded for now
hw = torch.tensor([self.input_size[1], self.input_size[2]], device=x.device, dtype=x.dtype).repeat(B, 1)
ar = torch.tensor([[self.input_size[1] / self.input_size[2]]], device=x.device, dtype=x.dtype).repeat(B, 1)
# === process data info ===
# 1. get dynamic size
if height is None or width is None:
hw = torch.tensor([self.input_size[1], self.input_size[2]], device=x.device, dtype=x.dtype).repeat(B, 1)
else:
hw = torch.cat([height[:, None], width[:, None]], dim=1)
csize = self.csize_embedder(hw, B)
# 2. get aspect ratio
if ar is None:
ar = torch.tensor([[self.input_size[1] / self.input_size[2]]], device=x.device, dtype=x.dtype).repeat(B, 1)
else:
ar = ar.unsqueeze(1)
ar = self.ar_embedder(ar, B)
data_info = torch.cat([csize, ar], dim=1)
# 3. get number of frames
if num_frames is None:
num_frames = torch.tensor([x.shape[2]], device=x.device, dtype=x.dtype)
fl = num_frames.unsqueeze(1)
csize = self.csize_embedder(hw, B)
ar = self.ar_embedder(ar, B)
fl = self.fl_embedder(fl, B)
data_info = torch.cat([csize, ar], dim=1)
# Dynamic
# === get dynamic shape size ===
_, _, Tx, Hx, Wx = x.size()
T, H, W = self.get_dynamic_size(x)
S = H * W
pos_emb = self.get_spatial_pos_embed(H, W, self.base_size).to(x.device, x.dtype)
# embedding
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
x = x + self.get_spatial_pos_embed(H, W).to(x.device, x.dtype)
x = x + pos_emb
x = rearrange(x, "B T S C -> B (T S) C")
# shard over the sequence dim if sp is enabled
@ -419,11 +432,12 @@ class STDiT2(nn.Module):
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_spatial_pos_embed(self, H, W):
def get_spatial_pos_embed(self, H, W, base_size=None):
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
(H, W),
scale=self.space_scale,
base_size=base_size,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed

View file

@ -71,6 +71,8 @@ def merge_args(cfg, args, training=False):
cfg["mask_ratios"] = None
if "transform_name" not in cfg.dataset:
cfg.dataset["transform_name"] = "center"
if "bucket_config" not in cfg:
cfg["bucket_config"] = None
# Both training and inference
if "multi_resolution" not in cfg:

View file

@ -18,8 +18,7 @@ from opensora.acceleration.parallel_states import (
set_sequence_parallel_group,
)
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
from opensora.datasets import prepare_dataloader
from opensora.datasets.datasets_variable import prepare_dataloader_with_batchsampler
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save
from opensora.utils.config_utils import (
@ -91,39 +90,22 @@ def main():
# 3. build dataset and dataloader
# ======================================================
dataset = build_module(cfg.dataset, DATASETS)
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
)
# TODO: use plugin's prepare dataloader
# a batch contains:
# {
# "video": torch.Tensor, # [B, C, T, H, W],
# "text": List[str],
# }
if cfg.dataset.type == "VideoTextDataset":
dataloader = prepare_dataloader(
dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
)
if cfg.bucket_config is None:
dataloader = prepare_dataloader(**dataloader_args)
else:
dataloader = prepare_dataloader_with_batchsampler(
dataset,
batch_size=cfg.batch_size,
batch_size_bucket=dataset.batch_size_bucket,
bucket=dataset.bucket,
num_workers=cfg.num_workers,
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
)
print(dataset.batch_size_bucket)
logger.info(f"Dataset contains {len(dataset):,} videos ({dataset.data_path})")
dataloader = prepare_variable_dataloader(bucket_config=cfg.bucket_config, **dataloader_args)
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
logger.info(f"Dataset contains {len(dataset):,} videos ({dataset.data_path})")
logger.info(f"Total batch size: {total_batch_size}")
# ======================================================
@ -219,9 +201,8 @@ def main():
) as pbar:
for step in pbar:
batch = next(dataloader_iter)
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
y = batch["text"]
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
# Visual and text encoding
with torch.no_grad():
# Prepare visual inputs
@ -237,8 +218,8 @@ def main():
mask = None
# Video info
if "num_frames" in batch:
model_args["num_frames"] = batch["num_frames"].to(device, dtype)
for k, v in batch.items():
model_args[k] = v.to(device, dtype)
# Diffusion
t = torch.randint(0, scheduler.num_timesteps, (x.shape[0],), device=device)