diff --git a/configs/dit/train/1x256x256.py b/configs/dit/train/1x256x256.py index 667e0a8..c423b24 100644 --- a/configs/dit/train/1x256x256.py +++ b/configs/dit/train/1x256x256.py @@ -1,14 +1,15 @@ -num_frames = 1 -frame_interval = 1 -image_size = (256, 256) - # Define dataset -root = None -data_path = "CSV_PATH" -use_image_transform = True -num_workers = 4 +dataset = dict( + type="VideoTextDataset", + data_path=None, + num_frames=1, + frame_interval=1, + image_size=(256, 256), + transform_name="center", +) # Define acceleration +num_workers = 4 dtype = "bf16" grad_checkpoint = False plugin = "zero2" diff --git a/configs/opensora-v1-1/inference/Vx1024.py b/configs/opensora-v1-1/inference/image.py similarity index 91% rename from configs/opensora-v1-1/inference/Vx1024.py rename to configs/opensora-v1-1/inference/image.py index 649d9e5..96d6783 100644 --- a/configs/opensora-v1-1/inference/Vx1024.py +++ b/configs/opensora-v1-1/inference/image.py @@ -1,6 +1,7 @@ -num_frames = 4 +num_frames = 1 fps = 24 // 3 -image_size = (704, 1472) +# image_size = (1358, 680) +image_size = (2160, 3840) multi_resolution = "STDiT2" # Define model diff --git a/configs/opensora-v1-1/train/test.py b/configs/opensora-v1-1/train/benchmark.py similarity index 73% rename from configs/opensora-v1-1/train/test.py rename to configs/opensora-v1-1/train/benchmark.py index 87f4d3b..d3b60cb 100644 --- a/configs/opensora-v1-1/train/test.py +++ b/configs/opensora-v1-1/train/benchmark.py @@ -4,15 +4,22 @@ dataset = dict( data_path=None, num_frames=None, frame_interval=3, - image_size=(None, None, None), + image_size=(None, None), transform_name="resize_crop", ) bucket_config = { - "256": {1: (1.0, 64)}, # 32 ok, 64 broken + # "256": {1: (1.0, 256)}, # 4.5s/it + # "512": {1: (1.0, 96)}, # 4.7s/it + # "512": {1: (1.0, 128)}, # 6.3s/it + # "480p": {1: (1.0, 50)}, # 4.0s/it + # "1024": {1: (1.0, 32)}, # 6.8s/it + # "1024": {1: (1.0, 20)}, # 4.3s/it + # "1080p": {1: (1.0, 16)}, # 8.6s/it + "1080p": {1: (1.0, 8)}, # 4.4s/it } # Define acceleration -num_workers = 0 +num_workers = 4 dtype = "bf16" grad_checkpoint = True plugin = "zero2" diff --git a/configs/opensora-v1-1/train/image.py b/configs/opensora-v1-1/train/image.py index 588f3c2..fe9a671 100644 --- a/configs/opensora-v1-1/train/image.py +++ b/configs/opensora-v1-1/train/image.py @@ -4,15 +4,15 @@ dataset = dict( data_path=None, num_frames=None, frame_interval=3, - image_size=(None, None, None), + image_size=(None, None), transform_name="resize_crop", ) -bucket_config = { - "256": {1: (1.0, 128)}, - "512": {1: (1.0, 2)}, - "480p": {1: (1.0, 2)}, - "1024": {1: (1.0, 2)}, - "1080p": {1: (1.0, 2)}, +bucket_config = { # 6s/it + "256": {1: (1.0, 256)}, + "512": {1: (1.0, 80)}, + "480p": {1: (1.0, 52)}, + "1024": {1: (1.0, 20)}, + "1080p": {1: (1.0, 8)}, } # Define acceleration @@ -53,7 +53,7 @@ wandb = False epochs = 1000 log_every = 10 -ckpt_every = 1000 +ckpt_every = 500 load = None batch_size = 10 # only for logging diff --git a/opensora/datasets/aspect.py b/opensora/datasets/aspect.py index 0d772c9..7588cde 100644 --- a/opensora/datasets/aspect.py +++ b/opensora/datasets/aspect.py @@ -3,10 +3,12 @@ import math # Ours -def get_h_w(a, ts): +def get_h_w(a, ts, eps=1e-4): h = (ts * a) ** 0.5 + h = h + eps h = math.ceil(h) if math.ceil(h) % 2 == 0 else math.floor(h) w = h / a + w = w + eps w = math.ceil(w) if math.ceil(w) % 2 == 0 else math.floor(w) return h, w @@ -93,11 +95,11 @@ ASPECT_RATIO_720P = { "0.67": (784, 1176), "0.75": (832, 1110), "1.00": (960, 960), - "1.33": (1108, 831), + "1.33": (1108, 832), "1.50": (1176, 784), "1.78": (1280, 720), "1.89": (1320, 698), - "2.00": (1358, 679), + "2.00": (1358, 680), "2.08": (1386, 666), } @@ -116,7 +118,7 @@ ASPECT_RATIO_480P = { "1.52": (790, 520), "1.78": (854, 480), "1.92": (888, 462), - "2.00": (906, 453), + "2.00": (906, 454), "2.10": (928, 442), } @@ -130,14 +132,14 @@ ASPECT_RATIO_360P = { "0.54": (470, 870), "0.56": (480, 854), # base "0.62": (506, 810), - "0.67": (522, 783), + "0.67": (522, 784), "0.75": (554, 738), "1.00": (640, 640), "1.33": (740, 555), "1.50": (784, 522), "1.78": (854, 480), "1.89": (880, 466), - "2.00": (906, 453), + "2.00": (906, 454), "2.08": (924, 444), } diff --git a/opensora/datasets/sampler.py b/opensora/datasets/sampler.py index 08e8930..1478ff2 100644 --- a/opensora/datasets/sampler.py +++ b/opensora/datasets/sampler.py @@ -4,6 +4,7 @@ from pprint import pprint from typing import Iterator, List, Optional import torch +import torch.distributed as dist from torch.utils.data import DistributedSampler from .bucket import Bucket @@ -29,6 +30,27 @@ class VariableVideoBatchSampler(DistributedSampler): self.bucket = Bucket(bucket_config) self.verbose = verbose self.last_micro_batch_access_index = 0 + self.approximate_num_batch = None + + def get_num_batch(self) -> int: + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + bucket_sample_dict = OrderedDict() + + # group by bucket + # each data sample is put into a bucket with a similar image/video size + for i in range(len(self.dataset)): + t, h, w = self.dataset.get_data_info(i) + bucket_id = self.bucket.get_bucket_id(t, h, w, self.dataset.frame_interval, g) + if bucket_id is None: + continue + if bucket_id not in bucket_sample_dict: + bucket_sample_dict[bucket_id] = [] + bucket_sample_dict[bucket_id].append(i) + + # calculate the number of batches + self._print_bucket_info(bucket_sample_dict) + return self.approximate_num_batch def __iter__(self) -> Iterator[List[int]]: g = torch.Generator() @@ -155,8 +177,9 @@ class VariableVideoBatchSampler(DistributedSampler): def load_state_dict(self, state_dict: dict) -> None: self.__dict__.update(state_dict) - def _print_bucket_info(self, bucket_sample_dict: dict) -> None: + def _print_bucket_info(self, bucket_sample_dict: dict, verbose=True) -> None: total_samples = 0 + num_batch = 0 num_dict = {} num_aspect_dict = defaultdict(int) num_hwt_dict = defaultdict(int) @@ -166,13 +189,17 @@ class VariableVideoBatchSampler(DistributedSampler): num_dict[k] = size num_aspect_dict[k[-1]] += size num_hwt_dict[k[:-1]] += size - print(f"Total training samples: {total_samples}, num buckets: {len(num_dict)}") - print("Bucket samples:") - pprint(num_dict) - print("Bucket samples by HxWxT:") - pprint(num_hwt_dict) - print("Bucket samples by aspect ratio:") - pprint(num_aspect_dict) + num_batch += size // self.bucket.get_batch_size(k[:-1]) + if dist.get_rank() == 0 and verbose: + print(f"Total training samples: {total_samples}, num buckets: {len(num_dict)}") + print("Bucket samples:") + pprint(num_dict) + print("Bucket samples by HxWxT:") + pprint(num_hwt_dict) + print("Bucket samples by aspect ratio:") + pprint(num_aspect_dict) + print(f"Number of batches: {num_batch}") + self.approximate_num_batch = num_batch def set_epoch(self, epoch: int) -> None: super().set_epoch(epoch) diff --git a/opensora/models/vae/vae.py b/opensora/models/vae/vae.py index e6e1441..79f78b1 100644 --- a/opensora/models/vae/vae.py +++ b/opensora/models/vae/vae.py @@ -53,9 +53,9 @@ class VideoAutoencoderKL(nn.Module): def get_latent_size(self, input_size): latent_size = [] for i in range(3): - assert ( - input_size[i] is None or input_size[i] % self.patch_size[i] == 0 - ), "Input size must be divisible by patch size" + # assert ( + # input_size[i] is None or input_size[i] % self.patch_size[i] == 0 + # ), "Input size must be divisible by patch size" latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None) return latent_size @@ -89,9 +89,9 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module): def get_latent_size(self, input_size): latent_size = [] for i in range(3): - assert ( - input_size[i] is None or input_size[i] % self.patch_size[i] == 0 - ), "Input size must be divisible by patch size" + # assert ( + # input_size[i] is None or input_size[i] % self.patch_size[i] == 0 + # ), "Input size must be divisible by patch size" latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None) return latent_size diff --git a/scripts/train.py b/scripts/train.py index 439b461..8ddf9c6 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -37,8 +37,6 @@ def main(): # 1. args & cfg # ====================================================== cfg = parse_configs(training=True) - print("Training configuration:") - pprint(cfg._cfg_dict) exp_name, exp_dir = create_experiment_workspace(cfg) save_training_config(cfg._cfg_dict, exp_dir) @@ -58,6 +56,8 @@ def main(): if not coordinator.is_master(): logger = create_logger(None) else: + print("Training configuration:") + pprint(cfg._cfg_dict) logger = create_logger(exp_dir) logger.info(f"Experiment directory created at {exp_dir}") @@ -173,13 +173,16 @@ def main(): dataloader=dataloader, ) torch.set_default_dtype(torch.float) - num_steps_per_epoch = len(dataloader) logger.info("Boost model for distributed training") + if cfg.dataset.type == "VariableVideoTextDataset": + num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size() + else: + num_steps_per_epoch = len(dataloader) # ======================================================= # 6. training loop # ======================================================= - start_epoch = start_step = log_step = sampler_start_idx = 0 + start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0 running_loss = 0.0 sampler_to_io = dataloader.batch_sampler if cfg.dataset.type == "VariableVideoTextDataset" else None # 6.1. resume training @@ -254,6 +257,7 @@ def main(): running_loss += loss.item() global_step = epoch * num_steps_per_epoch + step log_step += 1 + acc_step += 1 # Log to tensorboard if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0: @@ -269,6 +273,7 @@ def main(): "epoch": epoch, "loss": loss.item(), "avg_loss": avg_loss, + "acc_step": acc_step, }, step=global_step, )