disable t-conv for t=1

This commit is contained in:
Shen-Chenhui 2024-04-29 20:06:18 +08:00
parent c59b499085
commit e3b84247c1

View file

@ -529,12 +529,20 @@ class Encoder(nn.Module):
self.block_res_blocks.append(block_items)
if i < self.num_blocks - 1: # SCH: T-Causal Conv 3x3x3, 128->128, stride t x stride s x stride s
t_stride = 2 if self.temporal_downsample[i] else 1
s_stride = 2 if not self.disable_spatial_downsample else 1
self.conv_blocks.append(
self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride))
) # SCH: should be same in_channel and out_channel
prev_filters = filters # update in_channels
if self.temporal_downsample[i]:
t_stride = 2 if self.temporal_downsample[i] else 1
s_stride = 2 if not self.disable_spatial_downsample else 1
self.conv_blocks.append(
self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride))
) # SCH: should be same in_channel and out_channel
prev_filters = filters # update in_channels
else: # if no t downsample, don't add since this does nothing for pipeline models
self.conv_blocks.append( # Identity
nn.Identity(prev_filters)
)
prev_filters = filters # update in_channels
# last layer res block
self.res_blocks = nn.ModuleList([])
@ -673,14 +681,23 @@ class Decoder(nn.Module):
# conv blocks with upsampling
if i > 0:
t_stride = 2 if self.temporal_downsample[i - 1] else 1
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
self.conv_blocks.insert(
0,
self.conv_fn(
prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3, 3, 3)
),
)
if self.temporal_downsample[i]:
t_stride = 2 if self.temporal_downsample[i - 1] else 1
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
self.conv_blocks.insert(
0,
self.conv_fn(
prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3, 3, 3)
),
)
else:
self.conv_blocks.insert(
0,
nn.Identity(
prev_filters
),
)
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, device=device, dtype=dtype)