Merge branch 'release' of https://github.com/hpcaitech/Open-Sora into release

This commit is contained in:
Zangwei Zheng 2024-03-17 11:01:03 +08:00
commit 342c50ed24
8 changed files with 54 additions and 23 deletions

1
.gitignore vendored
View file

@ -169,6 +169,7 @@ runs/
checkpoints/
outputs/
samples/
pretrained_models/
# Secret files
hostfile

View file

@ -121,11 +121,17 @@ To run inference with our provided weights, first download [T5](https://huggingf
```bash
# Sample 16x256x256 (~2s)
python scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path ./path/to/your/ckpt.pth
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path ./path/to/your/ckpt.pth
# Sample 16x512x512 (~2s)
python scripts/inference.py configs/opensora/inference/16x512x512.py
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x512x512.py
# Sample 64x512x512 (~5s)
python scripts/inference.py configs/opensora/inference/64x512x512.py
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/64x512x512.py
# Sample 64x512x512 with sequence parallelism(~5s)
# sequence parallelism is enabled automatically when nproc_per_node is larger than 1
torchrun --standalone --nproc_per_node 2 scripts/inference.py configs/opensora/inference/64x512x512.py
```
For inference with other models, see [here](docs/commands.md) for more instructions.

View file

@ -7,6 +7,8 @@ model = dict(
type="STDiT-XL/2",
space_scale=0.5,
time_scale=1.0,
enable_flashattn=True,
enable_layernorm_kernel=True,
from_pretrained="PRETRAINED_MODEL",
)
vae = dict(

View file

@ -7,7 +7,9 @@ model = dict(
type="STDiT-XL/2",
space_scale=1.0,
time_scale=1.0,
from_pretrained="PRETRAINED_MODEL",
enable_flashattn=True,
enable_layernorm_kernel=True,
from_pretrained="PRETRAINED_MODEL"
)
vae = dict(
type="VideoAutoencoderKL",

View file

@ -7,6 +7,8 @@ model = dict(
type="STDiT-XL/2",
space_scale=1.0,
time_scale=2 / 3,
enable_flashattn=True,
enable_layernorm_kernel=True,
from_pretrained="PRETRAINED_MODEL",
)
vae = dict(

View file

@ -21,7 +21,7 @@ import xformers.ops
from einops import rearrange
from timm.models.vision_transformer import Mlp
from opensora.acceleration.communications import all_to_all
from opensora.acceleration.communications import all_to_all, split_forward_gather_backward
from opensora.acceleration.parallel_states import get_sequence_parallel_group
approx_gelu = lambda: nn.GELU(approximate="tanh")
@ -315,8 +315,9 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
# apply all_to_all to gather sequence and split attention heads
q = all_to_all(q, sp_group, scatter_dim=2, gather_dim=1)
k = all_to_all(k, sp_group, scatter_dim=2, gather_dim=1)
v = all_to_all(v, sp_group, scatter_dim=2, gather_dim=1)
k = split_forward_gather_backward(k, get_sequence_parallel_group(), dim=2, grad_scale="down")
v = split_forward_gather_backward(v, get_sequence_parallel_group(), dim=2, grad_scale="down")
q = q.view(1, -1, self.num_heads // sp_size, self.head_dim)
k = k.view(1, -1, self.num_heads // sp_size, self.head_dim)
@ -327,7 +328,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
# apply all to all to gather back attention heads and scatter sequence
x = x.view(B, -1, self.num_heads // sp_size, self.head_dim)
x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2)

View file

@ -1,12 +1,16 @@
import os
import torch
import colossalai
import torch.distributed as dist
from mmengine.runner import set_random_seed
from opensora.datasets import save_sample
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import to_torch_dtype
from opensora.acceleration.parallel_states import set_sequence_parallel_group
from colossalai.cluster import DistCoordinator
def load_prompts(prompt_path):
@ -17,11 +21,21 @@ def load_prompts(prompt_path):
def main():
# ======================================================
# 1. args & cfg
# 1. cfg and init distributed env
# ======================================================
cfg = parse_configs(training=False)
print(cfg)
# init distributed
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
if coordinator.world_size > 1:
set_sequence_parallel_group(dist.group.WORLD)
enable_sequence_parallelism = True
else:
enable_sequence_parallelism = False
# ======================================================
# 2. runtime variables
# ======================================================
@ -49,6 +63,7 @@ def main():
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
dtype=dtype,
enable_sequence_parallelism=enable_sequence_parallelism,
)
text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance
@ -84,11 +99,13 @@ def main():
additional_args=model_args,
)
samples = vae.decode(samples.to(dtype))
for idx, sample in enumerate(samples):
print(f"Prompt: {batch_prompts[idx]}")
save_path = os.path.join(save_dir, f"sample_{sample_idx}")
save_sample(sample, fps=cfg.fps, save_path=save_path)
sample_idx += 1
if coordinator.is_master():
for idx, sample in enumerate(samples):
print(f"Prompt: {batch_prompts[idx]}")
save_path = os.path.join(save_dir, f"sample_{sample_idx}")
save_sample(sample, fps=cfg.fps, save_path=save_path)
sample_idx += 1
if __name__ == "__main__":

View file

@ -72,18 +72,22 @@ def run_cross_attention(rank, world_size):
seq_parallel_attention = SeqParallelMultiHeadCrossAttention(
d_model=256,
num_heads=4,
).cuda()
).cuda().to(torch.bfloat16)
torch.manual_seed(1024)
attention = MultiHeadCrossAttention(
d_model=256,
num_heads=4,
).cuda()
).cuda().to(torch.bfloat16)
# make sure the weights are the same
for p1, p2 in zip(seq_parallel_attention.parameters(), attention.parameters()):
p1.data.copy_(p2.data)
# create inputs
torch.manual_seed(1024)
x = torch.randn(4, 64, 256).cuda()
y = torch.randn(4, 32, 256).cuda()
x = torch.randn(4, 64, 256).cuda().to(torch.bfloat16)
y = torch.randn(4, 32, 256).cuda().to(torch.bfloat16)
mask = [2, 10, 8, 16]
mask = None
@ -124,11 +128,7 @@ def run_cross_attention(rank, world_size):
# # check grad
for p1, p2 in zip(seq_parallel_attention.named_parameters(), attention.named_parameters()):
# if not torch.allclose(p1[1].grad, p2[1].grad, atol=1e-7):
# print(p1[0], p2[0])
assert torch.allclose(
p1[1].grad, p2[1].grad, atol=1e-7
), f"\n{p1[0]}\nvs\n{p2[0]}:\n{p1[1].grad}\nvs\n{p2[1].grad}"
assert torch.allclose(p1[1].grad, p2[1].grad, rtol=1e-3, atol=1e-4), f"\n{p1[0]}\nvs\n{p2[0]}:\n{p1[1].grad}\nvs\n{p2[1].grad}"
# # check input grad
assert torch.allclose(x.grad, seq_x.grad, atol=1e-7), f"{x.grad}\nvs\n{seq_x.grad}"