From 356ff604c0502bdb61e94975f26cf6b81b370d1f Mon Sep 17 00:00:00 2001 From: shenchenhui Date: Thu, 28 Mar 2024 22:51:22 +0800 Subject: [PATCH] debug --- opensora/models/vae/vae_3d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/opensora/models/vae/vae_3d.py b/opensora/models/vae/vae_3d.py index 97d4ef1..42a7a5f 100644 --- a/opensora/models/vae/vae_3d.py +++ b/opensora/models/vae/vae_3d.py @@ -47,9 +47,9 @@ class ResBlock(nn.Module): self.norm2 = nn.GroupNorm(num_groups, self.filters, dtype=dtype, device=device) self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False) if self.use_conv_shortcut: - self.conv3 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False) + self.conv3 = conv_fn(in_out_channels, self.filters, kernel_size=(3, 3, 3), bias=False) else: - self.conv3 = conv_fn(self.filters, self.filters, kernel_size=(1, 1, 1), bias=False) + self.conv3 = conv_fn(in_out_channels, self.filters, kernel_size=(1, 1, 1), bias=False) def forward(self, x):