From 88be217a13586d66def321c1365507bda43a3a66 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Fri, 5 Apr 2024 16:25:32 +0800 Subject: [PATCH] debug --- opensora/models/vae/vae_3d.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/opensora/models/vae/vae_3d.py b/opensora/models/vae/vae_3d.py index 38698a4..58ec853 100644 --- a/opensora/models/vae/vae_3d.py +++ b/opensora/models/vae/vae_3d.py @@ -180,15 +180,15 @@ class Encoder(nn.Module): self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1)) def forward(self, x): - dtype = x.dtype - x = self.conv1(x) + dtype, device = x.dtype, x.device + x = self.conv1.to(device, dtype)(x) for i in range(self.num_blocks): for j in range(self.num_res_blocks): x = self.block_res_blocks[i][j](x) if i < self.num_blocks - 1: if self.conv_downsample: - x = self.conv_blocks[i](x) + x = self.conv_blocks[i].to(device, dtype)(x) else: if self.temporal_downsample[i]: x = self.avg_pool_with_t(x) @@ -198,9 +198,9 @@ class Encoder(nn.Module): for i in range(self.num_res_blocks): x = self.res_blocks[i](x) - x = self.norm1(x) + x = self.norm1.to(device, dtype)(x) x = self.activate(x) - x = self.conv2(x).to(dtype) + x = self.conv2.to(device, dtype)(x) return x @@ -326,8 +326,8 @@ class Decoder(nn.Module): x, **kwargs, ): - dtype = x.dtype - x = self.conv1(x) + dtype, device = x.dtype, x.device + x = self.conv1.to(device, dtype)(x) for i in range(self.num_res_blocks): x = self.res_blocks[i](x) for i in reversed(range(self.num_blocks)): # reverse here to make decoder symmetric with encoder @@ -337,18 +337,18 @@ class Decoder(nn.Module): if i > 0: if self.upsample == 'deconv': assert self.custom_conv_padding is None, ('Custom padding not implemented for ConvTranspose') - x = self.conv_blocks[i-1](x) + x = self.conv_blocks[i-1].to(device, dtype)(x) elif self.upsample == 'nearest+conv': if self.temporal_downsample[i - 1]: x = self.upsampler_with_t(x) else: x = self.upsampler(x) - x = self.conv_blocks[i-1](x) + x = self.conv_blocks[i-1].to(device, dtype)(x) else: raise NotImplementedError(f'Unknown upsampler: {self.upsample}') - x = self.norm1(x) + x = self.norm1.to(device, dtype)(x) x = self.activate(x) - x = self.conv2(x).to(dtype) + x = self.conv2.to(device, dtype)(x) return x