mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 13:14:44 +02:00
Merge branch 'dev/v1.2' of https://github.com/hpcaitech/Open-Sora-dev into dev/v1.2
This commit is contained in:
commit
ba80141b9d
30
.github/workflows/github_page.yaml
vendored
Normal file
30
.github/workflows/github_page.yaml
vendored
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
name: GitHub Pages
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-22.04
|
||||
permissions:
|
||||
contents: write
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
ref: gallery
|
||||
|
||||
- name: Setup Node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
|
||||
- run: npm install
|
||||
- run: npm run build
|
||||
|
||||
- name: Deploy
|
||||
uses: peaceiris/actions-gh-pages@v3
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
publish_dir: ./build
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
|
|
@ -183,5 +183,12 @@ hostfile
|
|||
gradio_cached_examples/
|
||||
wandb/
|
||||
|
||||
<<<<<<< HEAD
|
||||
# vae weights
|
||||
eval/vae/flolpips/weights/
|
||||
=======
|
||||
# npm
|
||||
node_modules/
|
||||
package-lock.json
|
||||
package.json
|
||||
>>>>>>> upstream/main
|
||||
|
|
|
|||
92
README.md
92
README.md
|
|
@ -9,24 +9,26 @@
|
|||
<a href="https://twitter.com/yangyou1991/status/1769411544083996787?s=61&t=jT0Dsx2d-MS5vS9rNM5e5g"><img src="https://img.shields.io/badge/Twitter-Discuss-blue?logo=twitter&"></a>
|
||||
<a href="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png"><img src="https://img.shields.io/badge/微信-小助手加群-green?logo=wechat&"></a>
|
||||
<a href="https://hpc-ai.com/blog/open-sora-v1.0"><img src="https://img.shields.io/badge/Open_Sora-Blog-blue"></a>
|
||||
<a href="https://huggingface.co/spaces/hpcai-tech/open-sora"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Gradio Demo-blue"></a>
|
||||
</div>
|
||||
|
||||
## Open-Sora: Democratizing Efficient Video Production for All
|
||||
|
||||
We present **Open-Sora**, an initiative dedicated to **efficiently** produce high-quality video and make the model,
|
||||
tools and contents accessible to all. By embracing **open-source** principles,
|
||||
We design and implement **Open-Sora**, an initiative dedicated to **efficiently** producing high-quality video. We hope to make the model,
|
||||
tools and all details accessible to all. By embracing **open-source** principles,
|
||||
Open-Sora not only democratizes access to advanced video generation techniques, but also offers a
|
||||
streamlined and user-friendly platform that simplifies the complexities of video production.
|
||||
With Open-Sora, we aim to inspire innovation, creativity, and inclusivity in the realm of content creation.
|
||||
streamlined and user-friendly platform that simplifies the complexities of video generation.
|
||||
With Open-Sora, our goal is to foster innovation, creativity, and inclusivity within the field of content creation.
|
||||
|
||||
[[中文文档]](/docs/zh_CN/README.md)
|
||||
[[中文文档]](/docs/zh_CN/README.md) [[潞晨云部署视频教程]](https://www.bilibili.com/video/BV141421R7Ag)
|
||||
|
||||
<h4>Open-Sora is still at an early stage and under active development.</h4>
|
||||
|
||||
## 📰 News
|
||||
|
||||
* **[2024.04.22]** 🔥 We release **Open-Sora 1.1**, which supports **2s~15s, 144p to 720p, any aspect ratio** text-to-image, **text-to-video, image-to-video, video-to-video, infinite time** generation. In addition, a full video processing pipeline is released. [[checkpoints]]() [[report]](/docs/report_02.md)
|
||||
* **[2024.03.18]** We release **Open-Sora 1.0**, a fully open-source project for video generation.
|
||||
* **[2024.04.25]** 🤗 We released the [Gradio demo for Open-Sora](https://huggingface.co/spaces/hpcai-tech/open-sora) on Hugging Face Spaces.
|
||||
* **[2024.04.25]** 🔥 We released **Open-Sora 1.1**, which supports **2s~15s, 144p to 720p, any aspect ratio** text-to-image, **text-to-video, image-to-video, video-to-video, infinite time** generation. In addition, a full video processing pipeline is released. [[checkpoints]]() [[report]](/docs/report_02.md)
|
||||
* **[2024.03.18]** We released **Open-Sora 1.0**, a fully open-source project for video generation.
|
||||
Open-Sora 1.0 supports a full pipeline of video data preprocessing, training with
|
||||
<a href="https://github.com/hpcaitech/ColossalAI"><img src="assets/readme/colossal_ai.png" width="8%" ></a>
|
||||
acceleration,
|
||||
|
|
@ -37,16 +39,20 @@ With Open-Sora, we aim to inspire innovation, creativity, and inclusivity in the
|
|||
|
||||
## 🎥 Latest Demo
|
||||
|
||||
🔥 You can experience Open-Sora on our [🤗 Gradio application on Hugging Face](https://huggingface.co/spaces/hpcai-tech/open-sora). More samples are available in our [Gallery](https://hpcaitech.github.io/Open-Sora/).
|
||||
|
||||
| **2s 240×426** | **2s 240×426** |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [<img src="assets/demo/sample_16x240x426_9.gif" width="">](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/c31ebc52-de39-4a4e-9b1e-9211d45e05b2) | [<img src="assets/demo/sora_16x240x426_26.gif" width="">](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/c31ebc52-de39-4a4e-9b1e-9211d45e05b2) |
|
||||
| [<img src="assets/demo/sora_16x240x426_27.gif" width="">](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/f7ce4aaa-528f-40a8-be7a-72e61eaacbbd) | [<img src="assets/demo/sora_16x240x426_40.gif" width="">](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/5d58d71e-1fda-4d90-9ad3-5f2f7b75c6a9) |
|
||||
| **2s 240×426** | **2s 240×426** |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [<img src="assets/demo/sample_16x240x426_9.gif" width="">](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/c31ebc52-de39-4a4e-9b1e-9211d45e05b2) | [<img src="assets/demo/sora_16x240x426_26.gif" width="">](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/c31ebc52-de39-4a4e-9b1e-9211d45e05b2) |
|
||||
| [<img src="assets/demo/sora_16x240x426_27.gif" width="">](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/f7ce4aaa-528f-40a8-be7a-72e61eaacbbd) | [<img src="assets/demo/sora_16x240x426_40.gif" width="">](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/5d58d71e-1fda-4d90-9ad3-5f2f7b75c6a9) |
|
||||
|
||||
| **2s 426×240** | **2s 426×240** | **4s 480×854** |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [<img src="assets/demo/sora_16x426x240_24.gif" width="">](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/34ecb4a0-4eef-4286-ad4c-8e3a87e5a9fd) | [<img src="assets/demo/sora_16x426x240_3.gif" width="">](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/3e892ad2-9543-4049-b005-643a4c1bf3bf) | [<img src="assets/demo/sample_32x480x854_9.gif" width="">](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/c1619333-25d7-42ba-a91c-18dbc1870b18) |
|
||||
| **2s 426×240** | **4s 480×854** |
|
||||
| ---------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [<img src="assets/demo/sora_16x426x240_24.gif" width="">](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/34ecb4a0-4eef-4286-ad4c-8e3a87e5a9fd) | [<img src="assets/demo/sample_32x480x854_9.gif" width="">](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/c1619333-25d7-42ba-a91c-18dbc1870b18) |
|
||||
|
||||
| **16s 320×320** | **16s 224×448** | **2s 426×240** |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [<img src="assets/demo/sample_16s_320x320.gif" width="">](https://github.com/hpcaitech/Open-Sora/assets/99191637/3cab536e-9b43-4b33-8da8-a0f9cf842ff2) | [<img src="assets/demo/sample_16s_224x448.gif" width="">](https://github.com/hpcaitech/Open-Sora/assets/99191637/9fb0b9e0-c6f4-4935-b29e-4cac10b373c4) | [<img src="assets/demo/sora_16x426x240_3.gif" width="">](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/3e892ad2-9543-4049-b005-643a4c1bf3bf) |
|
||||
|
||||
<details>
|
||||
<summary>OpenSora 1.0 Demo</summary>
|
||||
|
|
@ -63,13 +69,11 @@ see [here](/assets/texts/t2v_samples.txt) for full prompts.
|
|||
|
||||
</details>
|
||||
|
||||
More samples are available in our [gallery](https://hpcaitech.github.io/Open-Sora/).
|
||||
|
||||
## 🔆 New Features/Updates
|
||||
|
||||
* 📍 **Open-Sora 1.1** released. Model weights are available [here](). It is trained on **0s~15s, 144p to 720p, various aspect ratios** videos. See our **[report 1.1](docs/report_02.md)** for more discussions.
|
||||
* 🔧 **Data processing pipeline v1.1** is released. An automatic [processing pipeline](#data-processing) from raw videos to (text, video clip) pairs is provided, including scene cutting $\rightarrow$ filtering(aesthetic, optical flow, OCR, etc.) $\rightarrow$ captioning $\rightarrow$ managing. With this tool, you can easily build your video dataset.
|
||||
* ✅ Modified ST-DiT architecture includes rope positional encoding, qk norm, longer text length, etc.
|
||||
* ✅ Improved ST-DiT architecture includes rope positional encoding, qk norm, longer text length, etc.
|
||||
* ✅ Support training with any resolution, aspect ratio, and duration (including images).
|
||||
* ✅ Support image and video conditioning and video editing, and thus support animating images, connecting videos, etc.
|
||||
* 📍 **Open-Sora 1.0** released. Model weights are available [here](#model-weights). With only 400K video clips and 200 H800
|
||||
|
|
@ -77,7 +81,7 @@ More samples are available in our [gallery](https://hpcaitech.github.io/Open-Sor
|
|||
* ✅ Three-stage training from an image diffusion model to a video diffusion model. We provide the weights for each
|
||||
stage.
|
||||
* ✅ Support training acceleration including accelerated transformer, faster T5 and VAE, and sequence parallelism.
|
||||
Open-Sora improve **55%** training speed when training on 64x512x512 videos. Details locates
|
||||
Open-Sora improves **55%** training speed when training on 64x512x512 videos. Details locates
|
||||
at [acceleration.md](docs/acceleration.md).
|
||||
* 🔧 **Data preprocessing pipeline v1.0**,
|
||||
including [downloading](/tools/datasets/README.md), [video cutting](/tools/scenedetect/README.md),
|
||||
|
|
@ -106,8 +110,8 @@ More samples are available in our [gallery](https://hpcaitech.github.io/Open-Sor
|
|||
### TODO list sorted by priority
|
||||
|
||||
* [ ] Training Video-VAE and adapt our model to new VAE. **[WIP]**
|
||||
* [ ] Incoporate a better scheduler, e.g., rectified flow in SD3.
|
||||
* [ ] Scaling model parameters and dataset size.
|
||||
* [ ] Scaling model parameters and dataset size. **[WIP]**
|
||||
* [ ] Incoporate a better scheduler, e.g., rectified flow in SD3. **[WIP]**
|
||||
|
||||
<details>
|
||||
<summary>View more</summary>
|
||||
|
|
@ -143,7 +147,7 @@ Other useful documents and links are listed below.
|
|||
|
||||
## Installation
|
||||
|
||||
TODO: discuss how to include data installation here.
|
||||
### Install from Source
|
||||
|
||||
```bash
|
||||
# create a virtual env
|
||||
|
|
@ -174,22 +178,43 @@ cd Open-Sora
|
|||
pip install -v .
|
||||
```
|
||||
|
||||
### Use Docker
|
||||
|
||||
Run the following command to build a docker image from Dockerfile provided.
|
||||
|
||||
```bash
|
||||
docker build -t opensora ./docker
|
||||
```
|
||||
|
||||
Run the following command to start the docker container in interactive mode.
|
||||
|
||||
```bash
|
||||
docker run -ti --gpus all -v {MOUNT_DIR}:/data opensora
|
||||
```
|
||||
|
||||
## Model Weights
|
||||
|
||||
### Open-Sora 1.1 Model Weights
|
||||
|
||||
TBD
|
||||
| Resolution | Model Size | Data | #iterations | Batch Size | URL |
|
||||
| ------------------ | ---------- | -------------------------- | ----------- | ------------------------------------------------- | -------------------------------------------------------------------- |
|
||||
| mainly 144p & 240p | 700M | 10M videos + 2M images | 100k | [dynamic](/configs/opensora-v1-1/train/stage2.py) | [:link:](https://huggingface.co/hpcai-tech/OpenSora-STDiT-v2-stage2) |
|
||||
| 144p to 720p | 700M | 500K HQ videos + 1M images | 4k | [dynamic](/configs/opensora-v1-1/train/stage3.py) | [:link:](https://huggingface.co/hpcai-tech/OpenSora-STDiT-v2-stage3) |
|
||||
|
||||
See our **[report 1.1](docs/report_02.md)** for more infomation.
|
||||
|
||||
:warning: **LIMITATION**: This version contains known issues which we are going to fix in the next version (as we save computation resource for the next release). In addition, the video generation may fail for long duration, and high resolution will have noisy results due to this problem.
|
||||
|
||||
### Open-Sora 1.0 Model Weights
|
||||
|
||||
<details>
|
||||
<summary>View more</summary>
|
||||
|
||||
| Resolution | Data | #iterations | Batch Size | GPU days (H800) | URL |
|
||||
| ---------- | ------ | ----------- | ---------- | --------------- | --------------------------------------------------------------------------------------------- |
|
||||
| 16×512×512 | 20K HQ | 20k | 2×64 | 35 | [:link:](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-HQ-16x512x512.pth) |
|
||||
| 16×256×256 | 20K HQ | 24k | 8×64 | 45 | [:link:](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-HQ-16x256x256.pth) |
|
||||
| 16×256×256 | 366K | 80k | 8×64 | 117 | [:link:](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-16x256x256.pth) |
|
||||
| Resolution | Model Size | Data | #iterations | Batch Size | GPU days (H800) | URL |
|
||||
| ---------- | ---------- | ------ | ----------- | ---------- | --------------- |
|
||||
| 16×512×512 | 700M | 20K HQ | 20k | 2×64 | 35 | [:link:](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-HQ-16x512x512.pth) |
|
||||
| 16×256×256 | 700M | 20K HQ | 24k | 8×64 | 45 | [:link:](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-HQ-16x256x256.pth) |
|
||||
| 16×256×256 | 700M | 366K | 80k | 8×64 | 117 | [:link:](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-16x256x256.pth) |
|
||||
|
||||
Training orders: 16x256x256 $\rightarrow$ 16x256x256 HQ $\rightarrow$ 16x512x512 HQ.
|
||||
|
||||
|
|
@ -207,7 +232,9 @@ on improving the quality and text alignment.
|
|||
|
||||
### Gradio Demo
|
||||
|
||||
We have provided a [Gradio application](./gradio) in this repository, you can use the following the command to start an interactive web application to experience video generation with Open-Sora.
|
||||
🔥 You can experience Open-Sora on our [🤗 Gradio application](https://huggingface.co/spaces/hpcai-tech/open-sora) on Hugging Face online.
|
||||
|
||||
If you want to deploy gradio locally, we have also provided a [Gradio application](./gradio) in this repository, you can use the following the command to start an interactive web application to experience video generation with Open-Sora.
|
||||
|
||||
```bash
|
||||
pip install gradio spaces
|
||||
|
|
@ -221,12 +248,12 @@ This will launch a Gradio application on your localhost. If you want to know mor
|
|||
Since Open-Sora 1.1 supports inference with dynamic input size, you can pass the input size as an argument.
|
||||
|
||||
```bash
|
||||
# video sampling
|
||||
# text to video
|
||||
python scripts/inference.py configs/opensora-v1-1/inference/sample.py \
|
||||
--ckpt-path CKPT_PATH --prompt "A beautiful sunset over the city" --num-frames 32 --image-size 480 854
|
||||
```
|
||||
|
||||
See [here](docs/commands.md#inference-with-open-sora-11) for more instructions.
|
||||
See [here](docs/commands.md#inference-with-open-sora-11) for more instructions including text-to-image, image-to-video, video-to-video, and infinite time generation.
|
||||
|
||||
### Open-Sora 1.0 Command Line Inference
|
||||
|
||||
|
|
@ -256,12 +283,12 @@ To lower the memory usage, set a smaller `vae.micro_batch_size` in the config (s
|
|||
</details>
|
||||
|
||||
## Data Processing
|
||||
|
||||
High-quality data is crucial for training good generation models.
|
||||
To this end, we establish a complete pipeline for data processing, which could seamlessly convert raw videos to high-quality video-text pairs.
|
||||
The pipeline is shown below. For detailed information, please refer to [data processing](docs/data_processing.md).
|
||||
Also check out the [datasets](docs/datasets.md) we use.
|
||||
|
||||
|
||||

|
||||
|
||||
## Training
|
||||
|
|
@ -343,9 +370,6 @@ following [all-contributors](https://github.com/all-contributors/all-contributor
|
|||
|
||||
If you wish to contribute to this project, you can refer to the [Contribution Guideline](./CONTRIBUTING.md).
|
||||
|
||||
[Zangwei Zheng](https://github.com/zhengzangw) and [Xiangyu Peng](https://github.com/xyupeng) equally contributed to
|
||||
this work during their internship at [HPC-AI Tech](https://hpc-ai.com/).
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
* [ColossalAI](https://github.com/hpcaitech/ColossalAI): A powerful large model parallel acceleration and optimization
|
||||
|
|
|
|||
|
|
@ -14,27 +14,35 @@ prompt = [
|
|||
|
||||
loop = 2
|
||||
condition_frame_length = 4
|
||||
reference_path = [
|
||||
"https://cdn.openai.com/tmp/s/interp/d0.mp4",
|
||||
None,
|
||||
"assets/images/condition/wave.png",
|
||||
]
|
||||
# valid when reference_path is not None
|
||||
# (loop id, ref id, ref start, target start, length, edit_ratio)
|
||||
# (
|
||||
# loop id, [the loop index of the condition image or video]
|
||||
# reference id, [the index of the condition image or video in the reference_path]
|
||||
# reference start, [the start frame of the condition image or video]
|
||||
# target start, [the location to insert]
|
||||
# length, [the number of frames to insert]
|
||||
# edit_ratio [the edit rate of the condition image or video]
|
||||
# )
|
||||
# See https://github.com/hpcaitech/Open-Sora/blob/main/docs/config.md#advanced-inference-config for more details
|
||||
# See https://github.com/hpcaitech/Open-Sora/blob/main/docs/commands.md#inference-with-open-sora-11 for more examples
|
||||
mask_strategy = [
|
||||
"0,0,0,0,8,0.3",
|
||||
None,
|
||||
"0",
|
||||
]
|
||||
reference_path = [
|
||||
"https://cdn.openai.com/tmp/s/interp/d0.mp4",
|
||||
None,
|
||||
"assets/images/condition/wave.png",
|
||||
]
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="STDiT2-XL/2",
|
||||
from_pretrained=None,
|
||||
from_pretrained="hpcai-tech/OpenSora-STDiT-v2-stage3",
|
||||
input_sq_size=512,
|
||||
qk_norm=True,
|
||||
qk_norm_legacy=True,
|
||||
enable_flashattn=True,
|
||||
enable_flash_attn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
vae = dict(
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@ multi_resolution = "STDiT2"
|
|||
# Define model
|
||||
model = dict(
|
||||
type="STDiT2-XL/2",
|
||||
from_pretrained=None,
|
||||
from_pretrained="hpcai-tech/OpenSora-STDiT-v2-stage3",
|
||||
input_sq_size=512,
|
||||
qk_norm=True,
|
||||
qk_norm_legacy=True,
|
||||
enable_flashattn=True,
|
||||
enable_flash_attn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
vae = dict(
|
||||
|
|
|
|||
25
docker/Dockerfile
Normal file
25
docker/Dockerfile
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
FROM hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
|
||||
# metainformation
|
||||
LABEL org.opencontainers.image.source = "https://github.com/hpcaitech/Open-Sora"
|
||||
LABEL org.opencontainers.image.licenses = "Apache License 2.0"
|
||||
LABEL org.opencontainers.image.base.name = "docker.io/library/hpcaitech/pytorch-cuda:2.1.0-12.1.0"
|
||||
|
||||
COPY . /workspace/Open-Sora
|
||||
|
||||
# inatall library dependencies
|
||||
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
|
||||
|
||||
# install flash attention
|
||||
RUN pip install flash-attn --no-build-isolation
|
||||
|
||||
# install apex
|
||||
RUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" git+https://github.com/NVIDIA/apex.git
|
||||
|
||||
# install xformers
|
||||
RUN pip install xformers --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# install this project
|
||||
RUN git clone https://github.com/hpcaitech/Open-Sora && \
|
||||
cd Open-Sora && \
|
||||
pip install -v .
|
||||
|
|
@ -51,11 +51,30 @@ You can adjust the `--num-frames` and `--image-size` to generate different resul
|
|||
`inference-long.py` is compatible with `inference.py` and supports advanced features.
|
||||
|
||||
```bash
|
||||
# long video generation
|
||||
# image condition
|
||||
python scripts/inference-long.py configs/opensora-v1-1/inference/sample.py --ckpt-path CKPT_PATH \
|
||||
--num-frames 32 --image-size 240 426 --sample-name image-cond \
|
||||
--prompt 'A breathtaking sunrise scene.{"reference_path": "assets/images/condition/wave.png","mask_strategy": "0"}'
|
||||
|
||||
# video extending
|
||||
python scripts/inference-long.py configs/opensora-v1-1/inference/sample.py --ckpt-path CKPT_PATH \
|
||||
--num-frames 32 --image-size 240 426 --sample-name image-cond \
|
||||
--prompt 'A car driving on the ocean.{"reference_path": "https://cdn.openai.com/tmp/s/interp/d0.mp4","mask_strategy": "0,0,0,-8,8"}'
|
||||
|
||||
# long video generation
|
||||
python scripts/inference-long.py configs/opensora-v1-1/inference/sample.py --ckpt-path CKPT_PATH \
|
||||
--num-frames 32 --image-size 240 426 --loop 16 --condition-frame-length 8 --sample-name long \
|
||||
--prompt '|0|a white jeep equipped with a roof rack driving on a dirt road in a coniferous forest.|2|a white jeep equipped with a roof rack driving on a dirt road in the desert.|4|a white jeep equipped with a roof rack driving on a dirt road in a mountain.|6|A white jeep equipped with a roof rack driving on a dirt road in a city.|8|a white jeep equipped with a roof rack driving on a dirt road on the surface of a river.|10|a white jeep equipped with a roof rack driving on a dirt road under the lake.|12|a white jeep equipped with a roof rack flying into the sky.|14|a white jeep equipped with a roof rack driving in the universe. Earth is the background.{"reference_path": "https://cdn.openai.com/tmp/s/interp/d0.mp4", "mask_strategy": "0,0,0,0,16"}'
|
||||
|
||||
# video connecting
|
||||
python scripts/inference-long.py configs/opensora-v1-1/inference/sample.py --ckpt-path CKPT_PATH \
|
||||
--num-frames 32 --image-size 240 426 --sample-name connect \
|
||||
--prompt 'A breathtaking sunrise scene.{"reference_path": "assets/images/condition/sunset1.png;assets/images/condition/sunset2.png","mask_strategy": "0;0,1,0,-1,1"}'
|
||||
|
||||
# video editing
|
||||
python scripts/inference-long.py configs/opensora-v1-1/inference/sample.py --ckpt-path CKPT_PATH \
|
||||
--num-frames 32 --image-size 480 853 --sample-name edit \
|
||||
--prompt 'A cyberpunk-style city at night.{"reference_path": "https://cdn.pixabay.com/video/2021/10/12/91744-636709154_large.mp4","mask_strategy": "0,0,0,0,32,0.4"}'
|
||||
```
|
||||
|
||||
### Inference with DiT pretrained on ImageNet
|
||||
|
|
|
|||
|
|
@ -3,22 +3,50 @@
|
|||
For Open-Sora 1.1, we conduct mixed training with both images and videos. The main datasets we use are listed below.
|
||||
Please refer to [README](/README.md#data-processing) for data processing.
|
||||
|
||||
## Panda-70M
|
||||
## Video
|
||||
|
||||
### Panda-70M
|
||||
|
||||
[Panda-70M](https://github.com/snap-research/Panda-70M) is a large-scale dataset with 70M video-caption pairs.
|
||||
We use the [training-10M subset](https://github.com/snap-research/Panda-70M/tree/main/dataset_dataloading) for training,
|
||||
We use the [training-10M subset](https://github.com/snap-research/Panda-70M/tree/main/dataset_dataloading) for training,
|
||||
which contains ~10M videos of better quality.
|
||||
|
||||
## Pexels
|
||||
[Pexels](https://www.pexels.com/) is a popular online platform that provides high-quality stock photos, videos, and music for free.
|
||||
### Pexels
|
||||
|
||||
[Pexels](https://www.pexels.com/) is a popular online platform that provides high-quality stock photos, videos, and music for free.
|
||||
Most videos from this website are of high quality. Thus, we use them for both pre-training and HQ fine-tuning.
|
||||
We really appreciate the great platform and the contributors!
|
||||
|
||||
## Inter4K
|
||||
### Inter4K
|
||||
|
||||
[Inter4K](https://github.com/alexandrosstergiou/Inter4K) is a dataset containing 1K video clips with 4K resolution.
|
||||
The dataset is proposed for super-resolution tasks. We use the dataset for HQ fine-tuning.
|
||||
|
||||
### HD-VG-130M
|
||||
|
||||
## HD-VG-130M
|
||||
[HD-VG-130M](https://github.com/daooshee/HD-VG-130M?tab=readme-ov-file) comprises 130M text-video pairs.
|
||||
The caption is generated by BLIP-2.
|
||||
[HD-VG-130M](https://github.com/daooshee/HD-VG-130M?tab=readme-ov-file) comprises 130M text-video pairs.
|
||||
The caption is generated by BLIP-2.
|
||||
We find the scene and the text quality are relatively poor. For OpenSora 1.0, we only use ~350K samples from this dataset.
|
||||
|
||||
## Image
|
||||
|
||||
### Midjourney-v5-1.7M
|
||||
|
||||
[Midjourney-v5-1.7M](https://huggingface.co/datasets/wanng/midjourney-v5-202304-clean) includes 1.7M image-text pairs.
|
||||
In detail, this dataset introduces two subsets: original and upscale.
|
||||
This dataset is proposed for exploring the relationship of prompts and high-quality images.
|
||||
|
||||
### Midjourney-kaggle-clean
|
||||
|
||||
[Midjourney-kaggle-clean](https://huggingface.co/datasets/wanng/midjourney-kaggle-clean) is a reconstructed version of [Midjourney User Prompts & Generated Images (250k)](https://www.kaggle.com/datasets/succinctlyai/midjourney-texttoimage?select=general-01_2022_06_20.json%5D), which is cleaned by rules.
|
||||
Moreover, this dataset is divided into two subsets: original and upscale.
|
||||
This dataset is proposed for enabling research on text-to-image model prompting.
|
||||
|
||||
### upsplash-lite
|
||||
|
||||
The [Unsplash-lite](https://github.com/unsplash/datasets) Dataset comprises 25k nature-themed Unsplash photos, 25k keywords, and 1M searches.
|
||||
This dataset covers a vast range of uses and contexts. Its extensive scope in intent and semantics opens new avenues for research and learning.
|
||||
|
||||
### LAION-AESTHETICS 6.5+
|
||||
|
||||
LAION aesthetic 6.5+ dataset is a subset of the LAION dataset, which contains 625K high-quality images with aesthetic scores > 6.5. However, as LAION is currently not publicly available, we use this 168k [subset](https://huggingface.co/datasets/bhargavsdesai/laion_improved_aesthetics_6.5plus_with_images).
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ To summarize, the training of Open-Sora 1.1 requires approximately **9 days** on
|
|||
|
||||
As we get one step closer to the replication of Sora, we find many limitations for the current model, and these limitations point to the future work.
|
||||
|
||||
- **Generation Failure**: we fine many cases (especially when the total token number is large or the content is complex), our model fails to generate the scene. There may be a collapse in the temporal attention and we have identified a potential bug in our code. We are working hard to fix it.
|
||||
- **Generation Failure**: we fine many cases (especially when the total token number is large or the content is complex), our model fails to generate the scene. There may be a collapse in the temporal attention and we have identified a potential bug in our code. We are working hard to fix it. Besides, we will increase our model size and training data to improve the generation quality in the next version.
|
||||
- **Noisy generation and influency**: we find the generated model is sometimes noisy and not fluent, especially for long videos. We think the problem is due to not using a temporal VAE. As [Pixart-Sigma](https://arxiv.org/abs/2403.04692) finds that adapting to a new VAE is simple, we plan to develop a temporal VAE for the model in the next version.
|
||||
- **Lack of time consistency**: we find the model cannot generate videos with high time consistency. We think the problem is due to the lack of training FLOPs. We plan to collect more data and continue training the model to improve the time consistency.
|
||||
- **Bad human generation**: We find the model cannot generate high-quality human videos. We think the problem is due to the lack of human data. We plan to collect more human data and continue training the model to improve the human generation.
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@
|
|||
|
||||
## 安装
|
||||
|
||||
### 从源码安装
|
||||
```bash
|
||||
# create a virtual env
|
||||
conda create -n opensora python=3.10
|
||||
|
|
@ -112,6 +113,20 @@ cd Open-Sora
|
|||
pip install -v .
|
||||
```
|
||||
|
||||
### 使用Docker镜像
|
||||
|
||||
运行如下指令使用提供的Dockerfile构建镜像:
|
||||
|
||||
```bash
|
||||
docker build -t opensora ./docker
|
||||
```
|
||||
|
||||
运行以下命令以启动交互模式下的 Docker 容器:
|
||||
|
||||
```bash
|
||||
docker run -ti --gpus all -v {MOUNT_DIR}:/data opensora
|
||||
```
|
||||
|
||||
安装完成后,建议阅读[结构](structure.md),了解项目结构以及如何使用配置文件。
|
||||
|
||||
## 模型权重
|
||||
|
|
|
|||
224
gradio/app.py
224
gradio/app.py
|
|
@ -19,20 +19,56 @@ import spaces
|
|||
import torch
|
||||
|
||||
import gradio as gr
|
||||
from tempfile import NamedTemporaryFile
|
||||
import datetime
|
||||
|
||||
|
||||
MODEL_TYPES = ["v1.1"]
|
||||
|
||||
MODEL_TYPES = ["v1.1-stage2", "v1.1-stage3"]
|
||||
CONFIG_MAP = {
|
||||
"v1.1": "configs/opensora-v1-1/inference/sample-ref.py",
|
||||
"v1.1-stage2": "configs/opensora-v1-1/inference/sample-ref.py",
|
||||
"v1.1-stage3": "configs/opensora-v1-1/inference/sample-ref.py",
|
||||
}
|
||||
HF_STDIT_MAP = {
|
||||
"v1.1": "hpcai-tech/OpenSora-STDiT-v2-stage2",
|
||||
"v1.1-stage2": "hpcai-tech/OpenSora-STDiT-v2-stage2",
|
||||
"v1.1-stage3": "hpcai-tech/OpenSora-STDiT-v2-stage3",
|
||||
}
|
||||
RESOLUTION_MAP = {
|
||||
"360p": (360, 480),
|
||||
"480p": (480, 858),
|
||||
"720p": (720, 1280),
|
||||
"1080p": (1080, 1920)
|
||||
"144p": {
|
||||
"16:9": (256, 144),
|
||||
"9:16": (144, 256),
|
||||
"4:3": (221, 165),
|
||||
"3:4": (165, 221),
|
||||
"1:1": (192, 192),
|
||||
},
|
||||
"240p": {
|
||||
"16:9": (426, 240),
|
||||
"9:16": (240, 426),
|
||||
"4:3": (370, 278),
|
||||
"3:4": (278, 370),
|
||||
"1:1": (320, 320),
|
||||
},
|
||||
"360p": {
|
||||
"16:9": (640, 360),
|
||||
"9:16": (360, 640),
|
||||
"4:3": (554, 416),
|
||||
"3:4": (416, 554),
|
||||
"1:1": (480, 480),
|
||||
},
|
||||
"480p": {
|
||||
"16:9": (854, 480),
|
||||
"9:16": (480, 854),
|
||||
"4:3": (740, 555),
|
||||
"3:4": (555, 740),
|
||||
"1:1": (640, 640),
|
||||
},
|
||||
"720p": {
|
||||
"16:9": (1280, 720),
|
||||
"9:16": (720, 1280),
|
||||
"4:3": (1108, 832),
|
||||
"3:4": (832, 1110),
|
||||
"1:1": (960, 960),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -219,9 +255,9 @@ def build_models(model_type, config, enable_optimization=False):
|
|||
# build stdit
|
||||
# we load model from HuggingFace directly so that we don't need to
|
||||
# handle model download logic in HuggingFace Space
|
||||
from transformers import AutoModel
|
||||
from opensora.models.stdit.stdit2 import STDiT2
|
||||
|
||||
stdit = AutoModel.from_pretrained(
|
||||
stdit = STDiT2.from_pretrained(
|
||||
HF_STDIT_MAP[model_type],
|
||||
enable_flash_attn=enable_optimization,
|
||||
trust_remote_code=True,
|
||||
|
|
@ -249,7 +285,7 @@ def parse_args():
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
default="v1.1",
|
||||
default="v1.1-stage3",
|
||||
choices=MODEL_TYPES,
|
||||
help=f"The type of model to run for the Gradio App, can only be {MODEL_TYPES}",
|
||||
)
|
||||
|
|
@ -298,37 +334,53 @@ device = torch.device("cuda")
|
|||
vae, text_encoder, stdit, scheduler = build_models(args.model_type, config, enable_optimization=args.enable_optimization)
|
||||
|
||||
|
||||
@spaces.GPU(duration=200)
|
||||
def run_inference(mode, prompt_text, resolution, length, reference_image):
|
||||
def run_inference(mode, prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale):
|
||||
torch.manual_seed(seed)
|
||||
with torch.inference_mode():
|
||||
# ======================
|
||||
# 1. Preparation
|
||||
# ======================
|
||||
# parse the inputs
|
||||
resolution = RESOLUTION_MAP[resolution]
|
||||
|
||||
resolution = RESOLUTION_MAP[resolution][aspect_ratio]
|
||||
|
||||
# gather args from config
|
||||
num_frames = config.num_frames
|
||||
frame_interval = config.frame_interval
|
||||
fps = config.fps
|
||||
condition_frame_length = config.condition_frame_length
|
||||
|
||||
# compute number of loops
|
||||
num_seconds = int(length.rstrip('s'))
|
||||
total_number_of_frames = num_seconds * config.fps / config.frame_interval
|
||||
num_loop = math.ceil(total_number_of_frames / config.num_frames)
|
||||
if mode == "Text2Image":
|
||||
num_frames = 1
|
||||
num_loop = 1
|
||||
else:
|
||||
num_seconds = int(length.rstrip('s'))
|
||||
if num_seconds <= 16:
|
||||
num_frames = num_seconds * fps // frame_interval
|
||||
num_loop = 1
|
||||
else:
|
||||
config.num_frames = 16
|
||||
total_number_of_frames = num_seconds * fps / frame_interval
|
||||
num_loop = math.ceil((total_number_of_frames - condition_frame_length) / (num_frames - condition_frame_length))
|
||||
|
||||
# prepare model args
|
||||
model_args = dict()
|
||||
height = torch.tensor([resolution[0]], device=device, dtype=dtype)
|
||||
width = torch.tensor([resolution[1]], device=device, dtype=dtype)
|
||||
num_frames = torch.tensor([config.num_frames], device=device, dtype=dtype)
|
||||
ar = torch.tensor([resolution[0] / resolution[1]], device=device, dtype=dtype)
|
||||
if config.num_frames == 1:
|
||||
config.fps = IMG_FPS
|
||||
fps = torch.tensor([config.fps], device=device, dtype=dtype)
|
||||
model_args["height"] = height
|
||||
model_args["width"] = width
|
||||
model_args["num_frames"] = num_frames
|
||||
model_args["ar"] = ar
|
||||
model_args["fps"] = fps
|
||||
fps = IMG_FPS
|
||||
|
||||
model_args = dict()
|
||||
height_tensor = torch.tensor([resolution[0]], device=device, dtype=dtype)
|
||||
width_tensor = torch.tensor([resolution[1]], device=device, dtype=dtype)
|
||||
num_frames_tensor = torch.tensor([num_frames], device=device, dtype=dtype)
|
||||
ar_tensor = torch.tensor([resolution[0] / resolution[1]], device=device, dtype=dtype)
|
||||
fps_tensor = torch.tensor([fps], device=device, dtype=dtype)
|
||||
model_args["height"] = height_tensor
|
||||
model_args["width"] = width_tensor
|
||||
model_args["num_frames"] = num_frames_tensor
|
||||
model_args["ar"] = ar_tensor
|
||||
model_args["fps"] = fps_tensor
|
||||
|
||||
# compute latent size
|
||||
input_size = (config.num_frames, *resolution)
|
||||
input_size = (num_frames, *resolution)
|
||||
latent_size = vae.get_latent_size(input_size)
|
||||
|
||||
# process prompt
|
||||
|
|
@ -338,24 +390,32 @@ def run_inference(mode, prompt_text, resolution, length, reference_image):
|
|||
video_clips = []
|
||||
|
||||
# prepare mask strategy
|
||||
if mode == "Text2Video":
|
||||
if mode == "Text2Image":
|
||||
mask_strategy = [None]
|
||||
elif mode == "Image2Video":
|
||||
mask_strategy = ['0']
|
||||
elif mode == "Text2Video":
|
||||
if reference_image is not None:
|
||||
mask_strategy = ['0']
|
||||
else:
|
||||
mask_strategy = [None]
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
||||
# =========================
|
||||
# 2. Load reference images
|
||||
# =========================
|
||||
if mode == "Text2Video":
|
||||
if mode == "Text2Image":
|
||||
refs_x = collect_references_batch([None], vae, resolution)
|
||||
elif mode == "Image2Video":
|
||||
# save image to disk
|
||||
from PIL import Image
|
||||
im = Image.fromarray(reference_image)
|
||||
im.save("test.jpg")
|
||||
refs_x = collect_references_batch(["test.jpg"], vae, resolution)
|
||||
elif mode == "Text2Video":
|
||||
if reference_image is not None:
|
||||
# save image to disk
|
||||
from PIL import Image
|
||||
im = Image.fromarray(reference_image)
|
||||
|
||||
with NamedTemporaryFile(suffix=".jpg") as temp_file:
|
||||
im.save(temp_file.name)
|
||||
refs_x = collect_references_batch([temp_file.name], vae, resolution)
|
||||
else:
|
||||
refs_x = collect_references_batch([None], vae, resolution)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
||||
|
|
@ -382,11 +442,20 @@ def run_inference(mode, prompt_text, resolution, length, reference_image):
|
|||
mask_strategy[j] += ";"
|
||||
mask_strategy[
|
||||
j
|
||||
] += f"{loop_i},{len(refs)-1},-{config.condition_frame_length},0,{config.condition_frame_length}"
|
||||
] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length}"
|
||||
|
||||
masks = apply_mask_strategy(z, refs_x, mask_strategy, loop_i)
|
||||
|
||||
# 4.6. diffusion sampling
|
||||
# hack to update num_sampling_steps and cfg_scale
|
||||
scheduler_kwargs = config.scheduler.copy()
|
||||
scheduler_kwargs.pop('type')
|
||||
scheduler_kwargs['num_sampling_steps'] = sampling_steps
|
||||
scheduler_kwargs['cfg_scale'] = cfg_scale
|
||||
|
||||
scheduler.__init__(
|
||||
**scheduler_kwargs
|
||||
)
|
||||
samples = scheduler.sample(
|
||||
stdit,
|
||||
text_encoder,
|
||||
|
|
@ -406,10 +475,20 @@ def run_inference(mode, prompt_text, resolution, length, reference_image):
|
|||
for i in range(1, num_loop)
|
||||
]
|
||||
video = torch.cat(video_clips_list, dim=1)
|
||||
save_path = f"{args.output}/sample"
|
||||
saved_path = save_sample(video, fps=config.fps // config.frame_interval, save_path=save_path, force_video=True)
|
||||
current_datetime = datetime.datetime.now()
|
||||
timestamp = current_datetime.timestamp()
|
||||
save_path = os.path.join(args.output, f"output_{timestamp}")
|
||||
saved_path = save_sample(video, save_path=save_path, fps=config.fps // config.frame_interval)
|
||||
return saved_path
|
||||
|
||||
@spaces.GPU(duration=200)
|
||||
def run_image_inference(prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale):
|
||||
return run_inference("Text2Image", prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale)
|
||||
|
||||
@spaces.GPU(duration=200)
|
||||
def run_video_inference(prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale):
|
||||
return run_inference("Text2Video", prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale)
|
||||
|
||||
|
||||
def main():
|
||||
# create demo
|
||||
|
|
@ -438,31 +517,54 @@ def main():
|
|||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
mode = gr.Radio(
|
||||
choices=["Text2Video", "Image2Video"],
|
||||
value="Text2Video",
|
||||
label="Usage",
|
||||
info="Choose your usage scenario",
|
||||
)
|
||||
prompt_text = gr.Textbox(
|
||||
label="Prompt",
|
||||
placeholder="Describe your video here",
|
||||
lines=4,
|
||||
)
|
||||
resolution = gr.Radio(
|
||||
choices=["360p", "480p", "720p", "1080p"],
|
||||
value="360p",
|
||||
choices=["144p", "240p", "360p", "480p", "720p"],
|
||||
value="240p",
|
||||
label="Resolution",
|
||||
)
|
||||
aspect_ratio = gr.Radio(
|
||||
choices=["9:16", "16:9", "3:4", "4:3", "1:1"],
|
||||
value="9:16",
|
||||
label="Aspect Ratio (H:W)",
|
||||
)
|
||||
length = gr.Radio(
|
||||
choices=["2s", "4s", "8s"],
|
||||
choices=["2s", "4s", "8s", "16s"],
|
||||
value="2s",
|
||||
label="Video Length",
|
||||
label="Video Length (only effective for video generation)",
|
||||
info="8s may fail as Hugging Face ZeroGPU has the limitation of max 200 seconds inference time."
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed = gr.Slider(
|
||||
value=1024,
|
||||
minimum=1,
|
||||
maximum=2048,
|
||||
step=1,
|
||||
label="Seed"
|
||||
)
|
||||
|
||||
sampling_steps = gr.Slider(
|
||||
value=100,
|
||||
minimum=1,
|
||||
maximum=200,
|
||||
step=1,
|
||||
label="Sampling steps"
|
||||
)
|
||||
cfg_scale = gr.Slider(
|
||||
value=7.0,
|
||||
minimum=0.0,
|
||||
maximum=10.0,
|
||||
step=0.1,
|
||||
label="CFG Scale"
|
||||
)
|
||||
|
||||
reference_image = gr.Image(
|
||||
label="Reference Image (only used for Image2Video)",
|
||||
label="Reference Image (Optional)",
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
|
|
@ -472,12 +574,18 @@ def main():
|
|||
)
|
||||
|
||||
with gr.Row():
|
||||
submit_button = gr.Button("Generate video")
|
||||
image_gen_button = gr.Button("Generate image")
|
||||
video_gen_button = gr.Button("Generate video")
|
||||
|
||||
|
||||
submit_button.click(
|
||||
fn=run_inference,
|
||||
inputs=[mode, prompt_text, resolution, length, reference_image],
|
||||
image_gen_button.click(
|
||||
fn=run_image_inference,
|
||||
inputs=[prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale],
|
||||
outputs=reference_image
|
||||
)
|
||||
video_gen_button.click(
|
||||
fn=run_video_inference,
|
||||
inputs=[prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale],
|
||||
outputs=output_video
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
xformers
|
||||
git+https://github.com/hpcaitech/Open-Sora.git#egg=opensora
|
||||
transformers
|
||||
git+https://github.com/hpcaitech/Open-Sora.git#egg=opensora
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ class Attention(nn.Module):
|
|||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = LlamaRMSNorm,
|
||||
enable_flashattn: bool = False,
|
||||
enable_flash_attn: bool = False,
|
||||
rope=None,
|
||||
qk_norm_legacy: bool = False,
|
||||
) -> None:
|
||||
|
|
@ -149,7 +149,7 @@ class Attention(nn.Module):
|
|||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.enable_flashattn = enable_flashattn
|
||||
self.enable_flash_attn = enable_flash_attn
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
|
|
@ -167,7 +167,7 @@ class Attention(nn.Module):
|
|||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
# flash attn is not memory efficient for small sequences, this is empirical
|
||||
enable_flashattn = self.enable_flashattn and (N > B)
|
||||
enable_flash_attn = self.enable_flash_attn and (N > B)
|
||||
qkv = self.qkv(x)
|
||||
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||
|
||||
|
|
@ -185,7 +185,7 @@ class Attention(nn.Module):
|
|||
q = self.rotary_emb(q)
|
||||
k = self.rotary_emb(k)
|
||||
|
||||
if enable_flashattn:
|
||||
if enable_flash_attn:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# (B, #heads, N, #dim) -> (B, N, #heads, #dim)
|
||||
|
|
@ -210,7 +210,7 @@ class Attention(nn.Module):
|
|||
x = attn @ v
|
||||
|
||||
x_output_shape = (B, N, C)
|
||||
if not enable_flashattn:
|
||||
if not enable_flash_attn:
|
||||
x = x.transpose(1, 2)
|
||||
x = x.reshape(x_output_shape)
|
||||
x = self.proj(x)
|
||||
|
|
@ -358,7 +358,7 @@ class SeqParallelAttention(Attention):
|
|||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = LlamaRMSNorm,
|
||||
enable_flashattn: bool = False,
|
||||
enable_flash_attn: bool = False,
|
||||
rope=None,
|
||||
qk_norm_legacy: bool = False,
|
||||
) -> None:
|
||||
|
|
@ -371,7 +371,7 @@ class SeqParallelAttention(Attention):
|
|||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
norm_layer=norm_layer,
|
||||
enable_flashattn=enable_flashattn,
|
||||
enable_flash_attn=enable_flash_attn,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
|
@ -387,7 +387,7 @@ class SeqParallelAttention(Attention):
|
|||
# [B, SUB_N, 3, NUM_HEAD, HEAD_DIM] -> [B, N, 3, NUM_HEAD_PER_DEVICE, HEAD_DIM]
|
||||
qkv = all_to_all(qkv, sp_group, scatter_dim=3, gather_dim=1)
|
||||
|
||||
if self.enable_flashattn:
|
||||
if self.enable_flash_attn:
|
||||
qkv_permute_shape = (
|
||||
2,
|
||||
0,
|
||||
|
|
@ -408,7 +408,7 @@ class SeqParallelAttention(Attention):
|
|||
# ERROR: Should qk_norm first
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
if self.enable_flashattn:
|
||||
if self.enable_flash_attn:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
x = flash_attn_func(
|
||||
|
|
@ -428,7 +428,7 @@ class SeqParallelAttention(Attention):
|
|||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
if not self.enable_flashattn:
|
||||
if not self.enable_flash_attn:
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
# apply all to all to gather back attention heads and split sequence
|
||||
|
|
|
|||
|
|
@ -1,23 +1,19 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import os
|
||||
from einops import rearrange
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
from timm.models.layers import DropPath
|
||||
from timm.models.vision_transformer import Mlp
|
||||
|
||||
from opensora.acceleration.checkpoint import auto_grad_checkpoint
|
||||
from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward
|
||||
from opensora.acceleration.parallel_states import get_sequence_parallel_group
|
||||
from opensora.models.layers.blocks import (
|
||||
Attention,
|
||||
CaptionEmbedder,
|
||||
MultiHeadCrossAttention,
|
||||
PatchEmbed3D,
|
||||
PositionEmbedding2D,
|
||||
SeqParallelAttention,
|
||||
SeqParallelMultiHeadCrossAttention,
|
||||
SizeEmbedder,
|
||||
T2IFinalLayer,
|
||||
TimestepEmbedder,
|
||||
|
|
@ -27,6 +23,7 @@ from opensora.models.layers.blocks import (
|
|||
t2i_modulate,
|
||||
)
|
||||
from opensora.registry import MODELS
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
from opensora.utils.ckpt_utils import load_checkpoint
|
||||
|
||||
|
||||
|
|
@ -37,7 +34,7 @@ class STDiT2Block(nn.Module):
|
|||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
drop_path=0.0,
|
||||
enable_flashattn=False,
|
||||
enable_flash_attn=False,
|
||||
enable_layernorm_kernel=False,
|
||||
enable_sequence_parallelism=False,
|
||||
rope=None,
|
||||
|
|
@ -46,31 +43,23 @@ class STDiT2Block(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.enable_flashattn = enable_flashattn
|
||||
self.enable_flash_attn = enable_flash_attn
|
||||
self._enable_sequence_parallelism = enable_sequence_parallelism
|
||||
|
||||
assert not self._enable_sequence_parallelism, "Sequence parallelism is not supported."
|
||||
if enable_sequence_parallelism:
|
||||
self.attn_cls = SeqParallelAttention
|
||||
self.mha_cls = SeqParallelMultiHeadCrossAttention
|
||||
else:
|
||||
self.attn_cls = Attention
|
||||
self.mha_cls = MultiHeadCrossAttention
|
||||
|
||||
# spatial branch
|
||||
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
|
||||
self.attn = self.attn_cls(
|
||||
self.attn = Attention(
|
||||
hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=True,
|
||||
enable_flashattn=enable_flashattn,
|
||||
enable_flash_attn=enable_flash_attn,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_legacy=qk_norm_legacy,
|
||||
)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
|
||||
|
||||
# cross attn
|
||||
self.cross_attn = self.mha_cls(hidden_size, num_heads)
|
||||
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads)
|
||||
|
||||
# mlp branch
|
||||
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
|
||||
|
|
@ -81,11 +70,11 @@ class STDiT2Block(nn.Module):
|
|||
|
||||
# temporal branch
|
||||
self.norm_temp = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) # new
|
||||
self.attn_temp = self.attn_cls(
|
||||
self.attn_temp = Attention(
|
||||
hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=True,
|
||||
enable_flashattn=self.enable_flashattn,
|
||||
enable_flash_attn=self.enable_flash_attn,
|
||||
rope=rope,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_legacy=qk_norm_legacy,
|
||||
|
|
@ -177,8 +166,10 @@ class STDiT2Block(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class STDiT2(nn.Module):
|
||||
class STDiT2Config(PretrainedConfig):
|
||||
|
||||
model_type = "STDiT2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size=(None, None, None),
|
||||
|
|
@ -195,46 +186,75 @@ class STDiT2(nn.Module):
|
|||
no_temporal_pos_emb=False,
|
||||
caption_channels=4096,
|
||||
model_max_length=120,
|
||||
dtype=torch.float32,
|
||||
freeze=None,
|
||||
qk_norm=False,
|
||||
qk_norm_legacy=False,
|
||||
enable_flashattn=False,
|
||||
enable_flash_attn=False,
|
||||
enable_layernorm_kernel=False,
|
||||
enable_sequence_parallelism=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.pred_sigma = pred_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.dtype = dtype
|
||||
self.no_temporal_pos_emb = no_temporal_pos_emb
|
||||
self.depth = depth
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.enable_flashattn = enable_flashattn
|
||||
self.enable_layernorm_kernel = enable_layernorm_kernel
|
||||
|
||||
# support dynamic input
|
||||
self.patch_size = patch_size
|
||||
self.input_size = input_size
|
||||
self.input_sq_size = input_sq_size
|
||||
self.pos_embed = PositionEmbedding2D(hidden_size)
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.hidden_size = hidden_size
|
||||
self.depth = depth
|
||||
self.num_heads = num_heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.class_dropout_prob = class_dropout_prob
|
||||
self.pred_sigma = pred_sigma
|
||||
self.drop_path = drop_path
|
||||
self.no_temporal_pos_emb = no_temporal_pos_emb
|
||||
self.caption_channels = caption_channels
|
||||
self.model_max_length = model_max_length
|
||||
self.freeze = freeze
|
||||
self.qk_norm = qk_norm
|
||||
self.qk_norm_legacy = qk_norm_legacy
|
||||
self.enable_flash_attn = enable_flash_attn
|
||||
self.enable_layernorm_kernel = enable_layernorm_kernel
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
||||
self.t_block_temp = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 3 * hidden_size, bias=True)) # new
|
||||
|
||||
@MODELS.register_module()
|
||||
class STDiT2(PreTrainedModel):
|
||||
|
||||
config_class = STDiT2Config
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config
|
||||
):
|
||||
super().__init__(config)
|
||||
self.pred_sigma = config.pred_sigma
|
||||
self.in_channels = config.in_channels
|
||||
self.out_channels = config.in_channels * 2 if config.pred_sigma else config.in_channels
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_heads
|
||||
self.no_temporal_pos_emb = config.no_temporal_pos_emb
|
||||
self.depth = config.depth
|
||||
self.mlp_ratio = config.mlp_ratio
|
||||
self.enable_flash_attn = config.enable_flash_attn
|
||||
self.enable_layernorm_kernel = config.enable_layernorm_kernel
|
||||
|
||||
# support dynamic input
|
||||
self.patch_size = config.patch_size
|
||||
self.input_size = config.input_size
|
||||
self.input_sq_size = config.input_sq_size
|
||||
self.pos_embed = PositionEmbedding2D(config.hidden_size)
|
||||
|
||||
self.x_embedder = PatchEmbed3D(config.patch_size, config.in_channels, config.hidden_size)
|
||||
self.t_embedder = TimestepEmbedder(config.hidden_size)
|
||||
self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(config.hidden_size, 6 * config.hidden_size, bias=True))
|
||||
self.t_block_temp = nn.Sequential(nn.SiLU(), nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=True)) # new
|
||||
self.y_embedder = CaptionEmbedder(
|
||||
in_channels=caption_channels,
|
||||
hidden_size=hidden_size,
|
||||
uncond_prob=class_dropout_prob,
|
||||
in_channels=config.caption_channels,
|
||||
hidden_size=config.hidden_size,
|
||||
uncond_prob=config.class_dropout_prob,
|
||||
act_layer=approx_gelu,
|
||||
token_num=model_max_length,
|
||||
token_num=config.model_max_length,
|
||||
)
|
||||
|
||||
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
||||
drop_path = [x.item() for x in torch.linspace(0, config.drop_path, config.depth)]
|
||||
self.rope = RotaryEmbedding(dim=self.hidden_size // self.num_heads) # new
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
|
|
@ -243,17 +263,16 @@ class STDiT2(nn.Module):
|
|||
self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
drop_path=drop_path[i],
|
||||
enable_flashattn=self.enable_flashattn,
|
||||
enable_flash_attn=self.enable_flash_attn,
|
||||
enable_layernorm_kernel=self.enable_layernorm_kernel,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
rope=self.rope.rotate_queries_or_keys,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_legacy=qk_norm_legacy,
|
||||
qk_norm=config.qk_norm,
|
||||
qk_norm_legacy=config.qk_norm_legacy,
|
||||
)
|
||||
for i in range(self.depth)
|
||||
]
|
||||
)
|
||||
self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels)
|
||||
self.final_layer = T2IFinalLayer(config.hidden_size, np.prod(self.patch_size), self.out_channels)
|
||||
|
||||
# multi_res
|
||||
assert self.hidden_size % 3 == 0, "hidden_size must be divisible by 3"
|
||||
|
|
@ -265,20 +284,13 @@ class STDiT2(nn.Module):
|
|||
# init model
|
||||
self.initialize_weights()
|
||||
self.initialize_temporal()
|
||||
if freeze is not None:
|
||||
assert freeze in ["not_temporal", "text"]
|
||||
if freeze == "not_temporal":
|
||||
if config.freeze is not None:
|
||||
assert config.freeze in ["not_temporal", "text"]
|
||||
if config.freeze == "not_temporal":
|
||||
self.freeze_not_temporal()
|
||||
elif freeze == "text":
|
||||
elif config.freeze == "text":
|
||||
self.freeze_text()
|
||||
|
||||
# sequence parallel related configs
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
if enable_sequence_parallelism:
|
||||
self.sp_rank = dist.get_rank(get_sequence_parallel_group())
|
||||
else:
|
||||
self.sp_rank = None
|
||||
|
||||
def get_dynamic_size(self, x):
|
||||
_, _, T, H, W = x.size()
|
||||
if T % self.patch_size[0] != 0:
|
||||
|
|
@ -307,9 +319,10 @@ class STDiT2(nn.Module):
|
|||
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
|
||||
"""
|
||||
B = x.shape[0]
|
||||
x = x.to(self.dtype)
|
||||
timestep = timestep.to(self.dtype)
|
||||
y = y.to(self.dtype)
|
||||
dtype = self.x_embedder.proj.weight.dtype
|
||||
x = x.to(dtype)
|
||||
timestep = timestep.to(dtype)
|
||||
y = y.to(dtype)
|
||||
|
||||
# === process data info ===
|
||||
# 1. get dynamic size
|
||||
|
|
@ -342,10 +355,6 @@ class STDiT2(nn.Module):
|
|||
x = x + pos_emb
|
||||
x = rearrange(x, "B T S C -> B (T S) C")
|
||||
|
||||
# shard over the sequence dim if sp is enabled
|
||||
if self.enable_sequence_parallelism:
|
||||
x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
|
||||
|
||||
# prepare adaIN
|
||||
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
|
||||
t_spc = t + data_info # [B, C]
|
||||
|
|
@ -393,10 +402,7 @@ class STDiT2(nn.Module):
|
|||
T,
|
||||
S,
|
||||
)
|
||||
|
||||
if self.enable_sequence_parallelism:
|
||||
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
|
||||
# x.shape: [B, N, C]
|
||||
# x.shape: [B, N, C]
|
||||
|
||||
# final process
|
||||
x = self.final_layer(x, t, x_mask, t0_spc, T, S) # [B, N, C=T_p * H_p * W_p * C_out]
|
||||
|
|
@ -503,7 +509,28 @@ class STDiT2(nn.Module):
|
|||
|
||||
@MODELS.register_module("STDiT2-XL/2")
|
||||
def STDiT2_XL_2(from_pretrained=None, **kwargs):
|
||||
model = STDiT2(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
|
||||
if from_pretrained is not None:
|
||||
load_checkpoint(model, from_pretrained)
|
||||
if os.path.isdir(from_pretrained) or os.path.isfile(from_pretrained):
|
||||
# if it is a directory or a file, we load the checkpoint manually
|
||||
config = STDiT2Config(
|
||||
depth=28,
|
||||
hidden_size=1152,
|
||||
patch_size=(1, 2, 2),
|
||||
num_heads=16, **kwargs
|
||||
)
|
||||
model = STDiT2(config)
|
||||
load_checkpoint(model, from_pretrained)
|
||||
return model
|
||||
else:
|
||||
# otherwise, we load the model from hugging face hub
|
||||
return STDiT2.from_pretrained(from_pretrained)
|
||||
else:
|
||||
# create a new model
|
||||
config = STDiT2Config(
|
||||
depth=28,
|
||||
hidden_size=1152,
|
||||
patch_size=(1, 2, 2),
|
||||
num_heads=16, **kwargs
|
||||
)
|
||||
model = STDiT2(config)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -7,10 +7,11 @@ mmengine
|
|||
pandas
|
||||
pre-commit
|
||||
pyarrow
|
||||
pyav
|
||||
av
|
||||
tensorboard
|
||||
timm
|
||||
tqdm
|
||||
transformers
|
||||
wandb
|
||||
rotary_embedding_torch
|
||||
pandarallel
|
||||
|
|
|
|||
|
|
@ -168,7 +168,6 @@ def main():
|
|||
in_channels=vae.out_channels,
|
||||
caption_channels=text_encoder.output_dim,
|
||||
model_max_length=text_encoder.model_max_length,
|
||||
dtype=dtype,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
)
|
||||
text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ def main():
|
|||
vae = build_module(cfg.vae, MODELS)
|
||||
latent_size = vae.get_latent_size(input_size)
|
||||
text_encoder = build_module(cfg.text_encoder, MODELS, device=device) # T5 must be fp32
|
||||
|
||||
model = build_module(
|
||||
cfg.model,
|
||||
MODELS,
|
||||
|
|
@ -64,7 +65,6 @@ def main():
|
|||
in_channels=vae.out_channels,
|
||||
caption_channels=text_encoder.output_dim,
|
||||
model_max_length=text_encoder.model_max_length,
|
||||
dtype=dtype,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
)
|
||||
text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance
|
||||
|
|
|
|||
|
|
@ -133,8 +133,7 @@ def main():
|
|||
input_size=latent_size,
|
||||
in_channels=vae.out_channels,
|
||||
caption_channels=text_encoder.output_dim,
|
||||
model_max_length=text_encoder.model_max_length,
|
||||
dtype=dtype,
|
||||
model_max_length=text_encoder.model_max_length
|
||||
)
|
||||
model_numel, model_numel_trainable = get_model_numel(model)
|
||||
logger.info(
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -30,7 +30,7 @@ def fetch_readme() -> str:
|
|||
|
||||
setup(
|
||||
name="opensora",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
packages=find_packages(
|
||||
exclude=(
|
||||
"assets",
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ Human labeling of videos is expensive and time-consuming. We adopt powerful imag
|
|||
|
||||
## LLaVA Captioning
|
||||
|
||||
We extract three frames from the video for captioning. With batch inference, we can achieve 10 times speedup. With approximatly 720p resolution and 3 frames, the speed is 2~3 videos/s on 8 GPUs. If we resize the smaller side to 336, the speed can be 8 videos/s.
|
||||
We extract three frames from the video for captioning. With batch inference, we can achieve 10 times speedup. With approximatly 720p resolution and 1 frames, the speed is 2~3 videos/s on 8 GPUs. If we resize the smaller side to 336, the speed can be 8 videos/s. In Open-Sora v1.1, to lower the cost, we use the 7B model.
|
||||
|
||||
### Requirement
|
||||
|
||||
|
|
@ -36,13 +36,18 @@ pip install flash-attn --no-build-isolation
|
|||
pip install colossalai decord
|
||||
```
|
||||
|
||||
Since only the 34B model's performance is comparable to GPT-4V, we only provide the usage of the 34B model. The 34B model is available [here](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b), or run our script and it will be downloaded automatically.
|
||||
|
||||
### Usage
|
||||
|
||||
Prepare a csv file for processing. The csv file can be generated by `convert_dataset.py` according to its [documentation](/tools/datasets/README.md). Then, run the following command to generate captions for videos/images with LLaVA:
|
||||
Prepare a csv file for processing. The csv file can be generated by `convert_dataset.py` according to its [documentation](/tools/datasets/README.md). Then, run the following command to generate captions for videos/images with Llava:
|
||||
|
||||
```bash
|
||||
# caption with mistral-7B
|
||||
torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava DATA.csv --dp-size 8 --tp-size 1 --model-path liuhaotian/llava-v1.6-mistral-7b --prompt video
|
||||
|
||||
# caption with llava-34B
|
||||
# NOTE: remember to enable flash attention for this model
|
||||
torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava DATA.csv --dp-size 4 --tp-size 2 --model-path liuhaotian/llava-v1.6-34b --prompt image-3ex --flash-attention
|
||||
|
||||
# we run this on 8xH800 GPUs
|
||||
torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava DATA.csv --tp-size 2 --dp-size 4 --bs 16
|
||||
|
||||
|
|
@ -51,14 +56,6 @@ torchrun --nproc_per_node 2 --standalone -m tools.caption.caption_llava DATA.csv
|
|||
|
||||
# can also caption images
|
||||
torchrun --nproc_per_node 2 --standalone -m tools.caption.caption_llava DATA.csv --tp-size 2 --dp-size 1 --bs 16 --prompt image-3ex
|
||||
|
||||
# caption with llava-34B
|
||||
# NOTE: remember to enable flash attention for this model
|
||||
torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava DATA.csv --dp-size 4 --tp-size 2 --model-path liuhaotian/llava-v1.6-34b --prompt image-3ex --flash-attention
|
||||
|
||||
# caption with mistral-7B
|
||||
torchrun --nproc_per_node 8 --standalone -m tools.caption.caption_llava DATA.csv --dp-size 8 --tp-size 1 --model-path liuhaotian/llava-v1.6-mistral-7b --prompt video
|
||||
# bs can be 48
|
||||
```
|
||||
|
||||
Please note that you should add the `--flash-attention` flag when running with Llama-based Llava models as it provides speedup but do turn it off for mistral-based ones. Reasons can be found in [this issue](https://discuss.huggingface.co/t/flash-attention-has-no-effect-on-inference/73453).
|
||||
|
|
|
|||
Loading…
Reference in a new issue