mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-07 13:02:35 +02:00
[feature] impl ulysses-style seq parallel (#20)
* [feature] add ulysses style sp attn * [test] add sp attn test * [feature] add zero sp plugin * [hotfix] fix sp backward * [test] add test for dit model
This commit is contained in:
parent
3bea560af3
commit
97c089daec
22
benchmark.py
22
benchmark.py
|
|
@ -14,7 +14,7 @@ import torch
|
|||
from colossalai import launch_from_torch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
|
@ -24,6 +24,7 @@ from tqdm import tqdm
|
|||
from open_sora.diffusion import create_diffusion
|
||||
from open_sora.modeling import DiT_models
|
||||
from open_sora.utils.data import create_video_compressor, preprocess_batch
|
||||
from open_sora.utils.plugin import ZeroSeqParallelPlugin
|
||||
|
||||
#################################################################################
|
||||
# Training Loop #
|
||||
|
|
@ -44,14 +45,17 @@ def main(args):
|
|||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == "zero2":
|
||||
# use bf16 to avoid skipping the first few iterations due to NaNs
|
||||
plugin = LowLevelZeroPlugin(stage=2, precision="bf16")
|
||||
plugin = ZeroSeqParallelPlugin(sp_size=args.sp_size, stage=2, precision="bf16")
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# Create video compressor
|
||||
video_compressor = create_video_compressor(args.compressor)
|
||||
model_kwargs = {"in_channels": video_compressor.out_channels}
|
||||
model_kwargs = {
|
||||
"in_channels": video_compressor.out_channels,
|
||||
"seq_parallel_group": plugin.sp_group,
|
||||
}
|
||||
|
||||
# Create DiT and EMA
|
||||
model = DiT_models[args.model](**model_kwargs).to(get_current_device())
|
||||
|
|
@ -75,6 +79,7 @@ def main(args):
|
|||
torch.randn(args.num_frames, args.height, args.width, 3)
|
||||
for _ in range(args.batch_size)
|
||||
]
|
||||
assert args.num_tokens % args.sp_size == 0
|
||||
input_ids = torch.randn(args.batch_size, args.num_tokens, args.text_embed_dim)
|
||||
text_mask = torch.ones(input_ids.shape[:2], dtype=torch.int)
|
||||
batch = {
|
||||
|
|
@ -82,9 +87,15 @@ def main(args):
|
|||
"text_latent_states": input_ids,
|
||||
"text_padding_mask": text_mask,
|
||||
}
|
||||
batch = preprocess_batch(batch, patch_size, video_compressor)
|
||||
batch = preprocess_batch(
|
||||
batch, patch_size, video_compressor, pad_to_multiple=args.sp_size
|
||||
)
|
||||
video_inputs = batch.pop("video_latent_states")
|
||||
mask = batch.pop("video_padding_mask")
|
||||
logger.info(
|
||||
f"Num patches: {video_inputs.shape[1]}, num_tokens: {batch['text_latent_states'].shape[1]}",
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
# setup booster
|
||||
model, opt, *_ = booster.boost(model, opt)
|
||||
|
|
@ -125,7 +136,7 @@ def main(args):
|
|||
ranks=[0],
|
||||
)
|
||||
logger.info(
|
||||
f"Throughput: {throughput:.2f} samples/s",
|
||||
f"Throughput per device: {throughput:.2f} samples/s",
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
|
|
@ -139,6 +150,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument(
|
||||
"-p", "--plugin", type=str, default="zero2", choices=["ddp", "zero2"]
|
||||
)
|
||||
parser.add_argument("--sp_size", type=int, default=1)
|
||||
parser.add_argument("-w", "--warmup_steps", type=int, default=2)
|
||||
parser.add_argument("-s", "--steps", type=int, default=3)
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=4)
|
||||
|
|
|
|||
|
|
@ -14,10 +14,13 @@ from typing import Callable, Optional
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from timm.models.vision_transformer import Mlp
|
||||
|
||||
from open_sora.utils.comm import all_to_all, gather_seq, split_seq
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
r"""
|
||||
|
|
@ -101,6 +104,120 @@ class CrossAttention(nn.Module):
|
|||
return attn_output
|
||||
|
||||
|
||||
class SeqParallelCrossAttention(nn.Module):
|
||||
r"""
|
||||
A cross attention layer.
|
||||
|
||||
Parameters:
|
||||
query_dim (`int`): The number of channels in the query.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The number of channels in the context. If not given, defaults to `query_dim`.
|
||||
num_heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
||||
head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
bias (`bool`, *optional*, defaults to False):
|
||||
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
num_heads: int = 8,
|
||||
head_dim: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias=False,
|
||||
sdpa=True,
|
||||
seq_parallel_group=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = head_dim * num_heads
|
||||
cross_attention_dim = (
|
||||
cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
)
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.sdpa = sdpa
|
||||
self.seq_parallel_group = seq_parallel_group
|
||||
self.seq_parallel_size = (
|
||||
dist.get_world_size(self.seq_parallel_group)
|
||||
if seq_parallel_group is not None
|
||||
else 1
|
||||
)
|
||||
assert self.num_heads % self.seq_parallel_size == 0
|
||||
|
||||
self.to_q = nn.Linear(query_dim, self.hidden_size, bias=bias)
|
||||
self.to_k = nn.Linear(cross_attention_dim, self.hidden_size, bias=bias)
|
||||
self.to_v = nn.Linear(cross_attention_dim, self.hidden_size, bias=bias)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(self.hidden_size, query_dim), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, context=None, mask=None):
|
||||
bsz, q_len, _ = hidden_states.shape
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
context = context if context is not None else hidden_states
|
||||
kv_seq_len = context.shape[1]
|
||||
key = self.to_k(context)
|
||||
value = self.to_v(context)
|
||||
|
||||
# [B, S/P, H] -> [B, S, H/P]
|
||||
num_heads_parallel = self.num_heads // self.seq_parallel_size
|
||||
hidden_size_parallel = self.hidden_size // self.seq_parallel_size
|
||||
if self.seq_parallel_group is not None and self.seq_parallel_size > 1:
|
||||
query = all_to_all(
|
||||
query, self.seq_parallel_group, scatter_dim=2, gather_dim=1
|
||||
)
|
||||
key = all_to_all(key, self.seq_parallel_group, scatter_dim=2, gather_dim=1)
|
||||
value = all_to_all(
|
||||
value, self.seq_parallel_group, scatter_dim=2, gather_dim=1
|
||||
)
|
||||
|
||||
q_len *= self.seq_parallel_size
|
||||
kv_seq_len *= self.seq_parallel_size
|
||||
|
||||
# [B, S, H/P] -> [B, S, N/P, D] -> [B, N/P, S, D]
|
||||
query = query.view(bsz, q_len, num_heads_parallel, self.head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
key = key.view(bsz, kv_seq_len, num_heads_parallel, self.head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
value = value.view(
|
||||
bsz, kv_seq_len, num_heads_parallel, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if mask is not None:
|
||||
assert mask.shape == (bsz, 1, q_len, kv_seq_len)
|
||||
if self.sdpa:
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=mask, scale=self.scale
|
||||
)
|
||||
else:
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) / self.scale
|
||||
assert attn_weights.shape == (bsz, num_heads_parallel, q_len, kv_seq_len)
|
||||
if mask is not None:
|
||||
attn_weights = attn_weights + mask
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
||||
query.dtype
|
||||
)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
assert attn_output.shape == (bsz, num_heads_parallel, q_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, hidden_size_parallel)
|
||||
# [B, S, H/P] -> [B, S/P, H]
|
||||
if self.seq_parallel_group is not None and self.seq_parallel_size > 1:
|
||||
attn_output = all_to_all(
|
||||
attn_output, self.seq_parallel_group, scatter_dim=1, gather_dim=2
|
||||
)
|
||||
attn_output = self.to_out(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
|
@ -290,16 +407,18 @@ class DiTBlock(nn.Module):
|
|||
num_heads,
|
||||
cross_attention_dim=None,
|
||||
mlp_ratio=4.0,
|
||||
seq_parallel_group=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = CrossAttention(
|
||||
self.attn = SeqParallelCrossAttention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_heads=num_heads,
|
||||
head_dim=hidden_size // num_heads,
|
||||
bias=True,
|
||||
sdpa=True,
|
||||
seq_parallel_group=seq_parallel_group,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
|
@ -380,6 +499,7 @@ class DiT(nn.Module):
|
|||
text_dropout_prob=0.1,
|
||||
learn_sigma=True,
|
||||
use_cross_attn=True,
|
||||
seq_parallel_group=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.grad_checkpointing = False
|
||||
|
|
@ -388,6 +508,17 @@ class DiT(nn.Module):
|
|||
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.seq_parallel_group = seq_parallel_group
|
||||
self.seq_parallel_size = (
|
||||
dist.get_world_size(self.seq_parallel_group)
|
||||
if seq_parallel_group is not None
|
||||
else 1
|
||||
)
|
||||
self.seq_parallel_rank = (
|
||||
dist.get_rank(self.seq_parallel_group)
|
||||
if seq_parallel_group is not None
|
||||
else 0
|
||||
)
|
||||
|
||||
self.video_embedder = PatchEmbedder(
|
||||
patch_size, in_channels, hidden_size, bias=True
|
||||
|
|
@ -409,7 +540,13 @@ class DiT(nn.Module):
|
|||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
DiTBlock(hidden_size, num_heads, cross_attn_dim, mlp_ratio=mlp_ratio)
|
||||
DiTBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
cross_attn_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
seq_parallel_group=seq_parallel_group,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
|
|
@ -483,6 +620,18 @@ class DiT(nn.Module):
|
|||
video_latent_states = video_latent_states + pos_embed
|
||||
t = self.t_embedder(t) # (N, D)
|
||||
attention_mask = self._prepare_mask(attention_mask, video_latent_states.dtype)
|
||||
|
||||
if self.seq_parallel_group is not None and self.seq_parallel_size > 1:
|
||||
assert video_latent_states.shape[1] % self.seq_parallel_size == 0
|
||||
video_latent_states = split_seq(
|
||||
video_latent_states, self.seq_parallel_size, self.seq_parallel_rank
|
||||
)
|
||||
if text_latent_states is not None:
|
||||
assert text_latent_states.shape[1] % self.seq_parallel_size == 0
|
||||
text_latent_states = split_seq(
|
||||
text_latent_states, self.seq_parallel_size, self.seq_parallel_rank
|
||||
)
|
||||
|
||||
for block in self.blocks:
|
||||
if self.grad_checkpointing and self.training:
|
||||
video_latent_states = torch.utils.checkpoint.checkpoint(
|
||||
|
|
@ -496,6 +645,15 @@ class DiT(nn.Module):
|
|||
video_latent_states = block(
|
||||
video_latent_states, attention_mask, t, text_latent_states
|
||||
)
|
||||
|
||||
if self.seq_parallel_group is not None and self.seq_parallel_size > 1:
|
||||
video_latent_states = gather_seq(
|
||||
video_latent_states,
|
||||
self.seq_parallel_size,
|
||||
self.seq_parallel_rank,
|
||||
self.seq_parallel_group,
|
||||
)
|
||||
|
||||
if not self.use_cross_attn:
|
||||
video_latent_states = video_latent_states[:, text_len:]
|
||||
video_latent_states = self.final_layer(video_latent_states, t)
|
||||
|
|
|
|||
96
open_sora/utils/comm.py
Normal file
96
open_sora/utils/comm.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
|
||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward
|
||||
|
||||
|
||||
def _all_to_all(
|
||||
input_: torch.Tensor,
|
||||
world_size: int,
|
||||
group: dist.ProcessGroup,
|
||||
scatter_dim: int,
|
||||
gather_dim: int,
|
||||
):
|
||||
input_list = [
|
||||
t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)
|
||||
]
|
||||
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
||||
dist.all_to_all(output_list, input_list, group=group)
|
||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
|
||||
|
||||
class _AllToAll(torch.autograd.Function):
|
||||
"""All-to-all communication.
|
||||
|
||||
Args:
|
||||
input_: input matrix
|
||||
process_group: communication group
|
||||
scatter_dim: scatter dimension
|
||||
gather_dim: gather dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
|
||||
ctx.process_group = process_group
|
||||
ctx.scatter_dim = scatter_dim
|
||||
ctx.gather_dim = gather_dim
|
||||
ctx.world_size = dist.get_world_size(process_group)
|
||||
return _all_to_all(
|
||||
input_, ctx.world_size, process_group, scatter_dim, gather_dim
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return (
|
||||
_all_to_all(
|
||||
grad_output,
|
||||
ctx.world_size,
|
||||
ctx.process_group,
|
||||
ctx.gather_dim,
|
||||
ctx.scatter_dim,
|
||||
),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def all_to_all(
|
||||
input_: torch.Tensor,
|
||||
process_group: dist.ProcessGroup,
|
||||
scatter_dim: int = 2,
|
||||
gather_dim: int = 1,
|
||||
):
|
||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
||||
|
||||
|
||||
def split_seq(input_: torch.Tensor, sp_size: int, sp_rank: int, dim: int = 1):
|
||||
"""Split a tensor along sequence dimension. It will split input and divide grad by sp_size.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): The common shape is (bs, seq, *).
|
||||
sp_size (int): Sequence parallel size.
|
||||
sp_rank (int): Sequence parallel rank.
|
||||
dim (int, optional): Sequence dimension. Defaults to 1.
|
||||
"""
|
||||
input_ = input_.chunk(sp_size, dim=dim)[sp_rank].clone()
|
||||
return MoeOutGradScaler.apply(input_, sp_size)
|
||||
|
||||
|
||||
def gather_seq(
|
||||
input_: torch.Tensor,
|
||||
sp_size: int,
|
||||
sp_rank: int,
|
||||
sp_group: dist.ProcessGroup,
|
||||
dim: int = 1,
|
||||
):
|
||||
"""Gather a tensor along sequence dimension. It will gather input and multiply grad by sp_size.
|
||||
|
||||
Args:
|
||||
input_ (torch.Tensor): The common shape is (bs, seq, *).
|
||||
sp_size (int): Sequence parallel size.
|
||||
sp_rank (int): Sequence parallel rank.
|
||||
dim (int, optional): Sequence dimension. Defaults to 1.
|
||||
"""
|
||||
input_ = gather_forward_split_backward(input_, dim, sp_group)
|
||||
return MoeInGradScaler.apply(input_, sp_size)
|
||||
|
|
@ -16,6 +16,13 @@ DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
|||
PathType = Union[str, os.PathLike]
|
||||
|
||||
|
||||
def ceil_to_multiple(x: int, multiple: int) -> int:
|
||||
m = x % multiple
|
||||
if m == 0:
|
||||
return x
|
||||
return x + multiple - m
|
||||
|
||||
|
||||
def video2col(video_4d: torch.Tensor, patch_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Convert a 4D video tensor to a 2D tensor where each row is a patch of the video.
|
||||
|
|
@ -67,7 +74,9 @@ def col2video(
|
|||
return video
|
||||
|
||||
|
||||
def pad_sequences(sequences: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def pad_sequences(
|
||||
sequences: List[torch.Tensor], pad_to_multiple: Optional[int] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Pad a list of sequences.
|
||||
|
||||
Args:
|
||||
|
|
@ -77,6 +86,8 @@ def pad_sequences(sequences: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Te
|
|||
Tuple[torch.Tensor, torch.Tensor]: Padded batch of sequences ([B, T, ...]) and padding mask ([B, T]).
|
||||
"""
|
||||
max_len = max([sequence.shape[0] for sequence in sequences])
|
||||
if pad_to_multiple is not None:
|
||||
max_len = ceil_to_multiple(max_len, pad_to_multiple)
|
||||
padded_sequences = [
|
||||
F.pad(
|
||||
sequence, [0] * (sequence.ndim - 1) * 2 + [0, max_len - sequence.shape[0]]
|
||||
|
|
@ -96,7 +107,7 @@ def pad_sequences(sequences: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Te
|
|||
|
||||
|
||||
def patchify_batch(
|
||||
videos: List[torch.Tensor], patch_size: int
|
||||
videos: List[torch.Tensor], patch_size: int, pad_to_multiple: Optional[int] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Patchify a batch of videos.
|
||||
|
||||
|
|
@ -108,7 +119,7 @@ def patchify_batch(
|
|||
Tuple[torch.Tensor, torch.Tensor]: Padded batch of patches ([B, S, C, P, P]) and padding mask ([B, S]).
|
||||
"""
|
||||
video_patches = [video2col(video, patch_size) for video in videos]
|
||||
return pad_sequences(video_patches)
|
||||
return pad_sequences(video_patches, pad_to_multiple=pad_to_multiple)
|
||||
|
||||
|
||||
def expand_mask_4d(q_mask: torch.Tensor, kv_mask: torch.Tensor) -> torch.Tensor:
|
||||
|
|
@ -127,7 +138,9 @@ def expand_mask_4d(q_mask: torch.Tensor, kv_mask: torch.Tensor) -> torch.Tensor:
|
|||
return mask.unsqueeze(1)
|
||||
|
||||
|
||||
def make_batch(samples: List[dict], video_dir: str) -> dict:
|
||||
def make_batch(
|
||||
samples: List[dict], video_dir: str, pad_to_multiple: Optional[int] = None
|
||||
) -> dict:
|
||||
"""Make a batch of samples.
|
||||
|
||||
Args:
|
||||
|
|
@ -141,7 +154,7 @@ def make_batch(samples: List[dict], video_dir: str) -> dict:
|
|||
for sample in samples
|
||||
]
|
||||
texts = [sample["text_latent_states"] for sample in samples]
|
||||
texts, text_padding_mask = pad_sequences(texts)
|
||||
texts, text_padding_mask = pad_sequences(texts, pad_to_multiple=pad_to_multiple)
|
||||
return {
|
||||
"videos": videos,
|
||||
"text_latent_states": texts,
|
||||
|
|
@ -269,6 +282,7 @@ def preprocess_batch(
|
|||
video_compressor: VideoCompressor,
|
||||
device=None,
|
||||
use_cross_attn=True,
|
||||
pad_to_multiple: Optional[int] = None,
|
||||
) -> dict:
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
|
|
@ -278,7 +292,9 @@ def preprocess_batch(
|
|||
video = normalize_video(video)
|
||||
video = video_compressor.encode(video)
|
||||
videos.append(video)
|
||||
video_latent_states, video_padding_mask = patchify_batch(videos, patch_size)
|
||||
video_latent_states, video_padding_mask = patchify_batch(
|
||||
videos, patch_size, pad_to_multiple
|
||||
)
|
||||
batch["video_latent_states"] = video_latent_states
|
||||
batch["video_padding_mask"] = video_padding_mask
|
||||
text_padding_mask = batch.pop("text_padding_mask").to(device)
|
||||
|
|
|
|||
102
open_sora/utils/plugin.py
Normal file
102
open_sora/utils/plugin.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import random
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
DP_AXIS, SP_AXIS = 0, 1
|
||||
|
||||
|
||||
class ZeroSeqParallelPlugin(LowLevelZeroPlugin):
|
||||
def __init__(
|
||||
self,
|
||||
sp_size: int = 1,
|
||||
stage: int = 2,
|
||||
precision: str = "fp16",
|
||||
initial_scale: float = 2**32,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
reduce_bucket_size_in_m: int = 12,
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
cpu_offload: bool = False,
|
||||
master_weights: bool = True,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
stage=stage,
|
||||
precision=precision,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
reduce_bucket_size_in_m=reduce_bucket_size_in_m,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=cpu_offload,
|
||||
master_weights=master_weights,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.sp_size = sp_size
|
||||
assert self.world_size % sp_size == 0, "world_size must be divisible by sp_size"
|
||||
self.dp_size = self.world_size // sp_size
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.sp_size)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
|
||||
self.dp_rank = self.pg_mesh.coordinate(DP_AXIS)
|
||||
self.sp_rank = self.pg_mesh.coordinate(SP_AXIS)
|
||||
|
||||
def __del__(self):
|
||||
"""Destroy the prcess groups in ProcessGroupMesh"""
|
||||
self.pg_mesh.destroy_mesh_process_groups()
|
||||
|
||||
def prepare_dataloader(
|
||||
self,
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
distributed_sampler_cls=None,
|
||||
**kwargs
|
||||
):
|
||||
_kwargs = kwargs.copy()
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(
|
||||
dataset, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
93
tests/test_model.py
Normal file
93
tests/test_model.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from torch.testing import assert_close
|
||||
|
||||
from open_sora.modeling import DiT_models
|
||||
from open_sora.utils.plugin import ZeroSeqParallelPlugin
|
||||
|
||||
|
||||
@parameterize("sp_size", [2, 4])
|
||||
def check_dit_model_fwd_bwd(
|
||||
sp_size: int, video_latent_states, text_latent_states, t, mask
|
||||
):
|
||||
plugin = ZeroSeqParallelPlugin(
|
||||
sp_size=sp_size, stage=2, precision="fp32", master_weights=False
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
model = DiT_models["DiT-B/8"](text_dropout_prob=0.0).to(get_current_device())
|
||||
parallel_model = DiT_models["DiT-B/8"](
|
||||
text_dropout_prob=0.0, seq_parallel_group=plugin.sp_group
|
||||
).to(get_current_device())
|
||||
parallel_model.load_state_dict(model.state_dict())
|
||||
opt = HybridAdam(parallel_model.parameters(), lr=1e-3)
|
||||
parallel_model, opt, *_ = booster.boost(parallel_model, opt)
|
||||
|
||||
target = model(video_latent_states, t, text_latent_states, mask)
|
||||
noise = torch.randn_like(target)
|
||||
target_loss = F.mse_loss(target, noise)
|
||||
target_loss.backward()
|
||||
|
||||
dp_video_latent_states = video_latent_states.chunk(plugin.dp_size)[plugin.dp_rank]
|
||||
dp_text_latent_states = text_latent_states.chunk(plugin.dp_size)[plugin.dp_rank]
|
||||
dp_t = t.chunk(plugin.dp_size)[plugin.dp_rank]
|
||||
dp_mask = mask.chunk(plugin.dp_size)[plugin.dp_rank]
|
||||
dp_noise = noise.chunk(plugin.dp_size)[plugin.dp_rank]
|
||||
|
||||
output = parallel_model(
|
||||
dp_video_latent_states, dp_t, dp_text_latent_states, dp_mask
|
||||
)
|
||||
loss = F.mse_loss(output, dp_noise)
|
||||
booster.backward(loss, opt)
|
||||
|
||||
if plugin.dp_size == 1:
|
||||
assert_close(target, output)
|
||||
|
||||
for p1, p2 in zip(model.parameters(), opt._master_param_groups_of_current_rank[0]):
|
||||
working_p = opt._param_store.master_to_working_param[id(p2)]
|
||||
grads = opt._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
|
||||
grad_index = 0 if opt._partition_grads else opt._local_rank
|
||||
grad = grads[grad_index]
|
||||
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
|
||||
assert_close(sharded_grad, grad[: sharded_grad.shape[0]])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
port=port,
|
||||
host="localhost",
|
||||
backend="nccl",
|
||||
)
|
||||
b, s, c, p = 4, 20, 3, 8
|
||||
dim_text, s_text = 512, 12
|
||||
video_latent_states = torch.rand(b, s, c, p, p, device=get_current_device())
|
||||
text_latent_states = torch.rand(b, s_text, dim_text, device=get_current_device())
|
||||
t = torch.randint(0, 1000, (b,), device=get_current_device())
|
||||
mask = torch.ones(b, 1, s, s_text, device=get_current_device(), dtype=torch.int)
|
||||
check_dit_model_fwd_bwd(
|
||||
video_latent_states=video_latent_states,
|
||||
text_latent_states=text_latent_states,
|
||||
t=t,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dit_model():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dit_model()
|
||||
76
tests/test_sp_attn.py
Normal file
76
tests/test_sp_attn.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from torch.testing import assert_close
|
||||
|
||||
from open_sora.modeling.dit import CrossAttention, SeqParallelCrossAttention
|
||||
from open_sora.utils.comm import gather_seq, split_seq
|
||||
|
||||
|
||||
def check_sp_attn():
|
||||
sp_size = dist.get_world_size()
|
||||
sp_rank = dist.get_rank()
|
||||
q_dim, context_dim = 8, 4
|
||||
num_heads = 4
|
||||
head_dim = 16
|
||||
bs = 2
|
||||
sq = 8
|
||||
skv = 4
|
||||
attn = CrossAttention(q_dim, context_dim, num_heads, head_dim).to(
|
||||
get_current_device()
|
||||
)
|
||||
parallel_attn = SeqParallelCrossAttention(
|
||||
q_dim, context_dim, num_heads, head_dim, seq_parallel_group=dist.group.WORLD
|
||||
).to(get_current_device())
|
||||
parallel_attn.load_state_dict(attn.state_dict())
|
||||
hidden_states = torch.rand(bs, sq, q_dim, device=get_current_device())
|
||||
context = torch.rand(bs, skv, context_dim, device=get_current_device())
|
||||
mask = torch.zeros(bs, 1, sq, skv, device=get_current_device())
|
||||
target = attn(hidden_states, context, mask)
|
||||
hidden_states_parallel = split_seq(hidden_states, sp_size, sp_rank)
|
||||
context_parallel = split_seq(context, sp_size, sp_rank)
|
||||
output_parallel = parallel_attn(
|
||||
hidden_states_parallel,
|
||||
context_parallel,
|
||||
mask,
|
||||
)
|
||||
assert torch.equal(target.chunk(sp_size, dim=1)[sp_rank], output_parallel)
|
||||
output = gather_seq(output_parallel, sp_size, sp_rank, dist.group.WORLD)
|
||||
assert torch.equal(target, output)
|
||||
target.mean().backward()
|
||||
output.mean().backward()
|
||||
|
||||
# all-reduce mean of grads
|
||||
for p in parallel_attn.parameters():
|
||||
p.grad.data.div_(sp_size)
|
||||
dist.all_reduce(p.grad.data)
|
||||
|
||||
for p1, p2 in zip(attn.parameters(), parallel_attn.parameters()):
|
||||
assert_close(p1.grad, p2.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
port=port,
|
||||
host="localhost",
|
||||
backend="nccl",
|
||||
)
|
||||
check_sp_attn()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_seq_parallel_attn():
|
||||
spawn(run_dist, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_seq_parallel_attn()
|
||||
60
train.py
60
train.py
|
|
@ -17,7 +17,6 @@ import torch.distributed as dist
|
|||
from colossalai import launch_from_torch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingLR
|
||||
|
|
@ -28,7 +27,13 @@ from tqdm import tqdm
|
|||
|
||||
from open_sora.diffusion import create_diffusion
|
||||
from open_sora.modeling import DiT_models
|
||||
from open_sora.utils.data import create_video_compressor, load_datasets, make_batch, preprocess_batch
|
||||
from open_sora.utils.data import (
|
||||
create_video_compressor,
|
||||
load_datasets,
|
||||
make_batch,
|
||||
preprocess_batch,
|
||||
)
|
||||
from open_sora.utils.plugin import ZeroSeqParallelPlugin
|
||||
|
||||
#################################################################################
|
||||
# Training Helper Functions #
|
||||
|
|
@ -93,7 +98,7 @@ def main(args):
|
|||
configure_backends()
|
||||
|
||||
# Step 2: set up acceleration plugins
|
||||
plugin = LowLevelZeroPlugin(stage=2, precision="fp16")
|
||||
plugin = ZeroSeqParallelPlugin(sp_size=args.sp_size, stage=2, precision="fp16")
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
if coordinator.is_master():
|
||||
|
|
@ -103,7 +108,10 @@ def main(args):
|
|||
|
||||
# Step 3: Create video compressor
|
||||
video_compressor = create_video_compressor(args.compressor)
|
||||
model_kwargs = {"in_channels": video_compressor.out_channels}
|
||||
model_kwargs = {
|
||||
"in_channels": video_compressor.out_channels,
|
||||
"seq_parallel_group": plugin.sp_group,
|
||||
}
|
||||
|
||||
# Step 4: Create DiT and EMA
|
||||
model = DiT_models[args.model](**model_kwargs).to(get_current_device())
|
||||
|
|
@ -119,7 +127,9 @@ def main(args):
|
|||
model.enable_gradient_checkpointing()
|
||||
|
||||
# Step 5: create diffusion pipeline
|
||||
diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
|
||||
diffusion = create_diffusion(
|
||||
timestep_respacing=""
|
||||
) # default: 1000 steps, linear noise schedule
|
||||
|
||||
# Step 6: setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
|
||||
opt = HybridAdam(model.parameters(), lr=args.lr, weight_decay=0)
|
||||
|
|
@ -129,15 +139,21 @@ def main(args):
|
|||
dataloader = plugin.prepare_dataloader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=partial(make_batch, video_dir=args.video_dir),
|
||||
collate_fn=partial(
|
||||
make_batch, video_dir=args.video_dir, pad_to_multiple=args.sp_size
|
||||
),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
)
|
||||
lr_scheduler = CosineAnnealingLR(opt, args.epochs * len(dataloader) // args.accumulation_steps)
|
||||
lr_scheduler = CosineAnnealingLR(
|
||||
opt, args.epochs * len(dataloader) // args.accumulation_steps
|
||||
)
|
||||
logger.info(f"Dataset contains {len(dataset)} samples", ranks=[0])
|
||||
|
||||
# Step 8: setup booster
|
||||
model, opt, _, dataloader, _ = booster.boost(model, opt, dataloader=dataloader)
|
||||
model, opt, _, dataloader, lr_scheduler = booster.boost(
|
||||
model, opt, dataloader=dataloader, lr_scheduler=lr_scheduler
|
||||
)
|
||||
if args.load_model is not None:
|
||||
booster.load_model(model, args.load_model)
|
||||
if args.load_optimizer is not None:
|
||||
|
|
@ -159,7 +175,9 @@ def main(args):
|
|||
) as pbar:
|
||||
total_loss = torch.tensor(0.0, device=get_current_device())
|
||||
for step, batch in enumerate(dataloader):
|
||||
batch = preprocess_batch(batch, patch_size, video_compressor)
|
||||
batch = preprocess_batch(
|
||||
batch, patch_size, video_compressor, pad_to_multiple=args.sp_size
|
||||
)
|
||||
video_inputs = batch.pop("video_latent_states")
|
||||
mask = batch.pop("video_padding_mask")
|
||||
t = torch.randint(
|
||||
|
|
@ -168,7 +186,9 @@ def main(args):
|
|||
(video_inputs.shape[0],),
|
||||
device=video_inputs.device,
|
||||
)
|
||||
loss_dict = diffusion.training_losses(model, video_inputs, t, batch, mask=mask)
|
||||
loss_dict = diffusion.training_losses(
|
||||
model, video_inputs, t, batch, mask=mask
|
||||
)
|
||||
loss = loss_dict["loss"].mean() / args.accumulation_steps
|
||||
total_loss.add_(loss.data)
|
||||
booster.backward(loss, opt)
|
||||
|
|
@ -182,7 +202,9 @@ def main(args):
|
|||
all_reduce_mean(total_loss)
|
||||
pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
|
||||
if coordinator.is_master():
|
||||
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
||||
global_step = (epoch * num_steps_per_epoch) + (
|
||||
step + 1
|
||||
) // args.accumulation_steps
|
||||
writer.add_scalar(
|
||||
tag="Loss",
|
||||
scalar_value=total_loss.item(),
|
||||
|
|
@ -193,9 +215,12 @@ def main(args):
|
|||
|
||||
# Save DiT checkpoint:
|
||||
if args.save_interval > 0 and (
|
||||
(step + 1) % (args.save_interval * args.accumulation_steps) == 0 or (step + 1) == len(dataloader)
|
||||
(step + 1) % (args.save_interval * args.accumulation_steps) == 0
|
||||
or (step + 1) == len(dataloader)
|
||||
):
|
||||
save_path = os.path.join(args.checkpoint_dir, f"epoch-{epoch}-step-{step}")
|
||||
save_path = os.path.join(
|
||||
args.checkpoint_dir, f"epoch-{epoch}-step-{step}"
|
||||
)
|
||||
save_checkpoints(booster, model, opt, ema, save_path, coordinator)
|
||||
logger.info(f"Saved checkpoint to {save_path}", ranks=[0])
|
||||
|
||||
|
|
@ -212,18 +237,23 @@ def main(args):
|
|||
if __name__ == "__main__":
|
||||
# Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters).
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8")
|
||||
parser.add_argument(
|
||||
"-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8"
|
||||
)
|
||||
parser.add_argument("-d", "--dataset", nargs="+", default=[])
|
||||
parser.add_argument("-v", "--video_dir", type=str, required=True)
|
||||
parser.add_argument("-e", "--epochs", type=int, default=10)
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=4)
|
||||
parser.add_argument("-g", "--grad_checkpoint", action="store_true", default=False)
|
||||
parser.add_argument("-a", "--accumulation_steps", default=1, type=int)
|
||||
parser.add_argument("--sp_size", type=int, default=1)
|
||||
parser.add_argument("--lr", type=float, default=1e-4)
|
||||
parser.add_argument("--save_interval", type=int, default=20)
|
||||
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
|
||||
parser.add_argument("--tensorboard_dir", type=str, default="runs")
|
||||
parser.add_argument("-c", "--compressor", choices=["raw", "vqvae", "vae"], default="raw")
|
||||
parser.add_argument(
|
||||
"-c", "--compressor", choices=["raw", "vqvae", "vae"], default="raw"
|
||||
)
|
||||
parser.add_argument("--load_model", default=None)
|
||||
parser.add_argument("--load_optimizer", default=None)
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
Loading…
Reference in a new issue