This commit is contained in:
Shen-Chenhui 2024-04-11 16:33:22 +08:00
parent 4741d952d8
commit dbe1c4cf2f

View file

@ -114,12 +114,11 @@ class CausalConv3d(nn.Module):
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
stride = strides if strides is not None else (stride, 1, 1)
padding = kwargs.pop('padding', 0)
if padding == "same" and not all([pad == 1 for pad in padding]):
padding = "valid"
# padding = kwargs.pop('padding', 0)
# if padding == "same" and not all([pad == 1 for pad in padding]):
# padding = "valid"
dilation = (dilation, 1, 1)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, padding=padding, **kwargs)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs)
def forward(self, x):
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'
@ -372,7 +371,7 @@ class Encoder(nn.Module):
self.conv_fn = functools.partial(
CausalConv3d,
padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch
# padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch
dtype=dtype,
device=device,
)
@ -485,7 +484,7 @@ class Decoder(nn.Module):
self.conv_fn = functools.partial(
CausalConv3d,
dtype=dtype,
padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch
# padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch
device=device,
)