diff --git a/configs/opensora-v1-2/train/stage1_feat.py b/configs/opensora-v1-2/train/stage1_feat.py new file mode 100644 index 0000000..ce96d30 --- /dev/null +++ b/configs/opensora-v1-2/train/stage1_feat.py @@ -0,0 +1,52 @@ +# Dataset settings +dataset = dict(type="BatchFeatureDataset") +grad_checkpoint = True +num_workers = 4 + +# Acceleration settings +dtype = "bf16" +plugin = "zero2" + +# Model settings +model = dict( + type="STDiT3-XL/2", + from_pretrained=None, + qk_norm=True, + enable_flash_attn=True, + enable_layernorm_kernel=True, + freeze_y_embedder=True, +) +scheduler = dict( + type="rflow", + use_timestep_transform=True, + sample_method="logit-normal", +) + +# Mask settings +mask_ratios = { + "random": 0.2, + "intepolate": 0.01, + "quarter_random": 0.01, + "quarter_head": 0.01, + "quarter_tail": 0.01, + "quarter_head_tail": 0.01, + "image_random": 0.05, + "image_head": 0.1, + "image_tail": 0.05, + "image_head_tail": 0.05, +} + +# Log settings +seed = 42 +outputs = "outputs" +wandb = False +epochs = 1000 +log_every = 10 +ckpt_every = 1 + +# optimization settings +load = None +grad_clip = 1.0 +lr = 2e-4 +ema_decay = 0.99 +adam_eps = 1e-15 diff --git a/configs/opensora-v1-2/train/train_load_batch.py b/configs/opensora-v1-2/train/train_load_batch.py index ae20146..1ce3de8 100644 --- a/configs/opensora-v1-2/train/train_load_batch.py +++ b/configs/opensora-v1-2/train/train_load_batch.py @@ -1,6 +1,6 @@ # Dataset settings dataset = dict( - type="BatchDataset", + type="BatchFeatureDataset", ) grad_checkpoint = True diff --git a/opensora/datasets/__init__.py b/opensora/datasets/__init__.py index fa4c798..ff4c88d 100644 --- a/opensora/datasets/__init__.py +++ b/opensora/datasets/__init__.py @@ -1,2 +1,2 @@ -from .datasets import IMG_FPS, VariableVideoTextDataset, VideoTextDataset, BatchDataset +from .datasets import IMG_FPS, BatchFeatureDataset, VariableVideoTextDataset, VideoTextDataset from .utils import get_transforms_image, get_transforms_video, is_img, is_vid, save_sample diff --git a/opensora/datasets/dataloader.py b/opensora/datasets/dataloader.py index 37a549c..8516951 100644 --- a/opensora/datasets/dataloader.py +++ b/opensora/datasets/dataloader.py @@ -8,7 +8,7 @@ from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _get_default_group from torch.utils.data import DataLoader -from .datasets import VariableVideoTextDataset, VideoTextDataset +from .datasets import BatchFeatureDataset, VariableVideoTextDataset, VideoTextDataset from .sampler import BatchDistributedSampler, StatefulDistributedSampler, VariableVideoBatchSampler @@ -56,6 +56,7 @@ def prepare_dataloader( worker_init_fn=get_seed_worker(seed), pin_memory=pin_memory, num_workers=num_workers, + collate_fn=collate_fn_default, **_kwargs, ), batch_sampler, @@ -77,105 +78,29 @@ def prepare_dataloader( drop_last=drop_last, pin_memory=pin_memory, num_workers=num_workers, + collate_fn=collate_fn_default, **_kwargs, ), sampler, ) - else: - raise ValueError(f"Unsupported dataset type: {type(dataset)}") - - -def prepare_variable_dataloader( - dataset, - batch_size, - bucket_config, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - process_group=None, - num_bucket_build_workers=1, - **kwargs, -): - _kwargs = kwargs.copy() - process_group = process_group or _get_default_group() - batch_sampler = VariableVideoBatchSampler( - dataset, - bucket_config, - num_replicas=process_group.size(), - rank=process_group.rank(), - shuffle=shuffle, - seed=seed, - drop_last=drop_last, - verbose=True, - num_bucket_build_workers=num_bucket_build_workers, - ) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return torch.utils.data.DataLoader( - dataset, - batch_sampler=batch_sampler, - worker_init_fn=seed_worker, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs, - ) - - -def build_batch_dataloader( - dataset, - # batch_size=1, - # shuffle=False, - seed=1024, - # drop_last=False, - pin_memory=False, - num_workers=0, - process_group: Optional[ProcessGroup] = None, - distributed=True, - **kwargs, -): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `BatchDistributedSampler`. - - batch_size must be 1; shuffle is not supported so far - """ - _kwargs = kwargs.copy() - if distributed: - process_group = process_group or _get_default_group() + elif isinstance(dataset, BatchFeatureDataset): sampler = BatchDistributedSampler( dataset, num_replicas=process_group.size(), rank=process_group.rank(), ) + return DataLoader( + dataset, + batch_size=1, + sampler=sampler, + worker_init_fn=get_seed_worker(seed), + pin_memory=pin_memory, + num_workers=num_workers, + collate_fn=collate_fn_batch, + **_kwargs, + ) else: - raise NotImplementedError - sampler = None - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader( - dataset, - batch_size=1, - sampler=sampler, - worker_init_fn=seed_worker, - # drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs, - ) + raise ValueError(f"Unsupported dataset type: {type(dataset)}") def collate_fn_default(batch): diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index 4297c3c..7050c98 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -1,4 +1,5 @@ import os +from glob import glob import numpy as np import torch @@ -185,16 +186,11 @@ class VariableVideoTextDataset(VideoTextDataset): return ret def __getitem__(self, index): - try: - return self.getitem(index) - except Exception: - # we return None here in case of errorneous data - # the collate function will handle it - return None + return self.getitem(index) @DATASETS.register_module() -class BatchDataset(torch.utils.data.Dataset): +class BatchFeatureDataset(torch.utils.data.Dataset): """ The dataset is composed of multiple .bin files. Each .bin file is a list of batch data (like a buffer). All .bin files have the same length. @@ -203,16 +199,15 @@ class BatchDataset(torch.utils.data.Dataset): Avoid loading the same .bin on two difference GPUs, i.e., one .bin is assigned to one GPU only. """ - def __init__(self): - # self.meta = read_file(data_path) - # self.path_list = self.meta['path'].tolist() - self.path_list = [f"/mnt/nfs-207/sora_data/webvid-10M/feat_text/data/{idx}.bin" for idx in range(5)] + def __init__(self, data_path=None): + self.path_list = sorted(glob(data_path + "/**/*.bin")) self._len_buffer = len(torch.load(self.path_list[0])) self._num_buffers = len(self.path_list) self.num_samples = self.len_buffer * len(self.path_list) self.cur_file_idx = -1 + self.cur_buffer = None @property def num_buffers(self): @@ -224,10 +219,9 @@ class BatchDataset(torch.utils.data.Dataset): def _load_buffer(self, idx): file_idx = idx // self.len_buffer - if file_idx == self.cur_file_idx: - return - self.cur_file_idx = file_idx - self.cur_buffer = torch.load(self.path_list[file_idx]) + if file_idx != self.cur_file_idx: + self.cur_file_idx = file_idx + self.cur_buffer = torch.load(self.path_list[file_idx]) def __len__(self): return self.num_samples diff --git a/opensora/datasets/sampler.py b/opensora/datasets/sampler.py index 00c6ba2..64c36e2 100644 --- a/opensora/datasets/sampler.py +++ b/opensora/datasets/sampler.py @@ -1,6 +1,3 @@ -import math -import warnings - from collections import OrderedDict, defaultdict from pprint import pformat from typing import Iterator, List, Optional @@ -291,13 +288,15 @@ class VariableVideoBatchSampler(DistributedSampler): class BatchDistributedSampler(DistributedSampler): """ - Used with BatchDataset; + Used with BatchFeatureDataset; Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then | buffer {i} | buffer {i+1} + ------ | ------------------- | ------------------- rank 0 | 0, 1, 2, 3, 4, | 5, 6, 7, 8, 9 rank 1 | 10, 11, 12, 13, 14, | 15, 16, 17, 18, 19 rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29 """ + def __iter__(self): num_buffers = self.dataset.num_buffers len_buffer = self.dataset.len_buffer diff --git a/opensora/registry.py b/opensora/registry.py index 4335d38..4c2f283 100644 --- a/opensora/registry.py +++ b/opensora/registry.py @@ -15,6 +15,8 @@ def build_module(module, builder, **kwargs): Returns: Any: The built module. """ + if module is None: + return None if isinstance(module, dict): cfg = deepcopy(module) for k, v in kwargs.items(): diff --git a/scripts/misc/extract_feat.py b/scripts/misc/extract_feat.py index b08baeb..47fee72 100644 --- a/scripts/misc/extract_feat.py +++ b/scripts/misc/extract_feat.py @@ -7,7 +7,7 @@ import torch.distributed as dist from tqdm import tqdm from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group -from opensora.datasets.dataloader import collate_fn_default, prepare_dataloader +from opensora.datasets.dataloader import prepare_dataloader from opensora.registry import DATASETS, MODELS, build_module from opensora.utils.config_utils import parse_configs, save_training_config from opensora.utils.misc import FeatureSaver, Timer, create_logger, format_numel_str, get_model_numel, to_torch_dtype @@ -60,7 +60,6 @@ def main(): drop_last=True, pin_memory=True, process_group=get_data_parallel_group(), - collate_fn=collate_fn_default, ) dataloader, sampler = prepare_dataloader( bucket_config=cfg.get("bucket_config", None), diff --git a/scripts/misc/profile_train.py b/scripts/misc/profile_train.py index 3475aec..1221b30 100644 --- a/scripts/misc/profile_train.py +++ b/scripts/misc/profile_train.py @@ -14,8 +14,7 @@ from tqdm import tqdm from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import get_data_parallel_group -from opensora.datasets import prepare_dataloader -from opensora.datasets.dataloader import collate_fn_default +from opensora.datasets.dataloader import prepare_dataloader from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config @@ -99,7 +98,6 @@ def main(): drop_last=True, pin_memory=True, process_group=get_data_parallel_group(), - collate_fn=collate_fn_default, ) dataloader, sampler = prepare_dataloader( bucket_config=cfg.get("bucket_config", None), diff --git a/scripts/train.py b/scripts/train.py index a4df266..1818782 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -14,7 +14,7 @@ from tqdm import tqdm from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import get_data_parallel_group -from opensora.datasets.dataloader import collate_fn_default, prepare_dataloader +from opensora.datasets.dataloader import prepare_dataloader from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config @@ -95,7 +95,6 @@ def main(): drop_last=True, pin_memory=True, process_group=get_data_parallel_group(), - collate_fn=collate_fn_default, ) dataloader, sampler = prepare_dataloader( bucket_config=cfg.get("bucket_config", None), @@ -109,12 +108,15 @@ def main(): # ====================================================== logger.info("Building models...") # == build text-encoder and vae == - text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype) - vae = build_module(cfg.vae, MODELS).to(device, dtype).eval() + text_encoder = build_module(cfg.get("text_encoder", None), MODELS, device=device, dtype=dtype) + vae = build_module(cfg.get("vae", None), MODELS).to(device, dtype).eval() # == build diffusion model == input_size = (dataset.num_frames, *dataset.image_size) - latent_size = vae.get_latent_size(input_size) + if vae is not None: + latent_size = vae.get_latent_size(input_size) + else: + latent_size = (None, None, None) model = ( build_module( cfg.model, @@ -221,6 +223,7 @@ def main(): for step, batch in pbar: x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] y = batch.pop("text") + breakpoint() # == visual and text encoding == with torch.no_grad(): diff --git a/scripts/train_load_batch.py b/scripts/train_load_batch.py index a612bb8..dab6c04 100644 --- a/scripts/train_load_batch.py +++ b/scripts/train_load_batch.py @@ -14,7 +14,7 @@ from tqdm import tqdm from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import get_data_parallel_group -from opensora.datasets.dataloader import build_batch_dataloader, collate_fn_batch +from opensora.datasets.dataloader import build_batch_dataloader from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config @@ -98,7 +98,6 @@ def main(): # drop_last=True, pin_memory=True, process_group=get_data_parallel_group(), - collate_fn=collate_fn_batch, ) dataloader = build_batch_dataloader(**dataloader_args) num_steps_per_epoch = len(dataset) // dist.get_world_size()