[feature] impl ulysses-style seq parallel (#20)

* [feature] add ulysses style sp attn

* [test] add sp attn test

* [feature] add zero sp plugin

* [hotfix] fix sp backward

* [test] add test for dit model
This commit is contained in:
Hongxin Liu 2024-03-01 14:42:06 +08:00 committed by GitHub
parent 3bea560af3
commit 97c089daec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 611 additions and 28 deletions

View file

@ -14,7 +14,7 @@ import torch
from colossalai import launch_from_torch
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import HybridAdam
@ -24,6 +24,7 @@ from tqdm import tqdm
from open_sora.diffusion import create_diffusion
from open_sora.modeling import DiT_models
from open_sora.utils.data import create_video_compressor, preprocess_batch
from open_sora.utils.plugin import ZeroSeqParallelPlugin
#################################################################################
# Training Loop #
@ -44,14 +45,17 @@ def main(args):
plugin = TorchDDPPlugin()
elif args.plugin == "zero2":
# use bf16 to avoid skipping the first few iterations due to NaNs
plugin = LowLevelZeroPlugin(stage=2, precision="bf16")
plugin = ZeroSeqParallelPlugin(sp_size=args.sp_size, stage=2, precision="bf16")
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
# Create video compressor
video_compressor = create_video_compressor(args.compressor)
model_kwargs = {"in_channels": video_compressor.out_channels}
model_kwargs = {
"in_channels": video_compressor.out_channels,
"seq_parallel_group": plugin.sp_group,
}
# Create DiT and EMA
model = DiT_models[args.model](**model_kwargs).to(get_current_device())
@ -75,6 +79,7 @@ def main(args):
torch.randn(args.num_frames, args.height, args.width, 3)
for _ in range(args.batch_size)
]
assert args.num_tokens % args.sp_size == 0
input_ids = torch.randn(args.batch_size, args.num_tokens, args.text_embed_dim)
text_mask = torch.ones(input_ids.shape[:2], dtype=torch.int)
batch = {
@ -82,9 +87,15 @@ def main(args):
"text_latent_states": input_ids,
"text_padding_mask": text_mask,
}
batch = preprocess_batch(batch, patch_size, video_compressor)
batch = preprocess_batch(
batch, patch_size, video_compressor, pad_to_multiple=args.sp_size
)
video_inputs = batch.pop("video_latent_states")
mask = batch.pop("video_padding_mask")
logger.info(
f"Num patches: {video_inputs.shape[1]}, num_tokens: {batch['text_latent_states'].shape[1]}",
ranks=[0],
)
# setup booster
model, opt, *_ = booster.boost(model, opt)
@ -125,7 +136,7 @@ def main(args):
ranks=[0],
)
logger.info(
f"Throughput: {throughput:.2f} samples/s",
f"Throughput per device: {throughput:.2f} samples/s",
ranks=[0],
)
@ -139,6 +150,7 @@ if __name__ == "__main__":
parser.add_argument(
"-p", "--plugin", type=str, default="zero2", choices=["ddp", "zero2"]
)
parser.add_argument("--sp_size", type=int, default=1)
parser.add_argument("-w", "--warmup_steps", type=int, default=2)
parser.add_argument("-s", "--steps", type=int, default=3)
parser.add_argument("-b", "--batch_size", type=int, default=4)

View file

@ -14,10 +14,13 @@ from typing import Callable, Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from timm.models.vision_transformer import Mlp
from open_sora.utils.comm import all_to_all, gather_seq, split_seq
class CrossAttention(nn.Module):
r"""
@ -101,6 +104,120 @@ class CrossAttention(nn.Module):
return attn_output
class SeqParallelCrossAttention(nn.Module):
r"""
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the context. If not given, defaults to `query_dim`.
num_heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
num_heads: int = 8,
head_dim: int = 64,
dropout: float = 0.0,
bias=False,
sdpa=True,
seq_parallel_group=None,
):
super().__init__()
self.hidden_size = head_dim * num_heads
cross_attention_dim = (
cross_attention_dim if cross_attention_dim is not None else query_dim
)
self.scale = head_dim**-0.5
self.num_heads = num_heads
self.head_dim = head_dim
self.sdpa = sdpa
self.seq_parallel_group = seq_parallel_group
self.seq_parallel_size = (
dist.get_world_size(self.seq_parallel_group)
if seq_parallel_group is not None
else 1
)
assert self.num_heads % self.seq_parallel_size == 0
self.to_q = nn.Linear(query_dim, self.hidden_size, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, self.hidden_size, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, self.hidden_size, bias=bias)
self.to_out = nn.Sequential(
nn.Linear(self.hidden_size, query_dim), nn.Dropout(dropout)
)
def forward(self, hidden_states, context=None, mask=None):
bsz, q_len, _ = hidden_states.shape
query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
kv_seq_len = context.shape[1]
key = self.to_k(context)
value = self.to_v(context)
# [B, S/P, H] -> [B, S, H/P]
num_heads_parallel = self.num_heads // self.seq_parallel_size
hidden_size_parallel = self.hidden_size // self.seq_parallel_size
if self.seq_parallel_group is not None and self.seq_parallel_size > 1:
query = all_to_all(
query, self.seq_parallel_group, scatter_dim=2, gather_dim=1
)
key = all_to_all(key, self.seq_parallel_group, scatter_dim=2, gather_dim=1)
value = all_to_all(
value, self.seq_parallel_group, scatter_dim=2, gather_dim=1
)
q_len *= self.seq_parallel_size
kv_seq_len *= self.seq_parallel_size
# [B, S, H/P] -> [B, S, N/P, D] -> [B, N/P, S, D]
query = query.view(bsz, q_len, num_heads_parallel, self.head_dim).transpose(
1, 2
)
key = key.view(bsz, kv_seq_len, num_heads_parallel, self.head_dim).transpose(
1, 2
)
value = value.view(
bsz, kv_seq_len, num_heads_parallel, self.head_dim
).transpose(1, 2)
if mask is not None:
assert mask.shape == (bsz, 1, q_len, kv_seq_len)
if self.sdpa:
attn_output = F.scaled_dot_product_attention(
query, key, value, attn_mask=mask, scale=self.scale
)
else:
attn_weights = torch.matmul(query, key.transpose(2, 3)) / self.scale
assert attn_weights.shape == (bsz, num_heads_parallel, q_len, kv_seq_len)
if mask is not None:
attn_weights = attn_weights + mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query.dtype
)
attn_output = torch.matmul(attn_weights, value)
assert attn_output.shape == (bsz, num_heads_parallel, q_len, self.head_dim)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, hidden_size_parallel)
# [B, S, H/P] -> [B, S/P, H]
if self.seq_parallel_group is not None and self.seq_parallel_size > 1:
attn_output = all_to_all(
attn_output, self.seq_parallel_group, scatter_dim=1, gather_dim=2
)
attn_output = self.to_out(attn_output)
return attn_output
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
@ -290,16 +407,18 @@ class DiTBlock(nn.Module):
num_heads,
cross_attention_dim=None,
mlp_ratio=4.0,
seq_parallel_group=None,
):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = CrossAttention(
self.attn = SeqParallelCrossAttention(
query_dim=hidden_size,
cross_attention_dim=cross_attention_dim,
num_heads=num_heads,
head_dim=hidden_size // num_heads,
bias=True,
sdpa=True,
seq_parallel_group=seq_parallel_group,
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
@ -380,6 +499,7 @@ class DiT(nn.Module):
text_dropout_prob=0.1,
learn_sigma=True,
use_cross_attn=True,
seq_parallel_group=None,
):
super().__init__()
self.grad_checkpointing = False
@ -388,6 +508,17 @@ class DiT(nn.Module):
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.seq_parallel_group = seq_parallel_group
self.seq_parallel_size = (
dist.get_world_size(self.seq_parallel_group)
if seq_parallel_group is not None
else 1
)
self.seq_parallel_rank = (
dist.get_rank(self.seq_parallel_group)
if seq_parallel_group is not None
else 0
)
self.video_embedder = PatchEmbedder(
patch_size, in_channels, hidden_size, bias=True
@ -409,7 +540,13 @@ class DiT(nn.Module):
self.blocks = nn.ModuleList(
[
DiTBlock(hidden_size, num_heads, cross_attn_dim, mlp_ratio=mlp_ratio)
DiTBlock(
hidden_size,
num_heads,
cross_attn_dim,
mlp_ratio=mlp_ratio,
seq_parallel_group=seq_parallel_group,
)
for _ in range(depth)
]
)
@ -483,6 +620,18 @@ class DiT(nn.Module):
video_latent_states = video_latent_states + pos_embed
t = self.t_embedder(t) # (N, D)
attention_mask = self._prepare_mask(attention_mask, video_latent_states.dtype)
if self.seq_parallel_group is not None and self.seq_parallel_size > 1:
assert video_latent_states.shape[1] % self.seq_parallel_size == 0
video_latent_states = split_seq(
video_latent_states, self.seq_parallel_size, self.seq_parallel_rank
)
if text_latent_states is not None:
assert text_latent_states.shape[1] % self.seq_parallel_size == 0
text_latent_states = split_seq(
text_latent_states, self.seq_parallel_size, self.seq_parallel_rank
)
for block in self.blocks:
if self.grad_checkpointing and self.training:
video_latent_states = torch.utils.checkpoint.checkpoint(
@ -496,6 +645,15 @@ class DiT(nn.Module):
video_latent_states = block(
video_latent_states, attention_mask, t, text_latent_states
)
if self.seq_parallel_group is not None and self.seq_parallel_size > 1:
video_latent_states = gather_seq(
video_latent_states,
self.seq_parallel_size,
self.seq_parallel_rank,
self.seq_parallel_group,
)
if not self.use_cross_attn:
video_latent_states = video_latent_states[:, text_len:]
video_latent_states = self.final_layer(video_latent_states, t)

96
open_sora/utils/comm.py Normal file
View file

@ -0,0 +1,96 @@
import torch
import torch.distributed as dist
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.shardformer.layer._operation import gather_forward_split_backward
def _all_to_all(
input_: torch.Tensor,
world_size: int,
group: dist.ProcessGroup,
scatter_dim: int,
gather_dim: int,
):
input_list = [
t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)
]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.world_size = dist.get_world_size(process_group)
return _all_to_all(
input_, ctx.world_size, process_group, scatter_dim, gather_dim
)
@staticmethod
def backward(ctx, grad_output):
return (
_all_to_all(
grad_output,
ctx.world_size,
ctx.process_group,
ctx.gather_dim,
ctx.scatter_dim,
),
None,
None,
None,
)
def all_to_all(
input_: torch.Tensor,
process_group: dist.ProcessGroup,
scatter_dim: int = 2,
gather_dim: int = 1,
):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
def split_seq(input_: torch.Tensor, sp_size: int, sp_rank: int, dim: int = 1):
"""Split a tensor along sequence dimension. It will split input and divide grad by sp_size.
Args:
input_ (torch.Tensor): The common shape is (bs, seq, *).
sp_size (int): Sequence parallel size.
sp_rank (int): Sequence parallel rank.
dim (int, optional): Sequence dimension. Defaults to 1.
"""
input_ = input_.chunk(sp_size, dim=dim)[sp_rank].clone()
return MoeOutGradScaler.apply(input_, sp_size)
def gather_seq(
input_: torch.Tensor,
sp_size: int,
sp_rank: int,
sp_group: dist.ProcessGroup,
dim: int = 1,
):
"""Gather a tensor along sequence dimension. It will gather input and multiply grad by sp_size.
Args:
input_ (torch.Tensor): The common shape is (bs, seq, *).
sp_size (int): Sequence parallel size.
sp_rank (int): Sequence parallel rank.
dim (int, optional): Sequence dimension. Defaults to 1.
"""
input_ = gather_forward_split_backward(input_, dim, sp_group)
return MoeInGradScaler.apply(input_, sp_size)

View file

@ -16,6 +16,13 @@ DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
PathType = Union[str, os.PathLike]
def ceil_to_multiple(x: int, multiple: int) -> int:
m = x % multiple
if m == 0:
return x
return x + multiple - m
def video2col(video_4d: torch.Tensor, patch_size: int) -> torch.Tensor:
"""
Convert a 4D video tensor to a 2D tensor where each row is a patch of the video.
@ -67,7 +74,9 @@ def col2video(
return video
def pad_sequences(sequences: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
def pad_sequences(
sequences: List[torch.Tensor], pad_to_multiple: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pad a list of sequences.
Args:
@ -77,6 +86,8 @@ def pad_sequences(sequences: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Te
Tuple[torch.Tensor, torch.Tensor]: Padded batch of sequences ([B, T, ...]) and padding mask ([B, T]).
"""
max_len = max([sequence.shape[0] for sequence in sequences])
if pad_to_multiple is not None:
max_len = ceil_to_multiple(max_len, pad_to_multiple)
padded_sequences = [
F.pad(
sequence, [0] * (sequence.ndim - 1) * 2 + [0, max_len - sequence.shape[0]]
@ -96,7 +107,7 @@ def pad_sequences(sequences: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Te
def patchify_batch(
videos: List[torch.Tensor], patch_size: int
videos: List[torch.Tensor], patch_size: int, pad_to_multiple: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Patchify a batch of videos.
@ -108,7 +119,7 @@ def patchify_batch(
Tuple[torch.Tensor, torch.Tensor]: Padded batch of patches ([B, S, C, P, P]) and padding mask ([B, S]).
"""
video_patches = [video2col(video, patch_size) for video in videos]
return pad_sequences(video_patches)
return pad_sequences(video_patches, pad_to_multiple=pad_to_multiple)
def expand_mask_4d(q_mask: torch.Tensor, kv_mask: torch.Tensor) -> torch.Tensor:
@ -127,7 +138,9 @@ def expand_mask_4d(q_mask: torch.Tensor, kv_mask: torch.Tensor) -> torch.Tensor:
return mask.unsqueeze(1)
def make_batch(samples: List[dict], video_dir: str) -> dict:
def make_batch(
samples: List[dict], video_dir: str, pad_to_multiple: Optional[int] = None
) -> dict:
"""Make a batch of samples.
Args:
@ -141,7 +154,7 @@ def make_batch(samples: List[dict], video_dir: str) -> dict:
for sample in samples
]
texts = [sample["text_latent_states"] for sample in samples]
texts, text_padding_mask = pad_sequences(texts)
texts, text_padding_mask = pad_sequences(texts, pad_to_multiple=pad_to_multiple)
return {
"videos": videos,
"text_latent_states": texts,
@ -269,6 +282,7 @@ def preprocess_batch(
video_compressor: VideoCompressor,
device=None,
use_cross_attn=True,
pad_to_multiple: Optional[int] = None,
) -> dict:
if device is None:
device = get_current_device()
@ -278,7 +292,9 @@ def preprocess_batch(
video = normalize_video(video)
video = video_compressor.encode(video)
videos.append(video)
video_latent_states, video_padding_mask = patchify_batch(videos, patch_size)
video_latent_states, video_padding_mask = patchify_batch(
videos, patch_size, pad_to_multiple
)
batch["video_latent_states"] = video_latent_states
batch["video_padding_mask"] = video_padding_mask
text_padding_mask = batch.pop("text_padding_mask").to(device)

102
open_sora/utils/plugin.py Normal file
View file

@ -0,0 +1,102 @@
import random
from typing import Optional
import numpy as np
import torch
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import ProcessGroupMesh
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
DP_AXIS, SP_AXIS = 0, 1
class ZeroSeqParallelPlugin(LowLevelZeroPlugin):
def __init__(
self,
sp_size: int = 1,
stage: int = 2,
precision: str = "fp16",
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
reduce_bucket_size_in_m: int = 12,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
) -> None:
super().__init__(
stage=stage,
precision=precision,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type,
reduce_bucket_size_in_m=reduce_bucket_size_in_m,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
master_weights=master_weights,
verbose=verbose,
)
self.sp_size = sp_size
assert self.world_size % sp_size == 0, "world_size must be divisible by sp_size"
self.dp_size = self.world_size // sp_size
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.sp_size)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
self.dp_rank = self.pg_mesh.coordinate(DP_AXIS)
self.sp_rank = self.pg_mesh.coordinate(SP_AXIS)
def __del__(self):
"""Destroy the prcess groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups()
def prepare_dataloader(
self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
distributed_sampler_cls=None,
**kwargs
):
_kwargs = kwargs.copy()
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(
dataset, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)

93
tests/test_model.py Normal file
View file

@ -0,0 +1,93 @@
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.nn.functional as F
from colossalai.booster import Booster
from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from torch.testing import assert_close
from open_sora.modeling import DiT_models
from open_sora.utils.plugin import ZeroSeqParallelPlugin
@parameterize("sp_size", [2, 4])
def check_dit_model_fwd_bwd(
sp_size: int, video_latent_states, text_latent_states, t, mask
):
plugin = ZeroSeqParallelPlugin(
sp_size=sp_size, stage=2, precision="fp32", master_weights=False
)
booster = Booster(plugin=plugin)
model = DiT_models["DiT-B/8"](text_dropout_prob=0.0).to(get_current_device())
parallel_model = DiT_models["DiT-B/8"](
text_dropout_prob=0.0, seq_parallel_group=plugin.sp_group
).to(get_current_device())
parallel_model.load_state_dict(model.state_dict())
opt = HybridAdam(parallel_model.parameters(), lr=1e-3)
parallel_model, opt, *_ = booster.boost(parallel_model, opt)
target = model(video_latent_states, t, text_latent_states, mask)
noise = torch.randn_like(target)
target_loss = F.mse_loss(target, noise)
target_loss.backward()
dp_video_latent_states = video_latent_states.chunk(plugin.dp_size)[plugin.dp_rank]
dp_text_latent_states = text_latent_states.chunk(plugin.dp_size)[plugin.dp_rank]
dp_t = t.chunk(plugin.dp_size)[plugin.dp_rank]
dp_mask = mask.chunk(plugin.dp_size)[plugin.dp_rank]
dp_noise = noise.chunk(plugin.dp_size)[plugin.dp_rank]
output = parallel_model(
dp_video_latent_states, dp_t, dp_text_latent_states, dp_mask
)
loss = F.mse_loss(output, dp_noise)
booster.backward(loss, opt)
if plugin.dp_size == 1:
assert_close(target, output)
for p1, p2 in zip(model.parameters(), opt._master_param_groups_of_current_rank[0]):
working_p = opt._param_store.master_to_working_param[id(p2)]
grads = opt._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = 0 if opt._partition_grads else opt._local_rank
grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
assert_close(sharded_grad, grad[: sharded_grad.shape[0]])
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(
config={},
rank=rank,
world_size=world_size,
port=port,
host="localhost",
backend="nccl",
)
b, s, c, p = 4, 20, 3, 8
dim_text, s_text = 512, 12
video_latent_states = torch.rand(b, s, c, p, p, device=get_current_device())
text_latent_states = torch.rand(b, s_text, dim_text, device=get_current_device())
t = torch.randint(0, 1000, (b,), device=get_current_device())
mask = torch.ones(b, 1, s, s_text, device=get_current_device(), dtype=torch.int)
check_dit_model_fwd_bwd(
video_latent_states=video_latent_states,
text_latent_states=text_latent_states,
t=t,
mask=mask,
)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dit_model():
spawn(run_dist, 4)
if __name__ == "__main__":
test_dit_model()

76
tests/test_sp_attn.py Normal file
View file

@ -0,0 +1,76 @@
import colossalai
import pytest
import torch
import torch.distributed as dist
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from torch.testing import assert_close
from open_sora.modeling.dit import CrossAttention, SeqParallelCrossAttention
from open_sora.utils.comm import gather_seq, split_seq
def check_sp_attn():
sp_size = dist.get_world_size()
sp_rank = dist.get_rank()
q_dim, context_dim = 8, 4
num_heads = 4
head_dim = 16
bs = 2
sq = 8
skv = 4
attn = CrossAttention(q_dim, context_dim, num_heads, head_dim).to(
get_current_device()
)
parallel_attn = SeqParallelCrossAttention(
q_dim, context_dim, num_heads, head_dim, seq_parallel_group=dist.group.WORLD
).to(get_current_device())
parallel_attn.load_state_dict(attn.state_dict())
hidden_states = torch.rand(bs, sq, q_dim, device=get_current_device())
context = torch.rand(bs, skv, context_dim, device=get_current_device())
mask = torch.zeros(bs, 1, sq, skv, device=get_current_device())
target = attn(hidden_states, context, mask)
hidden_states_parallel = split_seq(hidden_states, sp_size, sp_rank)
context_parallel = split_seq(context, sp_size, sp_rank)
output_parallel = parallel_attn(
hidden_states_parallel,
context_parallel,
mask,
)
assert torch.equal(target.chunk(sp_size, dim=1)[sp_rank], output_parallel)
output = gather_seq(output_parallel, sp_size, sp_rank, dist.group.WORLD)
assert torch.equal(target, output)
target.mean().backward()
output.mean().backward()
# all-reduce mean of grads
for p in parallel_attn.parameters():
p.grad.data.div_(sp_size)
dist.all_reduce(p.grad.data)
for p1, p2 in zip(attn.parameters(), parallel_attn.parameters()):
assert_close(p1.grad, p2.grad)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(
config={},
rank=rank,
world_size=world_size,
port=port,
host="localhost",
backend="nccl",
)
check_sp_attn()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_seq_parallel_attn():
spawn(run_dist, 2)
if __name__ == "__main__":
test_seq_parallel_attn()

View file

@ -17,7 +17,6 @@ import torch.distributed as dist
from colossalai import launch_from_torch
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingLR
@ -28,7 +27,13 @@ from tqdm import tqdm
from open_sora.diffusion import create_diffusion
from open_sora.modeling import DiT_models
from open_sora.utils.data import create_video_compressor, load_datasets, make_batch, preprocess_batch
from open_sora.utils.data import (
create_video_compressor,
load_datasets,
make_batch,
preprocess_batch,
)
from open_sora.utils.plugin import ZeroSeqParallelPlugin
#################################################################################
# Training Helper Functions #
@ -93,7 +98,7 @@ def main(args):
configure_backends()
# Step 2: set up acceleration plugins
plugin = LowLevelZeroPlugin(stage=2, precision="fp16")
plugin = ZeroSeqParallelPlugin(sp_size=args.sp_size, stage=2, precision="fp16")
booster = Booster(plugin=plugin)
if coordinator.is_master():
@ -103,7 +108,10 @@ def main(args):
# Step 3: Create video compressor
video_compressor = create_video_compressor(args.compressor)
model_kwargs = {"in_channels": video_compressor.out_channels}
model_kwargs = {
"in_channels": video_compressor.out_channels,
"seq_parallel_group": plugin.sp_group,
}
# Step 4: Create DiT and EMA
model = DiT_models[args.model](**model_kwargs).to(get_current_device())
@ -119,7 +127,9 @@ def main(args):
model.enable_gradient_checkpointing()
# Step 5: create diffusion pipeline
diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
diffusion = create_diffusion(
timestep_respacing=""
) # default: 1000 steps, linear noise schedule
# Step 6: setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
opt = HybridAdam(model.parameters(), lr=args.lr, weight_decay=0)
@ -129,15 +139,21 @@ def main(args):
dataloader = plugin.prepare_dataloader(
dataset,
batch_size=args.batch_size,
collate_fn=partial(make_batch, video_dir=args.video_dir),
collate_fn=partial(
make_batch, video_dir=args.video_dir, pad_to_multiple=args.sp_size
),
shuffle=True,
drop_last=True,
)
lr_scheduler = CosineAnnealingLR(opt, args.epochs * len(dataloader) // args.accumulation_steps)
lr_scheduler = CosineAnnealingLR(
opt, args.epochs * len(dataloader) // args.accumulation_steps
)
logger.info(f"Dataset contains {len(dataset)} samples", ranks=[0])
# Step 8: setup booster
model, opt, _, dataloader, _ = booster.boost(model, opt, dataloader=dataloader)
model, opt, _, dataloader, lr_scheduler = booster.boost(
model, opt, dataloader=dataloader, lr_scheduler=lr_scheduler
)
if args.load_model is not None:
booster.load_model(model, args.load_model)
if args.load_optimizer is not None:
@ -159,7 +175,9 @@ def main(args):
) as pbar:
total_loss = torch.tensor(0.0, device=get_current_device())
for step, batch in enumerate(dataloader):
batch = preprocess_batch(batch, patch_size, video_compressor)
batch = preprocess_batch(
batch, patch_size, video_compressor, pad_to_multiple=args.sp_size
)
video_inputs = batch.pop("video_latent_states")
mask = batch.pop("video_padding_mask")
t = torch.randint(
@ -168,7 +186,9 @@ def main(args):
(video_inputs.shape[0],),
device=video_inputs.device,
)
loss_dict = diffusion.training_losses(model, video_inputs, t, batch, mask=mask)
loss_dict = diffusion.training_losses(
model, video_inputs, t, batch, mask=mask
)
loss = loss_dict["loss"].mean() / args.accumulation_steps
total_loss.add_(loss.data)
booster.backward(loss, opt)
@ -182,7 +202,9 @@ def main(args):
all_reduce_mean(total_loss)
pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
if coordinator.is_master():
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
global_step = (epoch * num_steps_per_epoch) + (
step + 1
) // args.accumulation_steps
writer.add_scalar(
tag="Loss",
scalar_value=total_loss.item(),
@ -193,9 +215,12 @@ def main(args):
# Save DiT checkpoint:
if args.save_interval > 0 and (
(step + 1) % (args.save_interval * args.accumulation_steps) == 0 or (step + 1) == len(dataloader)
(step + 1) % (args.save_interval * args.accumulation_steps) == 0
or (step + 1) == len(dataloader)
):
save_path = os.path.join(args.checkpoint_dir, f"epoch-{epoch}-step-{step}")
save_path = os.path.join(
args.checkpoint_dir, f"epoch-{epoch}-step-{step}"
)
save_checkpoints(booster, model, opt, ema, save_path, coordinator)
logger.info(f"Saved checkpoint to {save_path}", ranks=[0])
@ -212,18 +237,23 @@ def main(args):
if __name__ == "__main__":
# Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters).
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8")
parser.add_argument(
"-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8"
)
parser.add_argument("-d", "--dataset", nargs="+", default=[])
parser.add_argument("-v", "--video_dir", type=str, required=True)
parser.add_argument("-e", "--epochs", type=int, default=10)
parser.add_argument("-b", "--batch_size", type=int, default=4)
parser.add_argument("-g", "--grad_checkpoint", action="store_true", default=False)
parser.add_argument("-a", "--accumulation_steps", default=1, type=int)
parser.add_argument("--sp_size", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--save_interval", type=int, default=20)
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
parser.add_argument("--tensorboard_dir", type=str, default="runs")
parser.add_argument("-c", "--compressor", choices=["raw", "vqvae", "vae"], default="raw")
parser.add_argument(
"-c", "--compressor", choices=["raw", "vqvae", "vae"], default="raw"
)
parser.add_argument("--load_model", default=None)
parser.add_argument("--load_optimizer", default=None)
args = parser.parse_args()