This commit is contained in:
shenchenhui 2024-03-28 22:51:22 +08:00
parent f1a6f7523e
commit 356ff604c0

View file

@ -47,9 +47,9 @@ class ResBlock(nn.Module):
self.norm2 = nn.GroupNorm(num_groups, self.filters, dtype=dtype, device=device)
self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False)
if self.use_conv_shortcut:
self.conv3 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False)
self.conv3 = conv_fn(in_out_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
else:
self.conv3 = conv_fn(self.filters, self.filters, kernel_size=(1, 1, 1), bias=False)
self.conv3 = conv_fn(in_out_channels, self.filters, kernel_size=(1, 1, 1), bias=False)
def forward(self, x):