mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
* update (#57) * update * update datautil * add VBench prompt * update eval * update eval * update intepolation * add vbench eval * Dev/sdedit implementation (#56) * Update utils.py * update * update * update --------- Co-authored-by: YuKun Zhou <90625606+1zeryu@users.noreply.github.com>
35 lines
1 KiB
Python
35 lines
1 KiB
Python
import torch
|
|
from colossalai.accelerator import get_accelerator
|
|
from colossalai.utils import get_current_device
|
|
from rotary_embedding_torch import RotaryEmbedding
|
|
|
|
from opensora.models.layers.blocks import Attention
|
|
|
|
# B, S, H = 7488, 1, 1152
|
|
# B, S, H = 32, 234, 1152
|
|
B, S, H = 128, 32, 1152
|
|
N, D = 16, 72
|
|
|
|
|
|
def run_attn(enable_flashattn: bool):
|
|
get_accelerator().reset_peak_memory_stats()
|
|
rope = RotaryEmbedding(D).to(device=get_current_device(), dtype=torch.bfloat16)
|
|
attn = Attention(
|
|
H,
|
|
N,
|
|
qkv_bias=True,
|
|
rope=rope.rotate_queries_or_keys,
|
|
enable_flashattn=enable_flashattn,
|
|
).to(device=get_current_device(), dtype=torch.bfloat16)
|
|
x = torch.randn(B, S, H, device=get_current_device(), dtype=torch.bfloat16).requires_grad_()
|
|
y = attn(x)
|
|
y.mean().backward()
|
|
print(f"Peak memory: {get_accelerator().max_memory_allocated() / 1024**2:.2f} MB")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("Use flashattn")
|
|
run_attn(True)
|
|
print("No flashattn")
|
|
run_attn(False)
|