This commit is contained in:
Shen-Chenhui 2024-03-29 10:17:09 +08:00
parent c9fac9fa2b
commit 1edc7c60ed

View file

@ -42,7 +42,7 @@ class ResBlock(nn.Module):
# SCH: MAGVIT uses GroupNorm by default
self.norm1 = nn.GroupNorm(num_groups, in_out_channels)
self.conv1 = conv_fn(in_out_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
self.norm2 = nn.GroupNorm(num_groups, self.filters)
self.norm2 = nn.GroupNorm(num_groups, self.filters, device=device)
self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False)
if self.use_conv_shortcut:
self.conv3 = conv_fn(in_out_channels, self.filters, kernel_size=(3, 3, 3), bias=False)