[feat] complete dynamic pos

This commit is contained in:
Zangwei Zheng 2024-03-28 19:27:17 +08:00
parent cba02e8a58
commit 8b778213d1
2 changed files with 7 additions and 8 deletions

View file

@ -4,7 +4,7 @@ dataset = dict(
data_path=None,
num_frames=None,
frame_interval=3,
image_size=(360, 480), # base size
image_size=(512, 512), # pretrained model is trained on 512x512
transform_name="resize_crop",
)
bucket_config = {
@ -24,7 +24,6 @@ sp_size = 1
# Define model
model = dict(
type="STDiT2-XL/2",
space_scale=0.5,
from_pretrained="PixArt-XL-2-1024-MS.pth",
enable_flashattn=True,
enable_layernorm_kernel=True,

View file

@ -189,7 +189,6 @@ class STDiT2(nn.Module):
caption_channels=4096,
model_max_length=120,
dtype=torch.float32,
space_scale=1.0,
freeze=None,
enable_flashattn=False,
enable_layernorm_kernel=False,
@ -207,12 +206,11 @@ class STDiT2(nn.Module):
self.mlp_ratio = mlp_ratio
self.enable_flashattn = enable_flashattn
self.enable_layernorm_kernel = enable_layernorm_kernel
self.space_scale = space_scale
# support dynamic input
self.patch_size = patch_size
self.input_size = input_size
self.base_size = (input_size[1] / patch_size[1] * input_size[2] / patch_size[2]) ** 0.5
self.input_sq_size = int((input_size[1] * input_size[2]) ** 0.5)
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
@ -324,7 +322,9 @@ class STDiT2(nn.Module):
_, _, Tx, Hx, Wx = x.size()
T, H, W = self.get_dynamic_size(x)
S = H * W
pos_emb = self.get_spatial_pos_embed(H, W, self.base_size).to(x.device, x.dtype)
rs = (height[0] * width[0]).sqrt().item()
scale = rs / self.input_sq_size
pos_emb = self.get_spatial_pos_embed(H, W, scale=scale, base_size=round(S**0.5)).to(x.device, x.dtype)
# embedding
x = self.x_embedder(x) # [B, N, C]
@ -432,11 +432,11 @@ class STDiT2(nn.Module):
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_spatial_pos_embed(self, H, W, base_size=None):
def get_spatial_pos_embed(self, H, W, scale=1.0, base_size=None):
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
(H, W),
scale=self.space_scale,
scale=scale,
base_size=base_size,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)