mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
debug
This commit is contained in:
parent
e3584b4e43
commit
88be217a13
|
|
@ -180,15 +180,15 @@ class Encoder(nn.Module):
|
|||
self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
dtype = x.dtype
|
||||
x = self.conv1(x)
|
||||
dtype, device = x.dtype, x.device
|
||||
x = self.conv1.to(device, dtype)(x)
|
||||
for i in range(self.num_blocks):
|
||||
for j in range(self.num_res_blocks):
|
||||
x = self.block_res_blocks[i][j](x)
|
||||
|
||||
if i < self.num_blocks - 1:
|
||||
if self.conv_downsample:
|
||||
x = self.conv_blocks[i](x)
|
||||
x = self.conv_blocks[i].to(device, dtype)(x)
|
||||
else:
|
||||
if self.temporal_downsample[i]:
|
||||
x = self.avg_pool_with_t(x)
|
||||
|
|
@ -198,9 +198,9 @@ class Encoder(nn.Module):
|
|||
for i in range(self.num_res_blocks):
|
||||
x = self.res_blocks[i](x)
|
||||
|
||||
x = self.norm1(x)
|
||||
x = self.norm1.to(device, dtype)(x)
|
||||
x = self.activate(x)
|
||||
x = self.conv2(x).to(dtype)
|
||||
x = self.conv2.to(device, dtype)(x)
|
||||
return x
|
||||
|
||||
|
||||
|
|
@ -326,8 +326,8 @@ class Decoder(nn.Module):
|
|||
x,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = x.dtype
|
||||
x = self.conv1(x)
|
||||
dtype, device = x.dtype, x.device
|
||||
x = self.conv1.to(device, dtype)(x)
|
||||
for i in range(self.num_res_blocks):
|
||||
x = self.res_blocks[i](x)
|
||||
for i in reversed(range(self.num_blocks)): # reverse here to make decoder symmetric with encoder
|
||||
|
|
@ -337,18 +337,18 @@ class Decoder(nn.Module):
|
|||
if i > 0:
|
||||
if self.upsample == 'deconv':
|
||||
assert self.custom_conv_padding is None, ('Custom padding not implemented for ConvTranspose')
|
||||
x = self.conv_blocks[i-1](x)
|
||||
x = self.conv_blocks[i-1].to(device, dtype)(x)
|
||||
elif self.upsample == 'nearest+conv':
|
||||
if self.temporal_downsample[i - 1]:
|
||||
x = self.upsampler_with_t(x)
|
||||
else:
|
||||
x = self.upsampler(x)
|
||||
x = self.conv_blocks[i-1](x)
|
||||
x = self.conv_blocks[i-1].to(device, dtype)(x)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown upsampler: {self.upsample}')
|
||||
x = self.norm1(x)
|
||||
x = self.norm1.to(device, dtype)(x)
|
||||
x = self.activate(x)
|
||||
x = self.conv2(x).to(dtype)
|
||||
x = self.conv2.to(device, dtype)(x)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue