diff --git a/opensora/models/vae/vae_3d.py b/opensora/models/vae/vae_3d.py index 29abaec..1907fb4 100644 --- a/opensora/models/vae/vae_3d.py +++ b/opensora/models/vae/vae_3d.py @@ -161,7 +161,7 @@ class Encoder(nn.Module): if self.conv_downsample: t_stride = 2 if self.temporal_downsample[i] else 1 t_pad = 1 if self.temporal_downsample[i] else 0 - self.conv_blocks.append(self.conv_fn(prev_filters, filters, kernel_size=(4, 4, 4), strides=(t_stride, 2, 2)), padding=(t_pad,1,1)) # SCH: should be same in_channel and out_channel + self.conv_blocks.append(self.conv_fn(prev_filters, filters, kernel_size=(4, 4, 4), stride=(t_stride, 2, 2)), padding=(t_pad,1,1)) # SCH: should be same in_channel and out_channel prev_filters = filters # update in_channels # NOTE: downsample, dimensions T, H, W diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index a067304..5d0c806 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -39,6 +39,7 @@ class CausalConv3d(nn.Module): chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode = 'constant', + strides = None, # allow custom stride **kwargs ): super().__init__() @@ -49,7 +50,7 @@ class CausalConv3d(nn.Module): assert is_odd(height_kernel_size) and is_odd(width_kernel_size) dilation = kwargs.pop('dilation', 1) - stride = kwargs.pop('stride', 1) + stride = strides[0] if strides is not None else kwargs.pop('stride', 1) self.pad_mode = pad_mode time_pad = dilation * (time_kernel_size - 1) + (1 - stride) @@ -59,7 +60,7 @@ class CausalConv3d(nn.Module): self.time_pad = time_pad self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) - stride = (stride, 1, 1) + stride = strides if strides is not None else (stride, 1, 1) dilation = (dilation, 1, 1) self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs) @@ -69,7 +70,6 @@ class CausalConv3d(nn.Module): x = nn.F.pad(x, self.time_causal_padding, mode = pad_mode) return self.conv(x) -# TODO: CausalConvTranspose3d class ResBlock(nn.Module): def __init__( @@ -187,6 +187,7 @@ class Encoder(nn.Module): if i < self.num_blocks - 1: # SCH: T-Causal Conv 3x3x3, 128->128, stride t x 2 x 2 t_stride = 2 if self.temporal_downsample[i] else 1 + # TODO: conv_fn usess default stride t x 1 x 1, cannot self.conv_blocks.append(self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, 2, 2))) # SCH: should be same in_channel and out_channel prev_filters = filters # update in_channels @@ -200,7 +201,6 @@ class Encoder(nn.Module): # MAGVIT uses Group Normalization self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, dtype=dtype, device=device) # SCH: separate channels into 32 groups - # TODO: check if this is correct ?? self.conv2 = nn.Conv3d(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), dtype=dtype, device=device, padding="same") def forward(self, x):