mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-18 08:31:02 +02:00
[sp] added padding (#160)
This commit is contained in:
parent
7115864314
commit
ee1c79a898
|
|
@ -4,6 +4,7 @@ import numpy as np
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
from timm.models.layers import DropPath
|
||||
|
|
@ -361,6 +362,19 @@ class STDiT3(PreTrainedModel):
|
|||
# === get pos embed ===
|
||||
_, _, Tx, Hx, Wx = x.size()
|
||||
T, H, W = self.get_dynamic_size(x)
|
||||
|
||||
# adjust for sequence parallelism
|
||||
# we need to ensure H * W is divisible by sequence parallel size
|
||||
# for simplicity, we can adjust the height to make it divisible
|
||||
if self.enable_sequence_parallelism:
|
||||
sp_size = dist.get_world_size(get_sequence_parallel_group())
|
||||
h_pad_size = sp_size - H % sp_size
|
||||
hx_pad_size = h_pad_size * self.patch_size[1]
|
||||
|
||||
# pad x along the H dimension
|
||||
H += h_pad_size
|
||||
x = F.pad(x, (0, 0, 0, hx_pad_size))
|
||||
|
||||
S = H * W
|
||||
base_size = round(S**0.5)
|
||||
resolution_sq = (height[0].item() * width[0].item()) ** 0.5
|
||||
|
|
|
|||
Loading…
Reference in a new issue