This commit is contained in:
Shen-Chenhui 2024-04-05 16:25:32 +08:00
parent e3584b4e43
commit 88be217a13

View file

@ -180,15 +180,15 @@ class Encoder(nn.Module):
self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1))
def forward(self, x):
dtype = x.dtype
x = self.conv1(x)
dtype, device = x.dtype, x.device
x = self.conv1.to(device, dtype)(x)
for i in range(self.num_blocks):
for j in range(self.num_res_blocks):
x = self.block_res_blocks[i][j](x)
if i < self.num_blocks - 1:
if self.conv_downsample:
x = self.conv_blocks[i](x)
x = self.conv_blocks[i].to(device, dtype)(x)
else:
if self.temporal_downsample[i]:
x = self.avg_pool_with_t(x)
@ -198,9 +198,9 @@ class Encoder(nn.Module):
for i in range(self.num_res_blocks):
x = self.res_blocks[i](x)
x = self.norm1(x)
x = self.norm1.to(device, dtype)(x)
x = self.activate(x)
x = self.conv2(x).to(dtype)
x = self.conv2.to(device, dtype)(x)
return x
@ -326,8 +326,8 @@ class Decoder(nn.Module):
x,
**kwargs,
):
dtype = x.dtype
x = self.conv1(x)
dtype, device = x.dtype, x.device
x = self.conv1.to(device, dtype)(x)
for i in range(self.num_res_blocks):
x = self.res_blocks[i](x)
for i in reversed(range(self.num_blocks)): # reverse here to make decoder symmetric with encoder
@ -337,18 +337,18 @@ class Decoder(nn.Module):
if i > 0:
if self.upsample == 'deconv':
assert self.custom_conv_padding is None, ('Custom padding not implemented for ConvTranspose')
x = self.conv_blocks[i-1](x)
x = self.conv_blocks[i-1].to(device, dtype)(x)
elif self.upsample == 'nearest+conv':
if self.temporal_downsample[i - 1]:
x = self.upsampler_with_t(x)
else:
x = self.upsampler(x)
x = self.conv_blocks[i-1](x)
x = self.conv_blocks[i-1].to(device, dtype)(x)
else:
raise NotImplementedError(f'Unknown upsampler: {self.upsample}')
x = self.norm1(x)
x = self.norm1.to(device, dtype)(x)
x = self.activate(x)
x = self.conv2(x).to(dtype)
x = self.conv2.to(device, dtype)(x)
return x