From e3b84247c1bb143b2211817fec6433110248ed56 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Mon, 29 Apr 2024 20:06:18 +0800 Subject: [PATCH] disable t-conv for t=1 --- opensora/models/vae/vae_3d.py | 45 ++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/opensora/models/vae/vae_3d.py b/opensora/models/vae/vae_3d.py index 9c61c97..d4a47ad 100644 --- a/opensora/models/vae/vae_3d.py +++ b/opensora/models/vae/vae_3d.py @@ -529,12 +529,20 @@ class Encoder(nn.Module): self.block_res_blocks.append(block_items) if i < self.num_blocks - 1: # SCH: T-Causal Conv 3x3x3, 128->128, stride t x stride s x stride s - t_stride = 2 if self.temporal_downsample[i] else 1 - s_stride = 2 if not self.disable_spatial_downsample else 1 - self.conv_blocks.append( - self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride)) - ) # SCH: should be same in_channel and out_channel - prev_filters = filters # update in_channels + if self.temporal_downsample[i]: + t_stride = 2 if self.temporal_downsample[i] else 1 + s_stride = 2 if not self.disable_spatial_downsample else 1 + self.conv_blocks.append( + self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride)) + ) # SCH: should be same in_channel and out_channel + prev_filters = filters # update in_channels + else: # if no t downsample, don't add since this does nothing for pipeline models + self.conv_blocks.append( # Identity + nn.Identity(prev_filters) + ) + prev_filters = filters # update in_channels + + # last layer res block self.res_blocks = nn.ModuleList([]) @@ -673,14 +681,23 @@ class Decoder(nn.Module): # conv blocks with upsampling if i > 0: - t_stride = 2 if self.temporal_downsample[i - 1] else 1 - # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 - self.conv_blocks.insert( - 0, - self.conv_fn( - prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3, 3, 3) - ), - ) + if self.temporal_downsample[i]: + t_stride = 2 if self.temporal_downsample[i - 1] else 1 + # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 + self.conv_blocks.insert( + 0, + self.conv_fn( + prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3, 3, 3) + ), + ) + else: + self.conv_blocks.insert( + 0, + nn.Identity( + prev_filters + ), + ) + self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, device=device, dtype=dtype)