[sp] added padding (#160)

This commit is contained in:
Frank Lee 2024-06-24 13:59:29 +08:00 committed by GitHub
parent 7115864314
commit ee1c79a898

View file

@ -4,6 +4,7 @@ import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding
from timm.models.layers import DropPath
@ -361,6 +362,19 @@ class STDiT3(PreTrainedModel):
# === get pos embed ===
_, _, Tx, Hx, Wx = x.size()
T, H, W = self.get_dynamic_size(x)
# adjust for sequence parallelism
# we need to ensure H * W is divisible by sequence parallel size
# for simplicity, we can adjust the height to make it divisible
if self.enable_sequence_parallelism:
sp_size = dist.get_world_size(get_sequence_parallel_group())
h_pad_size = sp_size - H % sp_size
hx_pad_size = h_pad_size * self.patch_size[1]
# pad x along the H dimension
H += h_pad_size
x = F.pad(x, (0, 0, 0, hx_pad_size))
S = H * W
base_size = round(S**0.5)
resolution_sq = (height[0].item() * width[0].item()) ** 0.5