mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-21 11:59:01 +02:00
Merge branch 'release' of https://github.com/hpcaitech/Open-Sora into release
This commit is contained in:
commit
342c50ed24
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -169,6 +169,7 @@ runs/
|
|||
checkpoints/
|
||||
outputs/
|
||||
samples/
|
||||
pretrained_models/
|
||||
|
||||
# Secret files
|
||||
hostfile
|
||||
|
|
|
|||
12
README.md
12
README.md
|
|
@ -121,11 +121,17 @@ To run inference with our provided weights, first download [T5](https://huggingf
|
|||
|
||||
```bash
|
||||
# Sample 16x256x256 (~2s)
|
||||
python scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path ./path/to/your/ckpt.pth
|
||||
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path ./path/to/your/ckpt.pth
|
||||
|
||||
# Sample 16x512x512 (~2s)
|
||||
python scripts/inference.py configs/opensora/inference/16x512x512.py
|
||||
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x512x512.py
|
||||
|
||||
# Sample 64x512x512 (~5s)
|
||||
python scripts/inference.py configs/opensora/inference/64x512x512.py
|
||||
torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/64x512x512.py
|
||||
|
||||
# Sample 64x512x512 with sequence parallelism(~5s)
|
||||
# sequence parallelism is enabled automatically when nproc_per_node is larger than 1
|
||||
torchrun --standalone --nproc_per_node 2 scripts/inference.py configs/opensora/inference/64x512x512.py
|
||||
```
|
||||
|
||||
For inference with other models, see [here](docs/commands.md) for more instructions.
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ model = dict(
|
|||
type="STDiT-XL/2",
|
||||
space_scale=0.5,
|
||||
time_scale=1.0,
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
from_pretrained="PRETRAINED_MODEL",
|
||||
)
|
||||
vae = dict(
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@ model = dict(
|
|||
type="STDiT-XL/2",
|
||||
space_scale=1.0,
|
||||
time_scale=1.0,
|
||||
from_pretrained="PRETRAINED_MODEL",
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
from_pretrained="PRETRAINED_MODEL"
|
||||
)
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ model = dict(
|
|||
type="STDiT-XL/2",
|
||||
space_scale=1.0,
|
||||
time_scale=2 / 3,
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
from_pretrained="PRETRAINED_MODEL",
|
||||
)
|
||||
vae = dict(
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ import xformers.ops
|
|||
from einops import rearrange
|
||||
from timm.models.vision_transformer import Mlp
|
||||
|
||||
from opensora.acceleration.communications import all_to_all
|
||||
from opensora.acceleration.communications import all_to_all, split_forward_gather_backward
|
||||
from opensora.acceleration.parallel_states import get_sequence_parallel_group
|
||||
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
|
|
@ -315,8 +315,9 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
|
|||
|
||||
# apply all_to_all to gather sequence and split attention heads
|
||||
q = all_to_all(q, sp_group, scatter_dim=2, gather_dim=1)
|
||||
k = all_to_all(k, sp_group, scatter_dim=2, gather_dim=1)
|
||||
v = all_to_all(v, sp_group, scatter_dim=2, gather_dim=1)
|
||||
|
||||
k = split_forward_gather_backward(k, get_sequence_parallel_group(), dim=2, grad_scale="down")
|
||||
v = split_forward_gather_backward(v, get_sequence_parallel_group(), dim=2, grad_scale="down")
|
||||
|
||||
q = q.view(1, -1, self.num_heads // sp_size, self.head_dim)
|
||||
k = k.view(1, -1, self.num_heads // sp_size, self.head_dim)
|
||||
|
|
@ -327,7 +328,7 @@ class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
|
|||
if mask is not None:
|
||||
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
|
||||
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
|
||||
|
||||
|
||||
# apply all to all to gather back attention heads and scatter sequence
|
||||
x = x.view(B, -1, self.num_heads // sp_size, self.head_dim)
|
||||
x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import colossalai
|
||||
import torch.distributed as dist
|
||||
from mmengine.runner import set_random_seed
|
||||
|
||||
from opensora.datasets import save_sample
|
||||
from opensora.registry import MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
from opensora.utils.misc import to_torch_dtype
|
||||
from opensora.acceleration.parallel_states import set_sequence_parallel_group
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
|
||||
def load_prompts(prompt_path):
|
||||
|
|
@ -17,11 +21,21 @@ def load_prompts(prompt_path):
|
|||
|
||||
def main():
|
||||
# ======================================================
|
||||
# 1. args & cfg
|
||||
# 1. cfg and init distributed env
|
||||
# ======================================================
|
||||
cfg = parse_configs(training=False)
|
||||
print(cfg)
|
||||
|
||||
# init distributed
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
if coordinator.world_size > 1:
|
||||
set_sequence_parallel_group(dist.group.WORLD)
|
||||
enable_sequence_parallelism = True
|
||||
else:
|
||||
enable_sequence_parallelism = False
|
||||
|
||||
# ======================================================
|
||||
# 2. runtime variables
|
||||
# ======================================================
|
||||
|
|
@ -49,6 +63,7 @@ def main():
|
|||
caption_channels=text_encoder.output_dim,
|
||||
model_max_length=text_encoder.model_max_length,
|
||||
dtype=dtype,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
)
|
||||
text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance
|
||||
|
||||
|
|
@ -84,11 +99,13 @@ def main():
|
|||
additional_args=model_args,
|
||||
)
|
||||
samples = vae.decode(samples.to(dtype))
|
||||
for idx, sample in enumerate(samples):
|
||||
print(f"Prompt: {batch_prompts[idx]}")
|
||||
save_path = os.path.join(save_dir, f"sample_{sample_idx}")
|
||||
save_sample(sample, fps=cfg.fps, save_path=save_path)
|
||||
sample_idx += 1
|
||||
|
||||
if coordinator.is_master():
|
||||
for idx, sample in enumerate(samples):
|
||||
print(f"Prompt: {batch_prompts[idx]}")
|
||||
save_path = os.path.join(save_dir, f"sample_{sample_idx}")
|
||||
save_sample(sample, fps=cfg.fps, save_path=save_path)
|
||||
sample_idx += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -72,18 +72,22 @@ def run_cross_attention(rank, world_size):
|
|||
seq_parallel_attention = SeqParallelMultiHeadCrossAttention(
|
||||
d_model=256,
|
||||
num_heads=4,
|
||||
).cuda()
|
||||
).cuda().to(torch.bfloat16)
|
||||
|
||||
torch.manual_seed(1024)
|
||||
attention = MultiHeadCrossAttention(
|
||||
d_model=256,
|
||||
num_heads=4,
|
||||
).cuda()
|
||||
).cuda().to(torch.bfloat16)
|
||||
|
||||
# make sure the weights are the same
|
||||
for p1, p2 in zip(seq_parallel_attention.parameters(), attention.parameters()):
|
||||
p1.data.copy_(p2.data)
|
||||
|
||||
# create inputs
|
||||
torch.manual_seed(1024)
|
||||
x = torch.randn(4, 64, 256).cuda()
|
||||
y = torch.randn(4, 32, 256).cuda()
|
||||
x = torch.randn(4, 64, 256).cuda().to(torch.bfloat16)
|
||||
y = torch.randn(4, 32, 256).cuda().to(torch.bfloat16)
|
||||
|
||||
mask = [2, 10, 8, 16]
|
||||
mask = None
|
||||
|
|
@ -124,11 +128,7 @@ def run_cross_attention(rank, world_size):
|
|||
|
||||
# # check grad
|
||||
for p1, p2 in zip(seq_parallel_attention.named_parameters(), attention.named_parameters()):
|
||||
# if not torch.allclose(p1[1].grad, p2[1].grad, atol=1e-7):
|
||||
# print(p1[0], p2[0])
|
||||
assert torch.allclose(
|
||||
p1[1].grad, p2[1].grad, atol=1e-7
|
||||
), f"\n{p1[0]}\nvs\n{p2[0]}:\n{p1[1].grad}\nvs\n{p2[1].grad}"
|
||||
assert torch.allclose(p1[1].grad, p2[1].grad, rtol=1e-3, atol=1e-4), f"\n{p1[0]}\nvs\n{p2[0]}:\n{p1[1].grad}\nvs\n{p2[1].grad}"
|
||||
|
||||
# # check input grad
|
||||
assert torch.allclose(x.grad, seq_x.grad, atol=1e-7), f"{x.grad}\nvs\n{seq_x.grad}"
|
||||
|
|
|
|||
Loading…
Reference in a new issue