From cba02e8a5872480f4998db1f6e25a8f9af01f515 Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Thu, 28 Mar 2024 17:58:36 +0800 Subject: [PATCH] [feat] dynamic for video (base_size not completed) --- .../train/{Vx360x480.py => Vx360p.py} | 14 +- .../train/{Tx360x480.py => test.py} | 0 opensora/datasets/__init__.py | 6 +- opensora/datasets/aspect.py | 2 +- opensora/datasets/dataloader.py | 303 ++++++++++++++++++ opensora/datasets/datasets.py | 66 ++++ opensora/datasets/datasets_variable.py | 192 ----------- opensora/datasets/utils.py | 107 +------ opensora/models/stdit/stdit2.py | 34 +- opensora/utils/config_utils.py | 2 + scripts/train.py | 55 ++-- 11 files changed, 437 insertions(+), 344 deletions(-) rename configs/opensora-v1-1/train/{Vx360x480.py => Vx360p.py} (72%) rename configs/opensora-v1-1/train/{Tx360x480.py => test.py} (100%) create mode 100644 opensora/datasets/dataloader.py delete mode 100644 opensora/datasets/datasets_variable.py diff --git a/configs/opensora-v1-1/train/Vx360x480.py b/configs/opensora-v1-1/train/Vx360p.py similarity index 72% rename from configs/opensora-v1-1/train/Vx360x480.py rename to configs/opensora-v1-1/train/Vx360p.py index 1b138b4..4c7ea62 100644 --- a/configs/opensora-v1-1/train/Vx360x480.py +++ b/configs/opensora-v1-1/train/Vx360p.py @@ -4,14 +4,18 @@ dataset = dict( data_path=None, num_frames=None, frame_interval=3, - image_size=(360, 480), + image_size=(360, 480), # base size transform_name="resize_crop", - bucket=[24, 48, 72], - batch_size_bucket={24: 12, 48: 6, 72: 4}, ) +bucket_config = { + "240p": {1: (1.0, 128), 16: (1.0, 4), 48: (1.0, 4), 72: (1.0, 2)}, + "360p": {1: (1.0, 64), 16: (0.5, 2), 24: (0.5, 2), 72: (0.0, None)}, + "720p": {1: (1.0, 32), 16: (0.5, 1), 72: (0.0, None)}, + "1080p": {1: (1.0, 16)}, +} # Define acceleration -num_workers = 4 +num_workers = 0 dtype = "bf16" grad_checkpoint = True plugin = "zero2" @@ -51,6 +55,6 @@ log_every = 10 ckpt_every = 1000 load = None -batch_size = 10 # only for logging +batch_size = 10 # only for logging lr = 2e-5 grad_clip = 1.0 diff --git a/configs/opensora-v1-1/train/Tx360x480.py b/configs/opensora-v1-1/train/test.py similarity index 100% rename from configs/opensora-v1-1/train/Tx360x480.py rename to configs/opensora-v1-1/train/test.py diff --git a/opensora/datasets/__init__.py b/opensora/datasets/__init__.py index a81891f..c096ef9 100644 --- a/opensora/datasets/__init__.py +++ b/opensora/datasets/__init__.py @@ -1,3 +1,3 @@ -from .datasets import VideoTextDataset -from .datasets_variable import VariableVideoTextDataset -from .utils import get_transforms_image, get_transforms_video, prepare_dataloader, save_sample +from .dataloader import prepare_dataloader, prepare_variable_dataloader +from .datasets import VariableVideoTextDataset, VideoTextDataset +from .utils import get_transforms_image, get_transforms_video, save_sample diff --git a/opensora/datasets/aspect.py b/opensora/datasets/aspect.py index db078c7..9191ef2 100644 --- a/opensora/datasets/aspect.py +++ b/opensora/datasets/aspect.py @@ -253,7 +253,7 @@ ASPECT_RATIO_256 = { def get_closest_ratio(height: float, width: float, ratios: dict): aspect_ratio = height / width closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) - return ratios[closest_ratio], float(closest_ratio) + return closest_ratio ASPECT_RATIOS = { diff --git a/opensora/datasets/dataloader.py b/opensora/datasets/dataloader.py new file mode 100644 index 0000000..1a6d466 --- /dev/null +++ b/opensora/datasets/dataloader.py @@ -0,0 +1,303 @@ +import random +from collections import OrderedDict +from typing import Iterator, Optional + +import numpy as np +import torch +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_default_group +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + +from .aspect import ASPECT_RATIOS, get_closest_ratio + + +class StatefulDistributedSampler(DistributedSampler): + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) + self.start_index: int = 0 + + def __iter__(self) -> Iterator: + iterator = super().__iter__() + indices = list(iterator) + indices = indices[self.start_index :] + return iter(indices) + + def __len__(self) -> int: + return self.num_samples - self.start_index + + def set_start_index(self, start_index: int) -> None: + self.start_index = start_index + + +def prepare_dataloader( + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + process_group: Optional[ProcessGroup] = None, + **kwargs, +): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `StatefulDistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + process_group = process_group or _get_default_group() + sampler = StatefulDistributedSampler( + dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle + ) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) + + +def find_approximate_hw(hw, hw_dict, approx=0.8): + for k, v in hw_dict.items(): + if hw >= v * approx: + return k + return None + + +def find_closet_smaller_bucket(t, t_dict, frame_interval): + # process image + if t == 1: + if 1 in t_dict: + return 1 + else: + return None + # process video + for k, v in t_dict.items(): + if t >= v * frame_interval and v != 1: + return k + return None + + +class Bucket: + def __init__(self, bucket_config): + for key in bucket_config: + assert key in ASPECT_RATIOS, f"Aspect ratio {key} not found." + # wrap config with OrderedDict + bucket_probs = OrderedDict() + bucket_bs = OrderedDict() + bucket_names = sorted(bucket_config.keys(), key=lambda x: ASPECT_RATIOS[x][0], reverse=True) + for key in bucket_names: + bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True) + bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names}) + bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names}) + + # first level: HW + num_bucket = 0 + bucket = dict() + hw_criteria = dict() + t_criteria = dict() + ar_criteria = dict() + for k1, v1 in bucket_probs.items(): + bucket[k1] = dict() + hw_criteria[k1] = ASPECT_RATIOS[k1][0] + t_criteria[k1] = dict() + ar_criteria[k1] = dict() + for k2, _ in v1.items(): + bucket[k1][k2] = dict() + t_criteria[k1][k2] = k2 + ar_criteria[k1][k2] = dict() + for k3, v3 in ASPECT_RATIOS[k1][1].items(): + bucket[k1][k2][k3] = [] + ar_criteria[k1][k2][k3] = v3 + num_bucket += 1 + + self.bucket_probs = bucket_probs + self.bucket_bs = bucket_bs + self.bucket = bucket + self.hw_criteria = hw_criteria + self.t_criteria = t_criteria + self.ar_criteria = ar_criteria + self.num_bucket = num_bucket + print(f"Number of buckets: {num_bucket}") + + def info_bucket(self, dataset, frame_interval=1): + infos = dict() + for i in range(len(dataset)): + T, H, W = dataset.get_data_info(i) + bucket_id = self.get_bucket_id(T, H, W, frame_interval) + if bucket_id is None: + continue + if f"{(bucket_id[0], bucket_id[1])}" not in infos: + infos[f"{(bucket_id[0], bucket_id[1])}"] = 0 + infos[f"{(bucket_id[0], bucket_id[1])}"] += 1 + print(f"Dataset contains {len(dataset)} samples.") + print("Bucket info:", infos) + + def get_bucket_id(self, T, H, W, frame_interval=1): + # hw + hw = H * W + hw_id = find_approximate_hw(hw, self.hw_criteria) + if hw_id is None: + return None + + # hw drops by probablity + while True: + # T + T_id = find_closet_smaller_bucket(T, self.t_criteria[hw_id], frame_interval) + if T_id is not None: + prob = self.get_prob((hw_id, T_id)) + if random.random() < prob: + break + hw_id = list(self.hw_criteria.keys()).index(hw_id) + if hw_id == len(self.hw_criteria) - 1: + break + hw_id = list(self.hw_criteria.keys())[hw_id + 1] + if T_id is None: + return None + + # ar + ar_criteria = self.ar_criteria[hw_id][T_id] + ar_id = get_closest_ratio(H, W, ar_criteria) + return hw_id, T_id, ar_id + + def get_thw(self, bucket_id): + assert len(bucket_id) == 3 + T = self.t_criteria[bucket_id[0]][bucket_id[1]] + H, W = self.ar_criteria[bucket_id[0]][bucket_id[1]][bucket_id[2]] + return T, H, W + + def get_prob(self, bucket_id): + return self.bucket_probs[bucket_id[0]][bucket_id[1]] + + def get_batch_size(self, bucket_id): + return self.bucket_bs[bucket_id[0]][bucket_id[1]] + + def __getitem__(self, index): + assert len(index) == 3 + return self.bucket[index[0]][index[1]][index[2]] + + def set_empty(self, index): + assert len(index) == 3 + self.bucket[index[0]][index[1]][index[2]] = [] + + def __len__(self): + return self.num_bucket + + +def closet_smaller_bucket(value, bucket): + for i in range(1, len(bucket)): + if value < bucket[i]: + return bucket[i - 1] + return bucket[-1] + + +class VariableVideoBatchSampler(torch.utils.data.BatchSampler): + def __init__(self, sampler, batch_size, drop_last, dataset, buckect_config): + self.sampler = sampler + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + self.bucket = Bucket(buckect_config) + self.frame_interval = self.dataset.frame_interval + self.bucket.info_bucket(self.dataset, self.frame_interval) + + def __iter__(self): + for idx in self.sampler: + T, H, W = self.dataset.get_data_info(idx) + bucket_id = self.bucket.get_bucket_id(T, H, W, self.frame_interval) + if bucket_id is None: + continue + rT, rH, rW = self.bucket.get_thw(bucket_id) + self.dataset.set_data_info(idx, rT, rH, rW) + buffer = self.bucket[bucket_id] + buffer.append(idx) + if len(buffer) >= self.bucket.get_batch_size(bucket_id): + yield buffer + self.bucket.set_empty(bucket_id) + + for k1, v1 in self.bucket.bucket.items(): + for k2, v2 in v1.items(): + for k3, buffer in v2.items(): + if len(buffer) > 0 and not self.drop_last: + yield buffer + self.bucket.set_empty((k1, k2, k3)) + + +def prepare_variable_dataloader( + dataset, + batch_size, + bucket_config, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + process_group=None, + **kwargs, +): + _kwargs = kwargs.copy() + process_group = process_group or _get_default_group() + sampler = StatefulDistributedSampler( + dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle + ) + batch_sampler = VariableVideoBatchSampler( + sampler=sampler, + batch_size=batch_size, + drop_last=drop_last, + dataset=dataset, + buckect_config=bucket_config, + ) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return torch.utils.data.DataLoader( + dataset, + batch_sampler=batch_sampler, + worker_init_fn=seed_worker, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index c466615..99505c1 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -87,3 +87,69 @@ class VideoTextDataset(torch.utils.data.Dataset): def __len__(self): return len(self.data) + + +@DATASETS.register_module() +class VariableVideoTextDataset(VideoTextDataset): + def __init__( + self, + data_path, + num_frames=None, + frame_interval=1, + image_size=None, + transform_name=None, + ): + super().__init__(data_path, num_frames, frame_interval, image_size, transform_name=None) + self.transform_name = transform_name + self.data_info = self.data[["num_frames", "height", "width"]].to_numpy().tolist() + + def set_data_info(self, idx, T, H, W): + self.data_info[idx] = [T, H, W] + + def get_data_info(self, index): + T = self.data.iloc[index]["num_frames"] + H = self.data.iloc[index]["height"] + W = self.data.iloc[index]["width"] + return T, H, W + + def getitem(self, index): + sample = self.data.iloc[index] + path = sample["path"] + text = sample["text"] + file_type = self.get_type(path) + num_frames, height, width = self.data_info[index] + ar = width / height + + if file_type == "video": + # loading + vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") + + # Sampling video frames + video = temporal_random_crop(vframes, num_frames, self.frame_interval) + + # transform + transform = get_transforms_video(self.transform_name, (height, width)) + video = transform(video) # T C H W + else: + # loading + image = pil_loader(path) + + # transform + transform = get_transforms_image(self.transform_name, (height, width)) + image = transform(image) + + # repeat + video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1) + + # TCHW -> CTHW + video = video.permute(1, 0, 2, 3) + return {"video": video, "text": text, "num_frames": num_frames, "height": height, "width": width, "ar": ar} + + def __getitem__(self, index): + for _ in range(10): + try: + return self.getitem(index) + except Exception as e: + print(e) + index = np.random.randint(len(self)) + raise RuntimeError("Too many bad data.") diff --git a/opensora/datasets/datasets_variable.py b/opensora/datasets/datasets_variable.py deleted file mode 100644 index eea2ea5..0000000 --- a/opensora/datasets/datasets_variable.py +++ /dev/null @@ -1,192 +0,0 @@ -import random - -import numpy as np -import pandas as pd -import torch -import torchvision -from torch.distributed.distributed_c10d import _get_default_group -from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader - -from opensora.registry import DATASETS - -from .utils import ( - VID_EXTENSIONS, - StatefulDistributedSampler, - get_transforms_image, - get_transforms_video, - temporal_random_crop, -) - - -def closet_smaller_bucket(value, bucket): - for i in range(1, len(bucket)): - if value < bucket[i]: - return bucket[i - 1] - return bucket[-1] - - -@DATASETS.register_module() -class VariableVideoTextDataset(torch.utils.data.Dataset): - """load video according to the csv file. - - Args: - target_video_len (int): the number of video frames will be load. - align_transform (callable): Align different videos in a specified size. - temporal_sample (callable): Sample the target length of a video. - """ - - def __init__( - self, - data_path, - num_frames=None, - frame_interval=1, - image_size=(256, 256), - bucket=None, - batch_size_bucket=None, - transform_name="center", - ): - self.data_path = data_path - self.data = pd.read_csv(data_path) - self.batch_size_bucket = batch_size_bucket - - # build bucket - self.bucket = bucket - num_effect_frames = self.data["num_frames"] // frame_interval - self.data = self.data[num_effect_frames >= self.bucket[0]] - self.data["bucket"] = num_effect_frames.apply(lambda x: closet_smaller_bucket(x, bucket)) - gb = self.data.groupby("bucket") - self.data_bucket = {x: gb.get_group(x) for x in bucket} - self.data_bucket_len = {x: len(self.data_bucket[x]) for x in bucket} - print(self.data_bucket_len) - - self.num_frames = num_frames - assert self.num_frames is None, "num_frames must be None" - self.frame_interval = frame_interval - self.image_size = image_size - self.transforms = { - "image": get_transforms_image(transform_name, image_size), - "video": get_transforms_video(transform_name, image_size), - } - - def get_type(self, path): - ext = path.split(".")[-1] - if ext.lower() in VID_EXTENSIONS: - return "video" - else: - assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}" - return "image" - - def get_bucket(self, index): - return self.data.iloc[index]["bucket"] - - def getitem(self, index): - sample = self.data.iloc[index] - path = sample["path"] - text = sample["text"] - file_type = self.get_type(path) - num_frames = self.get_bucket(index) - - if file_type == "video": - # loading - vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") - - # Sampling video frames - video = temporal_random_crop(vframes, num_frames, self.frame_interval) - - # transform - transform = self.transforms["video"] - video = transform(video) # T C H W - else: - # loading - image = pil_loader(path) - - # transform - transform = self.transforms["image"] - image = transform(image) - - # repeat - video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1) - - # TCHW -> CTHW - video = video.permute(1, 0, 2, 3) - return {"video": video, "text": text, "num_frames": num_frames} - - def __getitem__(self, index): - for _ in range(10): - try: - return self.getitem(index) - except Exception as e: - print(e) - index = np.random.randint(len(self)) - raise RuntimeError("Too many bad data.") - - def __len__(self): - return len(self.data) - - -class VariableVideoBatchSampler(torch.utils.data.BatchSampler): - def __init__(self, sampler, dataset, batch_size, bucket, batch_size_bucket, drop_last=False): - self.sampler = sampler - self.dataset = dataset - self.batch_size = batch_size - self.bucket = bucket - self.batch_size_bucket = batch_size_bucket - self._buckets = {x: [] for x in bucket} - self.drop_last = drop_last - - def __iter__(self): - for idx in self.sampler: - bucket_id = self.dataset.get_bucket(idx) - bucket = self._buckets[bucket_id] - bucket.append(idx) - if len(bucket) >= self.batch_size_bucket[bucket_id]: - yield bucket - self._buckets[bucket_id] = [] - - for bucket in self._buckets.values(): - if len(bucket) > 0 and not self.drop_last: - yield bucket - - -def prepare_dataloader_with_batchsampler( - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - process_group=None, - bucket=None, - batch_size_bucket=None, - **kwargs, -): - _kwargs = kwargs.copy() - process_group = process_group or _get_default_group() - sampler = StatefulDistributedSampler( - dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle - ) - batch_sampler = VariableVideoBatchSampler( - sampler, - dataset, - batch_size=batch_size, - bucket=bucket, - batch_size_bucket=batch_size_bucket, - drop_last=drop_last, - ) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return torch.utils.data.DataLoader( - dataset, - batch_sampler=batch_sampler, - worker_init_fn=seed_worker, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs, - ) diff --git a/opensora/datasets/utils.py b/opensora/datasets/utils.py index 22212eb..82166da 100644 --- a/opensora/datasets/utils.py +++ b/opensora/datasets/utils.py @@ -1,15 +1,9 @@ -import random -from typing import Iterator, Optional - import numpy as np import torch import torchvision import torchvision.transforms as transforms from PIL import Image -from torch.distributed import ProcessGroup -from torch.distributed.distributed_c10d import _get_default_group -from torch.utils.data import DataLoader, Dataset -from torch.utils.data.distributed import DistributedSampler + from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader from torchvision.io import write_video from torchvision.utils import save_image @@ -29,14 +23,16 @@ def temporal_random_crop(vframes, num_frames, frame_interval): return video -def get_transforms_video(name="center", resolution=(256, 256)): - if name == "center": - assert resolution[0] == resolution[1], "Resolution must be square for center crop" +def get_transforms_video(name="center", image_size=(256, 256)): + if name is None: + return None + elif name == "center": + assert image_size[0] == image_size[1], "image_size must be square for center crop" transform_video = transforms.Compose( [ video_transforms.ToTensorVideo(), # TCHW # video_transforms.RandomHorizontalFlipVideo(), - video_transforms.UCFCenterCropVideo(resolution[0]), + video_transforms.UCFCenterCropVideo(image_size[0]), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) @@ -44,7 +40,7 @@ def get_transforms_video(name="center", resolution=(256, 256)): transform_video = transforms.Compose( [ video_transforms.ToTensorVideo(), # TCHW - video_transforms.ResizeCrop(resolution), + video_transforms.ResizeCrop(image_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) @@ -54,7 +50,9 @@ def get_transforms_video(name="center", resolution=(256, 256)): def get_transforms_image(name="center", image_size=(256, 256)): - if name == "center": + if name is None: + return None + elif name == "center": assert image_size[0] == image_size[1], "Image size must be square for center crop" transform = transforms.Compose( [ @@ -123,89 +121,6 @@ def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1)): return save_path -class StatefulDistributedSampler(DistributedSampler): - def __init__( - self, - dataset: Dataset, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, - shuffle: bool = True, - seed: int = 0, - drop_last: bool = False, - ) -> None: - super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) - self.start_index: int = 0 - - def __iter__(self) -> Iterator: - iterator = super().__iter__() - indices = list(iterator) - indices = indices[self.start_index :] - return iter(indices) - - def __len__(self) -> int: - return self.num_samples - self.start_index - - def set_start_index(self, start_index: int) -> None: - self.start_index = start_index - - -def prepare_dataloader( - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - process_group: Optional[ProcessGroup] = None, - **kwargs, -): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `StatefulDistributedSampler`. - - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns: - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - process_group = process_group or _get_default_group() - sampler = StatefulDistributedSampler( - dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle - ) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs, - ) - - def center_crop_arr(pil_image, image_size): """ Center cropping implementation from ADM. diff --git a/opensora/models/stdit/stdit2.py b/opensora/models/stdit/stdit2.py index 9222f30..bb08f24 100644 --- a/opensora/models/stdit/stdit2.py +++ b/opensora/models/stdit/stdit2.py @@ -212,6 +212,7 @@ class STDiT2(nn.Module): # support dynamic input self.patch_size = patch_size self.input_size = input_size + self.base_size = (input_size[1] / patch_size[1] * input_size[2] / patch_size[2]) ** 0.5 self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size) self.t_embedder = TimestepEmbedder(hidden_size) @@ -280,7 +281,7 @@ class STDiT2(nn.Module): W = W // self.patch_size[2] return (T, H, W) - def forward(self, x, timestep, y, mask=None, x_mask=None, num_frames=None): + def forward(self, x, timestep, y, mask=None, x_mask=None, num_frames=None, height=None, width=None, ar=None): """ Forward pass of STDiT. Args: @@ -297,26 +298,38 @@ class STDiT2(nn.Module): timestep = timestep.to(self.dtype) y = y.to(self.dtype) - # TODO: hard-coded for now - hw = torch.tensor([self.input_size[1], self.input_size[2]], device=x.device, dtype=x.dtype).repeat(B, 1) - ar = torch.tensor([[self.input_size[1] / self.input_size[2]]], device=x.device, dtype=x.dtype).repeat(B, 1) + # === process data info === + # 1. get dynamic size + if height is None or width is None: + hw = torch.tensor([self.input_size[1], self.input_size[2]], device=x.device, dtype=x.dtype).repeat(B, 1) + else: + hw = torch.cat([height[:, None], width[:, None]], dim=1) + csize = self.csize_embedder(hw, B) + + # 2. get aspect ratio + if ar is None: + ar = torch.tensor([[self.input_size[1] / self.input_size[2]]], device=x.device, dtype=x.dtype).repeat(B, 1) + else: + ar = ar.unsqueeze(1) + ar = self.ar_embedder(ar, B) + data_info = torch.cat([csize, ar], dim=1) + + # 3. get number of frames if num_frames is None: num_frames = torch.tensor([x.shape[2]], device=x.device, dtype=x.dtype) fl = num_frames.unsqueeze(1) - csize = self.csize_embedder(hw, B) - ar = self.ar_embedder(ar, B) fl = self.fl_embedder(fl, B) - data_info = torch.cat([csize, ar], dim=1) - # Dynamic + # === get dynamic shape size === _, _, Tx, Hx, Wx = x.size() T, H, W = self.get_dynamic_size(x) S = H * W + pos_emb = self.get_spatial_pos_embed(H, W, self.base_size).to(x.device, x.dtype) # embedding x = self.x_embedder(x) # [B, N, C] x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) - x = x + self.get_spatial_pos_embed(H, W).to(x.device, x.dtype) + x = x + pos_emb x = rearrange(x, "B T S C -> B (T S) C") # shard over the sequence dim if sp is enabled @@ -419,11 +432,12 @@ class STDiT2(nn.Module): imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) return imgs - def get_spatial_pos_embed(self, H, W): + def get_spatial_pos_embed(self, H, W, base_size=None): pos_embed = get_2d_sincos_pos_embed( self.hidden_size, (H, W), scale=self.space_scale, + base_size=base_size, ) pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) return pos_embed diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index c8d8ee0..2840605 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -71,6 +71,8 @@ def merge_args(cfg, args, training=False): cfg["mask_ratios"] = None if "transform_name" not in cfg.dataset: cfg.dataset["transform_name"] = "center" + if "bucket_config" not in cfg: + cfg["bucket_config"] = None # Both training and inference if "multi_resolution" not in cfg: diff --git a/scripts/train.py b/scripts/train.py index b710d1b..c3c41d2 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -18,8 +18,7 @@ from opensora.acceleration.parallel_states import ( set_sequence_parallel_group, ) from opensora.acceleration.plugin import ZeroSeqParallelPlugin -from opensora.datasets import prepare_dataloader -from opensora.datasets.datasets_variable import prepare_dataloader_with_batchsampler +from opensora.datasets import prepare_dataloader, prepare_variable_dataloader from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save from opensora.utils.config_utils import ( @@ -91,39 +90,22 @@ def main(): # 3. build dataset and dataloader # ====================================================== dataset = build_module(cfg.dataset, DATASETS) - + dataloader_args = dict( + dataset=dataset, + batch_size=cfg.batch_size, + num_workers=cfg.num_workers, + shuffle=True, + drop_last=True, + pin_memory=True, + process_group=get_data_parallel_group(), + ) # TODO: use plugin's prepare dataloader - # a batch contains: - # { - # "video": torch.Tensor, # [B, C, T, H, W], - # "text": List[str], - # } - if cfg.dataset.type == "VideoTextDataset": - dataloader = prepare_dataloader( - dataset, - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, - shuffle=True, - drop_last=True, - pin_memory=True, - process_group=get_data_parallel_group(), - ) + if cfg.bucket_config is None: + dataloader = prepare_dataloader(**dataloader_args) else: - dataloader = prepare_dataloader_with_batchsampler( - dataset, - batch_size=cfg.batch_size, - batch_size_bucket=dataset.batch_size_bucket, - bucket=dataset.bucket, - num_workers=cfg.num_workers, - shuffle=True, - drop_last=True, - pin_memory=True, - process_group=get_data_parallel_group(), - ) - print(dataset.batch_size_bucket) - logger.info(f"Dataset contains {len(dataset):,} videos ({dataset.data_path})") - + dataloader = prepare_variable_dataloader(bucket_config=cfg.bucket_config, **dataloader_args) total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size + logger.info(f"Dataset contains {len(dataset):,} videos ({dataset.data_path})") logger.info(f"Total batch size: {total_batch_size}") # ====================================================== @@ -219,9 +201,8 @@ def main(): ) as pbar: for step in pbar: batch = next(dataloader_iter) - x = batch["video"].to(device, dtype) # [B, C, T, H, W] - y = batch["text"] - + x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] + y = batch.pop("text") # Visual and text encoding with torch.no_grad(): # Prepare visual inputs @@ -237,8 +218,8 @@ def main(): mask = None # Video info - if "num_frames" in batch: - model_args["num_frames"] = batch["num_frames"].to(device, dtype) + for k, v in batch.items(): + model_args[k] = v.to(device, dtype) # Diffusion t = torch.randint(0, scheduler.num_timesteps, (x.shape[0],), device=device)