From 8b778213d17128c9028184f64cd136141f163327 Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Thu, 28 Mar 2024 19:27:17 +0800 Subject: [PATCH] [feat] complete dynamic pos --- configs/opensora-v1-1/train/Vx360p.py | 3 +-- opensora/models/stdit/stdit2.py | 12 ++++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/configs/opensora-v1-1/train/Vx360p.py b/configs/opensora-v1-1/train/Vx360p.py index 4c7ea62..8d831a6 100644 --- a/configs/opensora-v1-1/train/Vx360p.py +++ b/configs/opensora-v1-1/train/Vx360p.py @@ -4,7 +4,7 @@ dataset = dict( data_path=None, num_frames=None, frame_interval=3, - image_size=(360, 480), # base size + image_size=(512, 512), # pretrained model is trained on 512x512 transform_name="resize_crop", ) bucket_config = { @@ -24,7 +24,6 @@ sp_size = 1 # Define model model = dict( type="STDiT2-XL/2", - space_scale=0.5, from_pretrained="PixArt-XL-2-1024-MS.pth", enable_flashattn=True, enable_layernorm_kernel=True, diff --git a/opensora/models/stdit/stdit2.py b/opensora/models/stdit/stdit2.py index bb08f24..25973fd 100644 --- a/opensora/models/stdit/stdit2.py +++ b/opensora/models/stdit/stdit2.py @@ -189,7 +189,6 @@ class STDiT2(nn.Module): caption_channels=4096, model_max_length=120, dtype=torch.float32, - space_scale=1.0, freeze=None, enable_flashattn=False, enable_layernorm_kernel=False, @@ -207,12 +206,11 @@ class STDiT2(nn.Module): self.mlp_ratio = mlp_ratio self.enable_flashattn = enable_flashattn self.enable_layernorm_kernel = enable_layernorm_kernel - self.space_scale = space_scale # support dynamic input self.patch_size = patch_size self.input_size = input_size - self.base_size = (input_size[1] / patch_size[1] * input_size[2] / patch_size[2]) ** 0.5 + self.input_sq_size = int((input_size[1] * input_size[2]) ** 0.5) self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size) self.t_embedder = TimestepEmbedder(hidden_size) @@ -324,7 +322,9 @@ class STDiT2(nn.Module): _, _, Tx, Hx, Wx = x.size() T, H, W = self.get_dynamic_size(x) S = H * W - pos_emb = self.get_spatial_pos_embed(H, W, self.base_size).to(x.device, x.dtype) + rs = (height[0] * width[0]).sqrt().item() + scale = rs / self.input_sq_size + pos_emb = self.get_spatial_pos_embed(H, W, scale=scale, base_size=round(S**0.5)).to(x.device, x.dtype) # embedding x = self.x_embedder(x) # [B, N, C] @@ -432,11 +432,11 @@ class STDiT2(nn.Module): imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) return imgs - def get_spatial_pos_embed(self, H, W, base_size=None): + def get_spatial_pos_embed(self, H, W, scale=1.0, base_size=None): pos_embed = get_2d_sincos_pos_embed( self.hidden_size, (H, W), - scale=self.space_scale, + scale=scale, base_size=base_size, ) pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)