Merge branch 'dev/v1.2' of https://github.com/hpcaitech/Open-Sora-dev into dev/v1.2

This commit is contained in:
Tom Young 2024-06-17 13:46:26 +00:00
commit d5cf345ef5
3 changed files with 14 additions and 4 deletions

View file

@ -211,7 +211,7 @@ docker run -ti --gpus all -v {MOUNT_DIR}:/data opensora
| Model | Model Size | Data | #iterations | Batch Size | URL |
| --------- | ---------- | ---- | ----------- | ---------- | ------------------------------------------------------------- |
| Diffusion | 1.1B | 30M | 70k | Dynamic | [:link:](https://huggingface.co/hpcai-tech/OpenSora-STDiT-v3) |
| VAE | 384M | 3M | 1.18M | 8 | [:link:](https://huggingface.co/hpcai-tech/OpenSora-VAE-v1.2) |
| VAE | 384M | 3M | 1.18M | 8 | [:link:](https://huggingface.co/hpcai-tech/OpenSora-VAE-v1.2) |
See our **[report 1.2](docs/report_03.md)** for more infomation.
@ -327,6 +327,15 @@ python scripts/inference.py configs/opensora-v1-2/inference/sample.py \
For image to video generation and other functionalities, the API is compatible with Open-Sora 1.1. See [here](docs/commands.md) for more instructions.
If your installation do not contain `apex` and `flash-attn`, you need to disable them in the config file, or via the folowing command.
```bash
python scripts/inference.py configs/opensora-v1-2/inference/sample.py \
--num-frames 4s --resolution 720p \
--layernorm-kernel False --flash-attn False \
--prompt "a beautiful waterfall"
```
### GPT-4o Prompt Refinement
We find that GPT-4o can refine the prompt and improve the quality of the generated video. With this feature, you can also use other language (e.g., Chinese) as the prompt. To enable this feature, you need prepare your openai api key in the environment:

View file

@ -28,6 +28,7 @@ def parse_args(training=False):
parser.add_argument("--layernorm-kernel", default=None, type=str2bool, help="enable layernorm kernel")
parser.add_argument("--resolution", default=None, type=str, help="multi resolution")
parser.add_argument("--data-path", default=None, type=str, help="path to data csv")
parser.add_argument("--dtype", default=None, type=str, help="data type")
# ======================================================
# Inference

View file

@ -221,11 +221,11 @@ def main():
# recover the prompt list
batched_prompt_segment_list = []
start_idx = 0
segment_start_idx = 0
all_prompts = broadcast_obj_list[0]
for num_segment in prompt_segment_length:
batched_prompt_segment_list.append(all_prompts[start_idx : start_idx + num_segment])
start_idx += num_segment
batched_prompt_segment_list.append(all_prompts[segment_start_idx : segment_start_idx + num_segment])
segment_start_idx += num_segment
# 2. append score
for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):