diff --git a/README.md b/README.md index 5ea3a5a..959fe64 100644 --- a/README.md +++ b/README.md @@ -28,27 +28,37 @@ inference, and more. Our provided checkpoint can produce 2s 512x512 videos. ## 🔆 New Features/Updates -- 📍 Open-Sora-v1 is trained on xxx. We train the model in three stages. Model weights are available here. Training details can be found here. -- ✅ Support training acceleration including flash-attention, accelerated T5, mixed precision, gradient checkpointing, splitted VAE, sequence parallelism, etc. XXX times. See more discussions [here](). -- ✅ We provide video cutting and captioning tools for data preprocessing. Our data collection plan can be found [here](). -- ✅ We find VQ-VAE from [] has a low quality and thus adopt a better VAE from []. We also find patching in the time dimension deteriorates the quality. See more discussions [here](). -- ✅ We investigate different architectures including DiT, Latte, and our proposed STDiT. Our STDiT achieves a better trade-off between quality and speed. See more discussions [here](). -- ✅ Support clip and t5 text conditioning. -- ✅ By viewing images as one-frame videos, our project supports training DiT on both images and videos (e.g., ImageNet & UCF101). +- 📍 Open-Sora-v1 is trained on xxx. We train the model in three stages. Model weights are available here. Training details can be found here. [WIP] +- ✅ Support training acceleration including flash-attention, accelerated T5, mixed precision, gradient checkpointing, splitted VAE, sequence parallelism, etc. XXX times. Details locates at [acceleration.md](docs/acceleration.md). [WIP] +- ✅ We provide video cutting and captioning tools for data preprocessing. Instructions can be found [here](tools/data/README.md) and our data collection plan can be found at [datasets.md](docs/datasets.md). +- ✅ We find VQ-VAE from [VideoGPT](https://wilson1yan.github.io/videogpt/index.html) has a low quality and thus adopt a better VAE from [Stability-AI](https://huggingface.co/stabilityai/sd-vae-ft-mse-original). We also find patching in the time dimension deteriorates the quality. See our **[report](docs/report_v1.md)** for more discussions. +- ✅ We investigate different architectures including DiT, Latte, and our proposed STDiT. Our **STDiT** achieves a better trade-off between quality and speed. See our **[report](docs/report_v1.md)** for more discussions. +- ✅ Support clip and T5 text conditioning. +- ✅ By viewing images as one-frame videos, our project supports training DiT on both images and videos (e.g., ImageNet & UCF101). See [command.md](docs/command.md) for more instructions. - ✅ Support inference with official weights from [DiT](https://github.com/facebookresearch/DiT), [Latte](https://github.com/Vchitect/Latte), and [PixArt](https://pixart-alpha.github.io/). +
+View more +- ✅ Refactor the codebase. See [structure.md](docs/structure.md) to learn the project structure and how to use the config files. + +
### TODO list sorted by priority - [ ] Complete the data processing pipeline (including dense optical flow, aesthetics scores, text-image similarity, deduplication, etc.). See [datasets.md]() for more information. **[WIP]** - [ ] Training Video-VAE. **[WIP]** + +
+View more + - [ ] Support image and video conditioning. - [ ] Evaluation pipeline. - [ ] Incoporate a better scheduler, e.g., rectified flow in SD3. - [ ] Support variable aspect ratios, resolutions, durations. - [ ] Support SD3 when released. +
## Contents @@ -78,7 +88,7 @@ cd Open-Sora pip install xxx ``` -After installation, to get fimilar with the project, you can check the [here]() for the project structure and how to use the config files. +After installation, we suggest reading [structure.md](docs/structure.md) to learn the project structure and how to use the config files. ## Model Weights @@ -128,7 +138,7 @@ We are grateful for their exceptional work and generous contribution to open sou } ``` -Zangwei Zheng and Xiangyu Peng equally contributed to this work during their internship at [HPC-AI Tech](https://hpc-ai.com/). +[Zangwei Zheng](https://github.com/zhengzangw) and [Xiangyu Peng](https://github.com/xyupeng) equally contributed to this work during their internship at [HPC-AI Tech](https://hpc-ai.com/). ## Star History diff --git a/docs/acceleration.md b/docs/acceleration.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/command.md b/docs/command.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/report_v1.md b/docs/report_v1.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/structure.md b/docs/structure.md new file mode 100644 index 0000000..deb967b --- /dev/null +++ b/docs/structure.md @@ -0,0 +1,167 @@ +# Repo & Config Structure + +## Repo Structure + +```plaintext +Open-Sora +├── README.md +├── docs +│ ├── acceleration.md -> Acceleration & Speed benchmark +│ ├── command.md -> Commands for training & inference +│ ├── datasets.md -> Datasets used in this project +│ ├── structure.md -> This file +│ └── report_v1.md -> Report for Open-Sora v1 +├── scripts +│ ├── train.py -> diffusion training script +│ └── inference.py -> Report for Open-Sora v1 +├── configs -> Configs for training & inference +├── opensora +│ ├── __init__.py +│ ├── registry.py -> Registry helper +│   ├── acceleration -> Acceleration related code +│   ├── dataset -> Dataset related code +│   ├── models +│   │   ├── layers -> Common layers +│   │   ├── vae -> VAE as image encoder +│   │   ├── text_encoder -> Text encoder +│   │   │   ├── classes.py -> Class id encoder (inference only) +│   │   │   ├── clip.py -> CLIP encoder +│   │   │   └── t5.py -> T5 encoder +│   │   ├── dit +│   │   ├── latte +│   │   ├── pixart +│   │   └── stdit -> Our STDiT related code +│   ├── schedulers -> Diffusion shedulers +│   │   ├── iddpm -> IDDPM for training and inference +│   │ └── dpms -> DPM-Solver for fast inference +│ └── utils +└── tools -> Tools for data processing and more +``` + +## Configs + +Our config files follows [MMEgine](https://github.com/open-mmlab/mmengine). MMEngine will reads the config file (a `.py` file) and parse it into a dictionary-like object. + +```plaintext +Open-Sora +└── configs -> Configs for training & inference + ├── opensora -> STDiT related configs + │ ├── inference + │ │ ├── 16x256x256.py -> Sample videos 16 frames 256x256 + │ │ ├── 16x512x512.py -> Sample videos 16 frames 512x512 + │ │ └── 64x512x512.py -> Sample videos 64 frames 512x512 + │ └── train + │ ├── 16x256x256.py -> Train on videos 16 frames 256x256 + │ ├── 16x256x256.py -> Train on videos 16 frames 256x256 + │ └── 64x512x512.py -> Train on videos 64 frames 512x512 + ├── dit -> DiT related configs +    │   ├── inference +    │   │   ├── 1x256x256-class.py -> Sample images with ckpts from DiT +    │   │   ├── 1x256x256.py -> Sample images with clip condition +    │   │   └── 16x256x256.py -> Sample videos +    │   └── train +    │     ├── 1x256x256.py -> Train on images with clip condition +    │      └── 16x256x256.py -> Train on videos + ├── latte -> Latte related configs + └── pixart -> PixArt related configs +``` + +## Inference config demos + +```python +# Define sampling size +num_frames = 64 # number of frames +fps = 24 // 2 # frames per second (divided by 2 for frame_interval=2) +image_size = (512, 512) # image size (height, width) + +# Define model +model = dict( + type="STDiT-XL/2", # Select model type (STDiT-XL/2, DiT-XL/2, etc.) + space_scale=1.0, # (Optional) Space positional encoding scale (new height / old height) + time_scale=2 / 3, # (Optional) Time positional encoding scale (new frame_interval / old frame_interval) + from_pretrained="PRETRAINED_MODEL", # (Optional) Load from pretrained model + no_temporal_pos_emb=True, # (Optional) Disable temporal positional encoding (for image) +) +vae = dict( + type="VideoAutoencoderKL", # Select VAE type + from_pretrained="stabilityai/sd-vae-ft-ema", # Load from pretrained VAE + split=8, # Split VAE micro batch size to be batch_size * num_frames // split +) +text_encoder = dict( + type="t5", # Select text encoder type (t5, clip) + from_pretrained="./pretrained_models/t5_ckpts", # Load from pretrained text encoder + model_max_length=120, # Maximum length of input text +) +scheduler = dict( + type="iddpm", # Select scheduler type (iddpm, dpm-solver) + num_sampling_steps=100, # Number of sampling steps + cfg_scale=7.0, # hyper-parameter for classifier-free diffusion +) +dtype = "fp16" # Computation type (fp16, fp32, bf16) + +# Other settings +batch_size = 1 # batch size +seed = 42 # random seed +prompt_path = "./assets/texts/t2v_samples.txt" # path to prompt file +save_dir = "./samples" # path to save samples +``` + +## Training config demos + +```python +# Define sampling size +num_frames = 64 +frame_interval = 2 # sample every 2 frames +image_size = (512, 512) + +# Define dataset +root = None # root path to the dataset +data_path = "CSV_PATH" # path to the csv file +use_image_transform = False # True if training on images +num_workers = 4 # number of workers for dataloader + +# Define acceleration +dtype = "bf16" # Computation type (fp16, fp32, bf16) +grad_checkpoint = True # Use gradient checkpointing +plugin = "zero2" # Plugin for distributed training (zero2, zero2-seq) +sp_size = 1 # Sequence parallelism size (1 for no sequence parallelism) + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=1.0, + time_scale=2 / 3, + from_pretrained="YOUR_PRETRAINED_MODEL", + enable_flashattn=True, # Enable flash attention + enable_layernorm_kernel=True, # Enable layernorm kernel +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + split=8, +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, + shardformer=True, # Enable shardformer for T5 acceleration +) +scheduler = dict( + type="iddpm", + timestep_respacing="", # Default 1000 timesteps +) + +# Others +seed = 42 +outputs = "outputs" # path to save checkpoints +wandb = False # Use wandb for logging + +epochs = 1000 # number of epochs (just large enough, kill when satisfied) +log_every = 10 +ckpt_every = 250 +load = None # path to resume training + +batch_size = 4 +lr = 2e-5 +grad_clip = 1.0 # gradient clipping +```