mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
[feat] complete dynamic pos
This commit is contained in:
parent
cba02e8a58
commit
8b778213d1
|
|
@ -4,7 +4,7 @@ dataset = dict(
|
|||
data_path=None,
|
||||
num_frames=None,
|
||||
frame_interval=3,
|
||||
image_size=(360, 480), # base size
|
||||
image_size=(512, 512), # pretrained model is trained on 512x512
|
||||
transform_name="resize_crop",
|
||||
)
|
||||
bucket_config = {
|
||||
|
|
@ -24,7 +24,6 @@ sp_size = 1
|
|||
# Define model
|
||||
model = dict(
|
||||
type="STDiT2-XL/2",
|
||||
space_scale=0.5,
|
||||
from_pretrained="PixArt-XL-2-1024-MS.pth",
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
|
|
|
|||
|
|
@ -189,7 +189,6 @@ class STDiT2(nn.Module):
|
|||
caption_channels=4096,
|
||||
model_max_length=120,
|
||||
dtype=torch.float32,
|
||||
space_scale=1.0,
|
||||
freeze=None,
|
||||
enable_flashattn=False,
|
||||
enable_layernorm_kernel=False,
|
||||
|
|
@ -207,12 +206,11 @@ class STDiT2(nn.Module):
|
|||
self.mlp_ratio = mlp_ratio
|
||||
self.enable_flashattn = enable_flashattn
|
||||
self.enable_layernorm_kernel = enable_layernorm_kernel
|
||||
self.space_scale = space_scale
|
||||
|
||||
# support dynamic input
|
||||
self.patch_size = patch_size
|
||||
self.input_size = input_size
|
||||
self.base_size = (input_size[1] / patch_size[1] * input_size[2] / patch_size[2]) ** 0.5
|
||||
self.input_sq_size = int((input_size[1] * input_size[2]) ** 0.5)
|
||||
|
||||
self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
|
|
@ -324,7 +322,9 @@ class STDiT2(nn.Module):
|
|||
_, _, Tx, Hx, Wx = x.size()
|
||||
T, H, W = self.get_dynamic_size(x)
|
||||
S = H * W
|
||||
pos_emb = self.get_spatial_pos_embed(H, W, self.base_size).to(x.device, x.dtype)
|
||||
rs = (height[0] * width[0]).sqrt().item()
|
||||
scale = rs / self.input_sq_size
|
||||
pos_emb = self.get_spatial_pos_embed(H, W, scale=scale, base_size=round(S**0.5)).to(x.device, x.dtype)
|
||||
|
||||
# embedding
|
||||
x = self.x_embedder(x) # [B, N, C]
|
||||
|
|
@ -432,11 +432,11 @@ class STDiT2(nn.Module):
|
|||
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
|
||||
return imgs
|
||||
|
||||
def get_spatial_pos_embed(self, H, W, base_size=None):
|
||||
def get_spatial_pos_embed(self, H, W, scale=1.0, base_size=None):
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
self.hidden_size,
|
||||
(H, W),
|
||||
scale=self.space_scale,
|
||||
scale=scale,
|
||||
base_size=base_size,
|
||||
)
|
||||
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
|
||||
|
|
|
|||
Loading…
Reference in a new issue