strides debug

This commit is contained in:
Shen-Chenhui 2024-04-09 18:14:25 +08:00
parent d8666a73c4
commit 1bbef18f15
2 changed files with 5 additions and 5 deletions

View file

@ -161,7 +161,7 @@ class Encoder(nn.Module):
if self.conv_downsample:
t_stride = 2 if self.temporal_downsample[i] else 1
t_pad = 1 if self.temporal_downsample[i] else 0
self.conv_blocks.append(self.conv_fn(prev_filters, filters, kernel_size=(4, 4, 4), strides=(t_stride, 2, 2)), padding=(t_pad,1,1)) # SCH: should be same in_channel and out_channel
self.conv_blocks.append(self.conv_fn(prev_filters, filters, kernel_size=(4, 4, 4), stride=(t_stride, 2, 2)), padding=(t_pad,1,1)) # SCH: should be same in_channel and out_channel
prev_filters = filters # update in_channels
# NOTE: downsample, dimensions T, H, W

View file

@ -39,6 +39,7 @@ class CausalConv3d(nn.Module):
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
pad_mode = 'constant',
strides = None, # allow custom stride
**kwargs
):
super().__init__()
@ -49,7 +50,7 @@ class CausalConv3d(nn.Module):
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
dilation = kwargs.pop('dilation', 1)
stride = kwargs.pop('stride', 1)
stride = strides[0] if strides is not None else kwargs.pop('stride', 1)
self.pad_mode = pad_mode
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
@ -59,7 +60,7 @@ class CausalConv3d(nn.Module):
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
stride = (stride, 1, 1)
stride = strides if strides is not None else (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs)
@ -69,7 +70,6 @@ class CausalConv3d(nn.Module):
x = nn.F.pad(x, self.time_causal_padding, mode = pad_mode)
return self.conv(x)
# TODO: CausalConvTranspose3d
class ResBlock(nn.Module):
def __init__(
@ -187,6 +187,7 @@ class Encoder(nn.Module):
if i < self.num_blocks - 1: # SCH: T-Causal Conv 3x3x3, 128->128, stride t x 2 x 2
t_stride = 2 if self.temporal_downsample[i] else 1
# TODO: conv_fn usess default stride t x 1 x 1, cannot
self.conv_blocks.append(self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, 2, 2))) # SCH: should be same in_channel and out_channel
prev_filters = filters # update in_channels
@ -200,7 +201,6 @@ class Encoder(nn.Module):
# MAGVIT uses Group Normalization
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, dtype=dtype, device=device) # SCH: separate <prev_filters> channels into 32 groups
# TODO: check if this is correct ??
self.conv2 = nn.Conv3d(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), dtype=dtype, device=device, padding="same")
def forward(self, x):