From 97c089daec480aa3da3064112ebb170f6f85bbe3 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Fri, 1 Mar 2024 14:42:06 +0800 Subject: [PATCH] [feature] impl ulysses-style seq parallel (#20) * [feature] add ulysses style sp attn * [test] add sp attn test * [feature] add zero sp plugin * [hotfix] fix sp backward * [test] add test for dit model --- benchmark.py | 22 ++++-- open_sora/modeling/dit.py | 162 +++++++++++++++++++++++++++++++++++++- open_sora/utils/comm.py | 96 ++++++++++++++++++++++ open_sora/utils/data.py | 28 +++++-- open_sora/utils/plugin.py | 102 ++++++++++++++++++++++++ tests/test_model.py | 93 ++++++++++++++++++++++ tests/test_sp_attn.py | 76 ++++++++++++++++++ train.py | 60 ++++++++++---- 8 files changed, 611 insertions(+), 28 deletions(-) create mode 100644 open_sora/utils/comm.py create mode 100644 open_sora/utils/plugin.py create mode 100644 tests/test_model.py create mode 100644 tests/test_sp_attn.py diff --git a/benchmark.py b/benchmark.py index 3244707..dbe9855 100644 --- a/benchmark.py +++ b/benchmark.py @@ -14,7 +14,7 @@ import torch from colossalai import launch_from_torch from colossalai.accelerator import get_accelerator from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import HybridAdam @@ -24,6 +24,7 @@ from tqdm import tqdm from open_sora.diffusion import create_diffusion from open_sora.modeling import DiT_models from open_sora.utils.data import create_video_compressor, preprocess_batch +from open_sora.utils.plugin import ZeroSeqParallelPlugin ################################################################################# # Training Loop # @@ -44,14 +45,17 @@ def main(args): plugin = TorchDDPPlugin() elif args.plugin == "zero2": # use bf16 to avoid skipping the first few iterations due to NaNs - plugin = LowLevelZeroPlugin(stage=2, precision="bf16") + plugin = ZeroSeqParallelPlugin(sp_size=args.sp_size, stage=2, precision="bf16") else: raise ValueError(f"Unknown plugin {args.plugin}") booster = Booster(plugin=plugin) # Create video compressor video_compressor = create_video_compressor(args.compressor) - model_kwargs = {"in_channels": video_compressor.out_channels} + model_kwargs = { + "in_channels": video_compressor.out_channels, + "seq_parallel_group": plugin.sp_group, + } # Create DiT and EMA model = DiT_models[args.model](**model_kwargs).to(get_current_device()) @@ -75,6 +79,7 @@ def main(args): torch.randn(args.num_frames, args.height, args.width, 3) for _ in range(args.batch_size) ] + assert args.num_tokens % args.sp_size == 0 input_ids = torch.randn(args.batch_size, args.num_tokens, args.text_embed_dim) text_mask = torch.ones(input_ids.shape[:2], dtype=torch.int) batch = { @@ -82,9 +87,15 @@ def main(args): "text_latent_states": input_ids, "text_padding_mask": text_mask, } - batch = preprocess_batch(batch, patch_size, video_compressor) + batch = preprocess_batch( + batch, patch_size, video_compressor, pad_to_multiple=args.sp_size + ) video_inputs = batch.pop("video_latent_states") mask = batch.pop("video_padding_mask") + logger.info( + f"Num patches: {video_inputs.shape[1]}, num_tokens: {batch['text_latent_states'].shape[1]}", + ranks=[0], + ) # setup booster model, opt, *_ = booster.boost(model, opt) @@ -125,7 +136,7 @@ def main(args): ranks=[0], ) logger.info( - f"Throughput: {throughput:.2f} samples/s", + f"Throughput per device: {throughput:.2f} samples/s", ranks=[0], ) @@ -139,6 +150,7 @@ if __name__ == "__main__": parser.add_argument( "-p", "--plugin", type=str, default="zero2", choices=["ddp", "zero2"] ) + parser.add_argument("--sp_size", type=int, default=1) parser.add_argument("-w", "--warmup_steps", type=int, default=2) parser.add_argument("-s", "--steps", type=int, default=3) parser.add_argument("-b", "--batch_size", type=int, default=4) diff --git a/open_sora/modeling/dit.py b/open_sora/modeling/dit.py index 5710abd..ac9ce8a 100644 --- a/open_sora/modeling/dit.py +++ b/open_sora/modeling/dit.py @@ -14,10 +14,13 @@ from typing import Callable, Optional import numpy as np import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from timm.models.vision_transformer import Mlp +from open_sora.utils.comm import all_to_all, gather_seq, split_seq + class CrossAttention(nn.Module): r""" @@ -101,6 +104,120 @@ class CrossAttention(nn.Module): return attn_output +class SeqParallelCrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the context. If not given, defaults to `query_dim`. + num_heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + num_heads: int = 8, + head_dim: int = 64, + dropout: float = 0.0, + bias=False, + sdpa=True, + seq_parallel_group=None, + ): + super().__init__() + self.hidden_size = head_dim * num_heads + cross_attention_dim = ( + cross_attention_dim if cross_attention_dim is not None else query_dim + ) + + self.scale = head_dim**-0.5 + self.num_heads = num_heads + self.head_dim = head_dim + self.sdpa = sdpa + self.seq_parallel_group = seq_parallel_group + self.seq_parallel_size = ( + dist.get_world_size(self.seq_parallel_group) + if seq_parallel_group is not None + else 1 + ) + assert self.num_heads % self.seq_parallel_size == 0 + + self.to_q = nn.Linear(query_dim, self.hidden_size, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, self.hidden_size, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, self.hidden_size, bias=bias) + + self.to_out = nn.Sequential( + nn.Linear(self.hidden_size, query_dim), nn.Dropout(dropout) + ) + + def forward(self, hidden_states, context=None, mask=None): + bsz, q_len, _ = hidden_states.shape + + query = self.to_q(hidden_states) + context = context if context is not None else hidden_states + kv_seq_len = context.shape[1] + key = self.to_k(context) + value = self.to_v(context) + + # [B, S/P, H] -> [B, S, H/P] + num_heads_parallel = self.num_heads // self.seq_parallel_size + hidden_size_parallel = self.hidden_size // self.seq_parallel_size + if self.seq_parallel_group is not None and self.seq_parallel_size > 1: + query = all_to_all( + query, self.seq_parallel_group, scatter_dim=2, gather_dim=1 + ) + key = all_to_all(key, self.seq_parallel_group, scatter_dim=2, gather_dim=1) + value = all_to_all( + value, self.seq_parallel_group, scatter_dim=2, gather_dim=1 + ) + + q_len *= self.seq_parallel_size + kv_seq_len *= self.seq_parallel_size + + # [B, S, H/P] -> [B, S, N/P, D] -> [B, N/P, S, D] + query = query.view(bsz, q_len, num_heads_parallel, self.head_dim).transpose( + 1, 2 + ) + key = key.view(bsz, kv_seq_len, num_heads_parallel, self.head_dim).transpose( + 1, 2 + ) + value = value.view( + bsz, kv_seq_len, num_heads_parallel, self.head_dim + ).transpose(1, 2) + + if mask is not None: + assert mask.shape == (bsz, 1, q_len, kv_seq_len) + if self.sdpa: + attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=mask, scale=self.scale + ) + else: + attn_weights = torch.matmul(query, key.transpose(2, 3)) / self.scale + assert attn_weights.shape == (bsz, num_heads_parallel, q_len, kv_seq_len) + if mask is not None: + attn_weights = attn_weights + mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query.dtype + ) + attn_output = torch.matmul(attn_weights, value) + assert attn_output.shape == (bsz, num_heads_parallel, q_len, self.head_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, hidden_size_parallel) + # [B, S, H/P] -> [B, S/P, H] + if self.seq_parallel_group is not None and self.seq_parallel_size > 1: + attn_output = all_to_all( + attn_output, self.seq_parallel_group, scatter_dim=1, gather_dim=2 + ) + attn_output = self.to_out(attn_output) + return attn_output + + def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) @@ -290,16 +407,18 @@ class DiTBlock(nn.Module): num_heads, cross_attention_dim=None, mlp_ratio=4.0, + seq_parallel_group=None, ): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.attn = CrossAttention( + self.attn = SeqParallelCrossAttention( query_dim=hidden_size, cross_attention_dim=cross_attention_dim, num_heads=num_heads, head_dim=hidden_size // num_heads, bias=True, sdpa=True, + seq_parallel_group=seq_parallel_group, ) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -380,6 +499,7 @@ class DiT(nn.Module): text_dropout_prob=0.1, learn_sigma=True, use_cross_attn=True, + seq_parallel_group=None, ): super().__init__() self.grad_checkpointing = False @@ -388,6 +508,17 @@ class DiT(nn.Module): self.out_channels = in_channels * 2 if learn_sigma else in_channels self.patch_size = patch_size self.num_heads = num_heads + self.seq_parallel_group = seq_parallel_group + self.seq_parallel_size = ( + dist.get_world_size(self.seq_parallel_group) + if seq_parallel_group is not None + else 1 + ) + self.seq_parallel_rank = ( + dist.get_rank(self.seq_parallel_group) + if seq_parallel_group is not None + else 0 + ) self.video_embedder = PatchEmbedder( patch_size, in_channels, hidden_size, bias=True @@ -409,7 +540,13 @@ class DiT(nn.Module): self.blocks = nn.ModuleList( [ - DiTBlock(hidden_size, num_heads, cross_attn_dim, mlp_ratio=mlp_ratio) + DiTBlock( + hidden_size, + num_heads, + cross_attn_dim, + mlp_ratio=mlp_ratio, + seq_parallel_group=seq_parallel_group, + ) for _ in range(depth) ] ) @@ -483,6 +620,18 @@ class DiT(nn.Module): video_latent_states = video_latent_states + pos_embed t = self.t_embedder(t) # (N, D) attention_mask = self._prepare_mask(attention_mask, video_latent_states.dtype) + + if self.seq_parallel_group is not None and self.seq_parallel_size > 1: + assert video_latent_states.shape[1] % self.seq_parallel_size == 0 + video_latent_states = split_seq( + video_latent_states, self.seq_parallel_size, self.seq_parallel_rank + ) + if text_latent_states is not None: + assert text_latent_states.shape[1] % self.seq_parallel_size == 0 + text_latent_states = split_seq( + text_latent_states, self.seq_parallel_size, self.seq_parallel_rank + ) + for block in self.blocks: if self.grad_checkpointing and self.training: video_latent_states = torch.utils.checkpoint.checkpoint( @@ -496,6 +645,15 @@ class DiT(nn.Module): video_latent_states = block( video_latent_states, attention_mask, t, text_latent_states ) + + if self.seq_parallel_group is not None and self.seq_parallel_size > 1: + video_latent_states = gather_seq( + video_latent_states, + self.seq_parallel_size, + self.seq_parallel_rank, + self.seq_parallel_group, + ) + if not self.use_cross_attn: video_latent_states = video_latent_states[:, text_len:] video_latent_states = self.final_layer(video_latent_states, t) diff --git a/open_sora/utils/comm.py b/open_sora/utils/comm.py new file mode 100644 index 0000000..00703d3 --- /dev/null +++ b/open_sora/utils/comm.py @@ -0,0 +1,96 @@ +import torch +import torch.distributed as dist +from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler +from colossalai.shardformer.layer._operation import gather_forward_split_backward + + +def _all_to_all( + input_: torch.Tensor, + world_size: int, + group: dist.ProcessGroup, + scatter_dim: int, + gather_dim: int, +): + input_list = [ + t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim) + ] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +class _AllToAll(torch.autograd.Function): + """All-to-all communication. + + Args: + input_: input matrix + process_group: communication group + scatter_dim: scatter dimension + gather_dim: gather dimension + """ + + @staticmethod + def forward(ctx, input_, process_group, scatter_dim, gather_dim): + ctx.process_group = process_group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + ctx.world_size = dist.get_world_size(process_group) + return _all_to_all( + input_, ctx.world_size, process_group, scatter_dim, gather_dim + ) + + @staticmethod + def backward(ctx, grad_output): + return ( + _all_to_all( + grad_output, + ctx.world_size, + ctx.process_group, + ctx.gather_dim, + ctx.scatter_dim, + ), + None, + None, + None, + ) + + +def all_to_all( + input_: torch.Tensor, + process_group: dist.ProcessGroup, + scatter_dim: int = 2, + gather_dim: int = 1, +): + return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) + + +def split_seq(input_: torch.Tensor, sp_size: int, sp_rank: int, dim: int = 1): + """Split a tensor along sequence dimension. It will split input and divide grad by sp_size. + + Args: + input_ (torch.Tensor): The common shape is (bs, seq, *). + sp_size (int): Sequence parallel size. + sp_rank (int): Sequence parallel rank. + dim (int, optional): Sequence dimension. Defaults to 1. + """ + input_ = input_.chunk(sp_size, dim=dim)[sp_rank].clone() + return MoeOutGradScaler.apply(input_, sp_size) + + +def gather_seq( + input_: torch.Tensor, + sp_size: int, + sp_rank: int, + sp_group: dist.ProcessGroup, + dim: int = 1, +): + """Gather a tensor along sequence dimension. It will gather input and multiply grad by sp_size. + + Args: + input_ (torch.Tensor): The common shape is (bs, seq, *). + sp_size (int): Sequence parallel size. + sp_rank (int): Sequence parallel rank. + dim (int, optional): Sequence dimension. Defaults to 1. + """ + input_ = gather_forward_split_backward(input_, dim, sp_group) + return MoeInGradScaler.apply(input_, sp_size) diff --git a/open_sora/utils/data.py b/open_sora/utils/data.py index d9078ec..99ad63c 100644 --- a/open_sora/utils/data.py +++ b/open_sora/utils/data.py @@ -16,6 +16,13 @@ DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] PathType = Union[str, os.PathLike] +def ceil_to_multiple(x: int, multiple: int) -> int: + m = x % multiple + if m == 0: + return x + return x + multiple - m + + def video2col(video_4d: torch.Tensor, patch_size: int) -> torch.Tensor: """ Convert a 4D video tensor to a 2D tensor where each row is a patch of the video. @@ -67,7 +74,9 @@ def col2video( return video -def pad_sequences(sequences: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: +def pad_sequences( + sequences: List[torch.Tensor], pad_to_multiple: Optional[int] = None +) -> Tuple[torch.Tensor, torch.Tensor]: """Pad a list of sequences. Args: @@ -77,6 +86,8 @@ def pad_sequences(sequences: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Te Tuple[torch.Tensor, torch.Tensor]: Padded batch of sequences ([B, T, ...]) and padding mask ([B, T]). """ max_len = max([sequence.shape[0] for sequence in sequences]) + if pad_to_multiple is not None: + max_len = ceil_to_multiple(max_len, pad_to_multiple) padded_sequences = [ F.pad( sequence, [0] * (sequence.ndim - 1) * 2 + [0, max_len - sequence.shape[0]] @@ -96,7 +107,7 @@ def pad_sequences(sequences: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Te def patchify_batch( - videos: List[torch.Tensor], patch_size: int + videos: List[torch.Tensor], patch_size: int, pad_to_multiple: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Patchify a batch of videos. @@ -108,7 +119,7 @@ def patchify_batch( Tuple[torch.Tensor, torch.Tensor]: Padded batch of patches ([B, S, C, P, P]) and padding mask ([B, S]). """ video_patches = [video2col(video, patch_size) for video in videos] - return pad_sequences(video_patches) + return pad_sequences(video_patches, pad_to_multiple=pad_to_multiple) def expand_mask_4d(q_mask: torch.Tensor, kv_mask: torch.Tensor) -> torch.Tensor: @@ -127,7 +138,9 @@ def expand_mask_4d(q_mask: torch.Tensor, kv_mask: torch.Tensor) -> torch.Tensor: return mask.unsqueeze(1) -def make_batch(samples: List[dict], video_dir: str) -> dict: +def make_batch( + samples: List[dict], video_dir: str, pad_to_multiple: Optional[int] = None +) -> dict: """Make a batch of samples. Args: @@ -141,7 +154,7 @@ def make_batch(samples: List[dict], video_dir: str) -> dict: for sample in samples ] texts = [sample["text_latent_states"] for sample in samples] - texts, text_padding_mask = pad_sequences(texts) + texts, text_padding_mask = pad_sequences(texts, pad_to_multiple=pad_to_multiple) return { "videos": videos, "text_latent_states": texts, @@ -269,6 +282,7 @@ def preprocess_batch( video_compressor: VideoCompressor, device=None, use_cross_attn=True, + pad_to_multiple: Optional[int] = None, ) -> dict: if device is None: device = get_current_device() @@ -278,7 +292,9 @@ def preprocess_batch( video = normalize_video(video) video = video_compressor.encode(video) videos.append(video) - video_latent_states, video_padding_mask = patchify_batch(videos, patch_size) + video_latent_states, video_padding_mask = patchify_batch( + videos, patch_size, pad_to_multiple + ) batch["video_latent_states"] = video_latent_states batch["video_padding_mask"] = video_padding_mask text_padding_mask = batch.pop("text_padding_mask").to(device) diff --git a/open_sora/utils/plugin.py b/open_sora/utils/plugin.py new file mode 100644 index 0000000..f687b41 --- /dev/null +++ b/open_sora/utils/plugin.py @@ -0,0 +1,102 @@ +import random +from typing import Optional + +import numpy as np +import torch +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.cluster import ProcessGroupMesh +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +DP_AXIS, SP_AXIS = 0, 1 + + +class ZeroSeqParallelPlugin(LowLevelZeroPlugin): + def __init__( + self, + sp_size: int = 1, + stage: int = 2, + precision: str = "fp16", + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + reduce_bucket_size_in_m: int = 12, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + cpu_offload: bool = False, + master_weights: bool = True, + verbose: bool = False, + ) -> None: + super().__init__( + stage=stage, + precision=precision, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type, + reduce_bucket_size_in_m=reduce_bucket_size_in_m, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + master_weights=master_weights, + verbose=verbose, + ) + self.sp_size = sp_size + assert self.world_size % sp_size == 0, "world_size must be divisible by sp_size" + self.dp_size = self.world_size // sp_size + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.sp_size) + self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS) + self.dp_rank = self.pg_mesh.coordinate(DP_AXIS) + self.sp_rank = self.pg_mesh.coordinate(SP_AXIS) + + def __del__(self): + """Destroy the prcess groups in ProcessGroupMesh""" + self.pg_mesh.destroy_mesh_process_groups() + + def prepare_dataloader( + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs + ): + _kwargs = kwargs.copy() + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls( + dataset, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle + ) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..62b3172 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,93 @@ +import colossalai +import pytest +import torch +import torch.distributed as dist +import torch.nn.functional as F +from colossalai.booster import Booster +from colossalai.logging import disable_existing_loggers +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from torch.testing import assert_close + +from open_sora.modeling import DiT_models +from open_sora.utils.plugin import ZeroSeqParallelPlugin + + +@parameterize("sp_size", [2, 4]) +def check_dit_model_fwd_bwd( + sp_size: int, video_latent_states, text_latent_states, t, mask +): + plugin = ZeroSeqParallelPlugin( + sp_size=sp_size, stage=2, precision="fp32", master_weights=False + ) + booster = Booster(plugin=plugin) + model = DiT_models["DiT-B/8"](text_dropout_prob=0.0).to(get_current_device()) + parallel_model = DiT_models["DiT-B/8"]( + text_dropout_prob=0.0, seq_parallel_group=plugin.sp_group + ).to(get_current_device()) + parallel_model.load_state_dict(model.state_dict()) + opt = HybridAdam(parallel_model.parameters(), lr=1e-3) + parallel_model, opt, *_ = booster.boost(parallel_model, opt) + + target = model(video_latent_states, t, text_latent_states, mask) + noise = torch.randn_like(target) + target_loss = F.mse_loss(target, noise) + target_loss.backward() + + dp_video_latent_states = video_latent_states.chunk(plugin.dp_size)[plugin.dp_rank] + dp_text_latent_states = text_latent_states.chunk(plugin.dp_size)[plugin.dp_rank] + dp_t = t.chunk(plugin.dp_size)[plugin.dp_rank] + dp_mask = mask.chunk(plugin.dp_size)[plugin.dp_rank] + dp_noise = noise.chunk(plugin.dp_size)[plugin.dp_rank] + + output = parallel_model( + dp_video_latent_states, dp_t, dp_text_latent_states, dp_mask + ) + loss = F.mse_loss(output, dp_noise) + booster.backward(loss, opt) + + if plugin.dp_size == 1: + assert_close(target, output) + + for p1, p2 in zip(model.parameters(), opt._master_param_groups_of_current_rank[0]): + working_p = opt._param_store.master_to_working_param[id(p2)] + grads = opt._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p)) + grad_index = 0 if opt._partition_grads else opt._local_rank + grad = grads[grad_index] + sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] + assert_close(sharded_grad, grad[: sharded_grad.shape[0]]) + + +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + port=port, + host="localhost", + backend="nccl", + ) + b, s, c, p = 4, 20, 3, 8 + dim_text, s_text = 512, 12 + video_latent_states = torch.rand(b, s, c, p, p, device=get_current_device()) + text_latent_states = torch.rand(b, s_text, dim_text, device=get_current_device()) + t = torch.randint(0, 1000, (b,), device=get_current_device()) + mask = torch.ones(b, 1, s, s_text, device=get_current_device(), dtype=torch.int) + check_dit_model_fwd_bwd( + video_latent_states=video_latent_states, + text_latent_states=text_latent_states, + t=t, + mask=mask, + ) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dit_model(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_dit_model() diff --git a/tests/test_sp_attn.py b/tests/test_sp_attn.py new file mode 100644 index 0000000..58d3b7d --- /dev/null +++ b/tests/test_sp_attn.py @@ -0,0 +1,76 @@ +import colossalai +import pytest +import torch +import torch.distributed as dist +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from torch.testing import assert_close + +from open_sora.modeling.dit import CrossAttention, SeqParallelCrossAttention +from open_sora.utils.comm import gather_seq, split_seq + + +def check_sp_attn(): + sp_size = dist.get_world_size() + sp_rank = dist.get_rank() + q_dim, context_dim = 8, 4 + num_heads = 4 + head_dim = 16 + bs = 2 + sq = 8 + skv = 4 + attn = CrossAttention(q_dim, context_dim, num_heads, head_dim).to( + get_current_device() + ) + parallel_attn = SeqParallelCrossAttention( + q_dim, context_dim, num_heads, head_dim, seq_parallel_group=dist.group.WORLD + ).to(get_current_device()) + parallel_attn.load_state_dict(attn.state_dict()) + hidden_states = torch.rand(bs, sq, q_dim, device=get_current_device()) + context = torch.rand(bs, skv, context_dim, device=get_current_device()) + mask = torch.zeros(bs, 1, sq, skv, device=get_current_device()) + target = attn(hidden_states, context, mask) + hidden_states_parallel = split_seq(hidden_states, sp_size, sp_rank) + context_parallel = split_seq(context, sp_size, sp_rank) + output_parallel = parallel_attn( + hidden_states_parallel, + context_parallel, + mask, + ) + assert torch.equal(target.chunk(sp_size, dim=1)[sp_rank], output_parallel) + output = gather_seq(output_parallel, sp_size, sp_rank, dist.group.WORLD) + assert torch.equal(target, output) + target.mean().backward() + output.mean().backward() + + # all-reduce mean of grads + for p in parallel_attn.parameters(): + p.grad.data.div_(sp_size) + dist.all_reduce(p.grad.data) + + for p1, p2 in zip(attn.parameters(), parallel_attn.parameters()): + assert_close(p1.grad, p2.grad) + + +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + port=port, + host="localhost", + backend="nccl", + ) + check_sp_attn() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_seq_parallel_attn(): + spawn(run_dist, 2) + + +if __name__ == "__main__": + test_seq_parallel_attn() diff --git a/train.py b/train.py index 7e8093d..1f4da90 100644 --- a/train.py +++ b/train.py @@ -17,7 +17,6 @@ import torch.distributed as dist from colossalai import launch_from_torch from colossalai.accelerator import get_accelerator from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.logging import get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingLR @@ -28,7 +27,13 @@ from tqdm import tqdm from open_sora.diffusion import create_diffusion from open_sora.modeling import DiT_models -from open_sora.utils.data import create_video_compressor, load_datasets, make_batch, preprocess_batch +from open_sora.utils.data import ( + create_video_compressor, + load_datasets, + make_batch, + preprocess_batch, +) +from open_sora.utils.plugin import ZeroSeqParallelPlugin ################################################################################# # Training Helper Functions # @@ -93,7 +98,7 @@ def main(args): configure_backends() # Step 2: set up acceleration plugins - plugin = LowLevelZeroPlugin(stage=2, precision="fp16") + plugin = ZeroSeqParallelPlugin(sp_size=args.sp_size, stage=2, precision="fp16") booster = Booster(plugin=plugin) if coordinator.is_master(): @@ -103,7 +108,10 @@ def main(args): # Step 3: Create video compressor video_compressor = create_video_compressor(args.compressor) - model_kwargs = {"in_channels": video_compressor.out_channels} + model_kwargs = { + "in_channels": video_compressor.out_channels, + "seq_parallel_group": plugin.sp_group, + } # Step 4: Create DiT and EMA model = DiT_models[args.model](**model_kwargs).to(get_current_device()) @@ -119,7 +127,9 @@ def main(args): model.enable_gradient_checkpointing() # Step 5: create diffusion pipeline - diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule + diffusion = create_diffusion( + timestep_respacing="" + ) # default: 1000 steps, linear noise schedule # Step 6: setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper): opt = HybridAdam(model.parameters(), lr=args.lr, weight_decay=0) @@ -129,15 +139,21 @@ def main(args): dataloader = plugin.prepare_dataloader( dataset, batch_size=args.batch_size, - collate_fn=partial(make_batch, video_dir=args.video_dir), + collate_fn=partial( + make_batch, video_dir=args.video_dir, pad_to_multiple=args.sp_size + ), shuffle=True, drop_last=True, ) - lr_scheduler = CosineAnnealingLR(opt, args.epochs * len(dataloader) // args.accumulation_steps) + lr_scheduler = CosineAnnealingLR( + opt, args.epochs * len(dataloader) // args.accumulation_steps + ) logger.info(f"Dataset contains {len(dataset)} samples", ranks=[0]) # Step 8: setup booster - model, opt, _, dataloader, _ = booster.boost(model, opt, dataloader=dataloader) + model, opt, _, dataloader, lr_scheduler = booster.boost( + model, opt, dataloader=dataloader, lr_scheduler=lr_scheduler + ) if args.load_model is not None: booster.load_model(model, args.load_model) if args.load_optimizer is not None: @@ -159,7 +175,9 @@ def main(args): ) as pbar: total_loss = torch.tensor(0.0, device=get_current_device()) for step, batch in enumerate(dataloader): - batch = preprocess_batch(batch, patch_size, video_compressor) + batch = preprocess_batch( + batch, patch_size, video_compressor, pad_to_multiple=args.sp_size + ) video_inputs = batch.pop("video_latent_states") mask = batch.pop("video_padding_mask") t = torch.randint( @@ -168,7 +186,9 @@ def main(args): (video_inputs.shape[0],), device=video_inputs.device, ) - loss_dict = diffusion.training_losses(model, video_inputs, t, batch, mask=mask) + loss_dict = diffusion.training_losses( + model, video_inputs, t, batch, mask=mask + ) loss = loss_dict["loss"].mean() / args.accumulation_steps total_loss.add_(loss.data) booster.backward(loss, opt) @@ -182,7 +202,9 @@ def main(args): all_reduce_mean(total_loss) pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"}) if coordinator.is_master(): - global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps + global_step = (epoch * num_steps_per_epoch) + ( + step + 1 + ) // args.accumulation_steps writer.add_scalar( tag="Loss", scalar_value=total_loss.item(), @@ -193,9 +215,12 @@ def main(args): # Save DiT checkpoint: if args.save_interval > 0 and ( - (step + 1) % (args.save_interval * args.accumulation_steps) == 0 or (step + 1) == len(dataloader) + (step + 1) % (args.save_interval * args.accumulation_steps) == 0 + or (step + 1) == len(dataloader) ): - save_path = os.path.join(args.checkpoint_dir, f"epoch-{epoch}-step-{step}") + save_path = os.path.join( + args.checkpoint_dir, f"epoch-{epoch}-step-{step}" + ) save_checkpoints(booster, model, opt, ema, save_path, coordinator) logger.info(f"Saved checkpoint to {save_path}", ranks=[0]) @@ -212,18 +237,23 @@ def main(args): if __name__ == "__main__": # Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters). parser = argparse.ArgumentParser() - parser.add_argument("-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8") + parser.add_argument( + "-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8" + ) parser.add_argument("-d", "--dataset", nargs="+", default=[]) parser.add_argument("-v", "--video_dir", type=str, required=True) parser.add_argument("-e", "--epochs", type=int, default=10) parser.add_argument("-b", "--batch_size", type=int, default=4) parser.add_argument("-g", "--grad_checkpoint", action="store_true", default=False) parser.add_argument("-a", "--accumulation_steps", default=1, type=int) + parser.add_argument("--sp_size", type=int, default=1) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--save_interval", type=int, default=20) parser.add_argument("--checkpoint_dir", type=str, default="checkpoints") parser.add_argument("--tensorboard_dir", type=str, default="runs") - parser.add_argument("-c", "--compressor", choices=["raw", "vqvae", "vae"], default="raw") + parser.add_argument( + "-c", "--compressor", choices=["raw", "vqvae", "vae"], default="raw" + ) parser.add_argument("--load_model", default=None) parser.add_argument("--load_optimizer", default=None) args = parser.parse_args()