mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
disable t-conv for t=1
This commit is contained in:
parent
c59b499085
commit
e3b84247c1
|
|
@ -529,12 +529,20 @@ class Encoder(nn.Module):
|
|||
self.block_res_blocks.append(block_items)
|
||||
|
||||
if i < self.num_blocks - 1: # SCH: T-Causal Conv 3x3x3, 128->128, stride t x stride s x stride s
|
||||
t_stride = 2 if self.temporal_downsample[i] else 1
|
||||
s_stride = 2 if not self.disable_spatial_downsample else 1
|
||||
self.conv_blocks.append(
|
||||
self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride))
|
||||
) # SCH: should be same in_channel and out_channel
|
||||
prev_filters = filters # update in_channels
|
||||
if self.temporal_downsample[i]:
|
||||
t_stride = 2 if self.temporal_downsample[i] else 1
|
||||
s_stride = 2 if not self.disable_spatial_downsample else 1
|
||||
self.conv_blocks.append(
|
||||
self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride))
|
||||
) # SCH: should be same in_channel and out_channel
|
||||
prev_filters = filters # update in_channels
|
||||
else: # if no t downsample, don't add since this does nothing for pipeline models
|
||||
self.conv_blocks.append( # Identity
|
||||
nn.Identity(prev_filters)
|
||||
)
|
||||
prev_filters = filters # update in_channels
|
||||
|
||||
|
||||
|
||||
# last layer res block
|
||||
self.res_blocks = nn.ModuleList([])
|
||||
|
|
@ -673,14 +681,23 @@ class Decoder(nn.Module):
|
|||
|
||||
# conv blocks with upsampling
|
||||
if i > 0:
|
||||
t_stride = 2 if self.temporal_downsample[i - 1] else 1
|
||||
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
|
||||
self.conv_blocks.insert(
|
||||
0,
|
||||
self.conv_fn(
|
||||
prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3, 3, 3)
|
||||
),
|
||||
)
|
||||
if self.temporal_downsample[i]:
|
||||
t_stride = 2 if self.temporal_downsample[i - 1] else 1
|
||||
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
|
||||
self.conv_blocks.insert(
|
||||
0,
|
||||
self.conv_fn(
|
||||
prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3, 3, 3)
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.conv_blocks.insert(
|
||||
0,
|
||||
nn.Identity(
|
||||
prev_filters
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, device=device, dtype=dtype)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue