diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index 669cf54..d74a72d 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -114,12 +114,11 @@ class CausalConv3d(nn.Module): self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) stride = strides if strides is not None else (stride, 1, 1) - padding = kwargs.pop('padding', 0) - - if padding == "same" and not all([pad == 1 for pad in padding]): - padding = "valid" + # padding = kwargs.pop('padding', 0) + # if padding == "same" and not all([pad == 1 for pad in padding]): + # padding = "valid" dilation = (dilation, 1, 1) - self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, padding=padding, **kwargs) + self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs) def forward(self, x): pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant' @@ -372,7 +371,7 @@ class Encoder(nn.Module): self.conv_fn = functools.partial( CausalConv3d, - padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch + # padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch dtype=dtype, device=device, ) @@ -485,7 +484,7 @@ class Decoder(nn.Module): self.conv_fn = functools.partial( CausalConv3d, dtype=dtype, - padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch + # padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch device=device, )