From 150cf4666a5b2b0ce888faee4cdb265561704b4b Mon Sep 17 00:00:00 2001 From: "Zheng Zangwei (Alex Zheng)" Date: Sat, 16 Mar 2024 21:17:16 +0800 Subject: [PATCH] Docs/readme (#74) * update docs * update docs * update docs * update acceleration docs and fix typos --- README.md | 25 +++++++++-- configs/dit/train/16x256x256.py | 2 +- configs/dit/train/1x256x256.py | 2 +- configs/opensora/inference/16x512x512.py | 2 +- configs/opensora/inference/64x512x512.py | 2 +- configs/opensora/train/16x512x512.py | 4 +- configs/opensora/train/360x512x512.py | 55 +++++++++++++++++++++++ configs/opensora/train/64x512x512.py | 2 +- configs/pixart/train/16x256x256.py | 2 +- configs/pixart/train/64x512x512.py | 54 ++++++++++++++++++++++ docs/acceleration.md | 57 ++++++++++++++++++++++++ docs/structure.md | 4 +- opensora/datasets/datasets.py | 2 +- opensora/models/vae/vae.py | 24 +++++----- 14 files changed, 213 insertions(+), 24 deletions(-) create mode 100644 configs/opensora/train/360x512x512.py create mode 100644 configs/pixart/train/64x512x512.py diff --git a/README.md b/README.md index 6422cb4..869cd16 100644 --- a/README.md +++ b/README.md @@ -23,8 +23,8 @@ inference, and more. Our provided checkpoint can produce 2s 512x512 videos. ## 🎥 Latest Demo -| **2s 512x512** | **2s 512x512** | **2s 512x512** | -| ----------------------------------------------- | ----------------------------------------------- |-------------------------------------------------| +| **2s 512x512** | **2s 512x512** | **2s 512x512** | +| ---------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------- | | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/de1963d3-b43b-4e68-a670-bb821ebb6f80) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/13f8338f-3d42-4b71-8142-d234fbd746cc) | [ ](https://github.com/hpcaitech/Open-Sora/assets/99191637/fa6a65a6-e32a-4d64-9a9e-eabb0ebb8c16) | Click for the original video. @@ -32,7 +32,7 @@ Click for the original video. ## 🔆 New Features/Updates - 📍 Open-Sora-v1 is trained on xxx. We train the model in three stages. Model weights are available here. Training details can be found here. [WIP] -- ✅ Support training acceleration including flash-attention, accelerated T5, mixed precision, gradient checkpointing, splitted VAE, sequence parallelism, etc. XXX times. Details locates at [acceleration.md](docs/acceleration.md). [WIP] +- ✅ Support training acceleration including accelerated transformer, faster T5 and VAE, and sequence parallelism. Open-Sora improve **55%** training speed when training on 64x512x512 videos. Details locates at [acceleration.md](docs/acceleration.md). - ✅ We provide video cutting and captioning tools for data preprocessing. Instructions can be found [here](tools/data/README.md) and our data collection plan can be found at [datasets.md](docs/datasets.md). - ✅ We find VQ-VAE from [VideoGPT](https://wilson1yan.github.io/videogpt/index.html) has a low quality and thus adopt a better VAE from [Stability-AI](https://huggingface.co/stabilityai/sd-vae-ft-mse-original). We also find patching in the time dimension deteriorates the quality. See our **[report](docs/report_v1.md)** for more discussions. - ✅ We investigate different architectures including DiT, Latte, and our proposed STDiT. Our **STDiT** achieves a better trade-off between quality and speed. See our **[report](docs/report_v1.md)** for more discussions. @@ -145,6 +145,25 @@ We provide code to split a long video into separate clips efficiently using `mul ## Training +To launch training, first prepare the dataset and the pretrained weights. [WIP] + +Then run the following commands to launch training on a single node. + +```bash +# 1 GPU, 16x256x256 +torchrun --nnodes=1 --nproc_per_node=1 scripts/train.py configs/opensora/train/16x256x512.py --data-path YOUR_CSV_PATH +# 8 GPUs, 64x512x512 +torchrun --nnodes=1 --nproc_per_node=8 scripts/train.py configs/opensora/train/64x512x512.py --data-path YOUR_CSV_PATH --ckpt-path YOUR_PRETRAINED_CKPT +``` + +To launch training on multiple nodes, prepare a hostfile according to [ColossalAI](https://colossalai.org/docs/basics/launch_colossalai/#launch-with-colossal-ai-cli), and run the following commands. + +```bash +colossalai run --nproc_per_node 8 --hostfile hostfile scripts/train.py configs/opensora/train/64x512x512.py --data-path YOUR_CSV_PATH --ckpt-path YOUR_PRETRAINED_CKPT +``` + +For training other models and advanced usage, see [here](docs/commands.md) for more instructions. + ## Acknowledgement * [DiT](https://github.com/facebookresearch/DiT): Scalable Diffusion Models with Transformers. diff --git a/configs/dit/train/16x256x256.py b/configs/dit/train/16x256x256.py index 67d3736..af8ee87 100644 --- a/configs/dit/train/16x256x256.py +++ b/configs/dit/train/16x256x256.py @@ -10,7 +10,7 @@ num_workers = 4 # Define acceleration dtype = "bf16" -grad_checkpoint = True +grad_checkpoint = False plugin = "zero2" sp_size = 1 diff --git a/configs/dit/train/1x256x256.py b/configs/dit/train/1x256x256.py index f8bd8d3..667e0a8 100644 --- a/configs/dit/train/1x256x256.py +++ b/configs/dit/train/1x256x256.py @@ -10,7 +10,7 @@ num_workers = 4 # Define acceleration dtype = "bf16" -grad_checkpoint = True +grad_checkpoint = False plugin = "zero2" sp_size = 1 diff --git a/configs/opensora/inference/16x512x512.py b/configs/opensora/inference/16x512x512.py index 2837275..b18dbef 100644 --- a/configs/opensora/inference/16x512x512.py +++ b/configs/opensora/inference/16x512x512.py @@ -12,7 +12,7 @@ model = dict( vae = dict( type="VideoAutoencoderKL", from_pretrained="stabilityai/sd-vae-ft-ema", - split=8, + micro_batch_size=128, ) text_encoder = dict( type="t5", diff --git a/configs/opensora/inference/64x512x512.py b/configs/opensora/inference/64x512x512.py index 25aad23..5983464 100644 --- a/configs/opensora/inference/64x512x512.py +++ b/configs/opensora/inference/64x512x512.py @@ -12,7 +12,7 @@ model = dict( vae = dict( type="VideoAutoencoderKL", from_pretrained="stabilityai/sd-vae-ft-ema", - split=8, + micro_batch_size=128, ) text_encoder = dict( type="t5", diff --git a/configs/opensora/train/16x512x512.py b/configs/opensora/train/16x512x512.py index f827ce3..885aad1 100644 --- a/configs/opensora/train/16x512x512.py +++ b/configs/opensora/train/16x512x512.py @@ -10,7 +10,7 @@ num_workers = 4 # Define acceleration dtype = "bf16" -grad_checkpoint = True +grad_checkpoint = False plugin = "zero2" sp_size = 1 @@ -26,7 +26,7 @@ model = dict( vae = dict( type="VideoAutoencoderKL", from_pretrained="stabilityai/sd-vae-ft-ema", - split=4, + micro_batch_size=128, ) text_encoder = dict( type="t5", diff --git a/configs/opensora/train/360x512x512.py b/configs/opensora/train/360x512x512.py new file mode 100644 index 0000000..7a6f759 --- /dev/null +++ b/configs/opensora/train/360x512x512.py @@ -0,0 +1,55 @@ +num_frames = 360 +frame_interval = 1 +image_size = (512, 512) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2-seq" +sp_size = 2 + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=1.0, + time_scale=2 / 3, + from_pretrained=None, + enable_flashattn=True, + enable_layernorm_kernel=True, + enable_sequence_parallelism=True, # enable sq here +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=128, +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="iddpm", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 250 +load = None + +batch_size = 1 +lr = 2e-5 +grad_clip = 1.0 diff --git a/configs/opensora/train/64x512x512.py b/configs/opensora/train/64x512x512.py index d902849..81154c8 100644 --- a/configs/opensora/train/64x512x512.py +++ b/configs/opensora/train/64x512x512.py @@ -26,7 +26,7 @@ model = dict( vae = dict( type="VideoAutoencoderKL", from_pretrained="stabilityai/sd-vae-ft-ema", - split=8, # split to lower memory usage + micro_batch_size=128, ) text_encoder = dict( type="t5", diff --git a/configs/pixart/train/16x256x256.py b/configs/pixart/train/16x256x256.py index 6819573..b47731e 100644 --- a/configs/pixart/train/16x256x256.py +++ b/configs/pixart/train/16x256x256.py @@ -10,7 +10,7 @@ num_workers = 4 # Define acceleration dtype = "bf16" -grad_checkpoint = True +grad_checkpoint = False plugin = "zero2" sp_size = 1 diff --git a/configs/pixart/train/64x512x512.py b/configs/pixart/train/64x512x512.py new file mode 100644 index 0000000..628cf25 --- /dev/null +++ b/configs/pixart/train/64x512x512.py @@ -0,0 +1,54 @@ +num_frames = 64 +frame_interval = 2 +image_size = (512, 512) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="PixArt-XL/2", + space_scale=1.0, + time_scale=2 / 3, + from_pretrained=None, + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=128, +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="iddpm", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 250 +load = None + +batch_size = 4 +lr = 2e-5 +grad_clip = 1.0 diff --git a/docs/acceleration.md b/docs/acceleration.md index e69de29..3d1eb51 100644 --- a/docs/acceleration.md +++ b/docs/acceleration.md @@ -0,0 +1,57 @@ +# Acceleration + +Open-Sora aims to provide a high-speed training framework for diffusion models. We can achieve **55%** training speed acceleration when training on **64 frames 512x512 videos**. Our framework support training **1min 1080p videos**. + +## Accelerated Transformer + +Open-Sora boosts the training speed by: + +- Kernal optimization including [flash attention](https://github.com/Dao-AILab/flash-attention), fused layernorm kernal, and the ones compiled by colossalAI. +- Hybrid parallelism including ZeRO. +- Gradient checkpointing for larger batch size. + +Our training speed on images is comparable to [OpenDiT](https://github.com/NUS-HPC-AI-Lab/OpenDiT), an project to accelerate DiT training. The training speed is measured on 8 H800 GPUs with batch size 128, image size 256x256. + +| Model | Throughput (img/s) | Throughput (tokens/s) | +| -------- | ------------------ | --------------------- | +| DiT | | | +| OpenDiT | | | +| OpenSora | 175 | 45k | + +## Efficient STDiT + +Our STDiT adopts spatial-temporal attention to model the video data. Compared with directly applying full attention on DiT, our STDiT is more efficient as the number of frames increases. Our current framework only supports sequence parallelism for very long sequence. + +The training speed is measured on 8 H800 GPUs with acceleration techniques applied, GC means gradient checkpointing. Both with T5 conditioning like PixArt. + +| Model | Setting | Throughput (sample/s) | Throughput (tokens/s) | +| ---------------- | -------------- | --------------------- | --------------------- | +| DiT | 16x256 (4k) | 7.20 | 29k | +| STDiT | 16x256 (4k) | 7.00 | 28k | +| DiT | 16x512 (16k) | 0.85 | 14k | +| STDiT | 16x512 (16k) | 1.45 | 23k | +| DiT (GC) | 64x512 (65k) | 0.08 | 5k | +| STDiT (GC) | 64x512 (65k) | 0.40 | 25k | +| STDiT (GC, sp=2) | 360x512 (370k) | 0.10 | 18k | + +With a 4x downsampling in the temporal dimension with Video-VAE, an 24fps video has 450 frames. The gap between the speed of STDiT (28k tokens/s) and DiT on images (up to 45k tokens/s) mainly comes from the T5 and VAE encoding, and temperal attention. + +## Accelerated Encoder (T5, VAE) + +During training, texts are encoded by T5, and videos are encoded by VAE. Typically there are two ways to accelerate the training: + +1. Preprocess text and video data in advance and save them to disk. +2. Encode text and video data during training, and accelerate the encoding process. + +For option 1, 120 tokens for one sample require 1M disk space, and a 64x64x64 latent requires 4M. Considering a training dataset with 10M video clips, the total disk space required is 50TB. Our storage system is not ready at this time for this scale of data. + +For option 2, we boost T5 speed and memory requirement. According to [OpenDiT](https://github.com/NUS-HPC-AI-Lab/OpenDiT), we find VAE consumes a large number of GPU memory. Thus we split batch size into smaller ones for VAE encoding. With both techniques, we can greatly accelerated the training speed. + +The training speed is measured on 8 H800 GPUs with STDiT. + +| Acceleration | Setting | Throughput (img/s) | Throughput (tokens/s) | +| ------------ | ------------- | ------------------ | --------------------- | +| Baseline | 16x256 (4k) | 6.16 | 25k | +| w. faster T5 | 16x256 (4k) | 7.00 | 29k | +| Baseline | 64x512 (65k) | 0.94 | 15k | +| w. both | 64x512 (65k) | 1.45 | 23k | diff --git a/docs/structure.md b/docs/structure.md index fbbcc05..6a521ae 100644 --- a/docs/structure.md +++ b/docs/structure.md @@ -94,7 +94,7 @@ model = dict( vae = dict( type="VideoAutoencoderKL", # Select VAE type from_pretrained="stabilityai/sd-vae-ft-ema", # Load from pretrained VAE - split=8, # Split VAE micro batch size to be batch_size * num_frames // split + micro_batch_size=128, # VAE with micro batch size to save memory ) text_encoder = dict( type="t5", # Select text encoder type (t5, clip) @@ -147,7 +147,7 @@ model = dict( vae = dict( type="VideoAutoencoderKL", from_pretrained="stabilityai/sd-vae-ft-ema", - split=8, + micro_batch_size=128, ) text_encoder = dict( type="t5", diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index c709685..9033841 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -61,7 +61,7 @@ class DatasetFromCSV(torch.utils.data.Dataset): if ext.lower() in ["mp4", "avi", "mov", "mkv"]: self.is_video = True else: - assert f".{ext.lower()}" in IMG_EXTENSIONS + assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}" self.is_video = False self.transform = transform diff --git a/opensora/models/vae/vae.py b/opensora/models/vae/vae.py index 3761f0a..363bbfe 100644 --- a/opensora/models/vae/vae.py +++ b/opensora/models/vae/vae.py @@ -8,25 +8,27 @@ from opensora.registry import MODELS @MODELS.register_module() class VideoAutoencoderKL(nn.Module): - def __init__(self, from_pretrained=None, split=None): + def __init__(self, from_pretrained=None, micro_batch_size=None): super().__init__() self.module = AutoencoderKL.from_pretrained(from_pretrained) self.out_channels = self.module.config.latent_channels self.patch_size = (1, 8, 8) - self.split = split + self.micro_batch_size = micro_batch_size def encode(self, x): # x: (B, C, T, H, W) B = x.shape[0] x = rearrange(x, "B C T H W -> (B T) C H W") - if self.split is None: + if self.micro_batch_size is None: x = self.module.encode(x).latent_dist.sample().mul_(0.18215) else: - bs = x.shape[0] // self.split + bs = self.micro_batch_size x_out = [] - for i in range(self.split): - x_out.append(self.module.encode(x[i * bs : (i + 1) * bs]).latent_dist.sample().mul_(0.18215)) + for i in range(0, x.shape[0], bs): + x_bs = x[i : i + bs] + x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215) + x_out.append(x_bs) x = torch.cat(x_out, dim=0) x = rearrange(x, "(B T) C H W -> B C T H W", B=B) return x @@ -35,13 +37,15 @@ class VideoAutoencoderKL(nn.Module): # x: (B, C, T, H, W) B = x.shape[0] x = rearrange(x, "B C T H W -> (B T) C H W") - if self.split is None: + if self.micro_batch_size is None: x = self.module.decode(x / 0.18215).sample else: - bs = x.shape[0] // self.split + bs = self.micro_batch_size x_out = [] - for i in range(self.split): - x_out.append(self.module.decode(x[i * bs : (i + 1) * bs] / 0.18215).sample) + for i in range(0, x.shape[0], bs): + x_bs = x[i : i + bs] + x_bs = self.module.decode(x_bs / 0.18215).sample + x_out.append(x_bs) x = torch.cat(x_out, dim=0) x = rearrange(x, "(B T) C H W -> B C T H W", B=B) return x