This commit is contained in:
Shen-Chenhui 2024-04-11 16:44:02 +08:00
parent dbe1c4cf2f
commit 89a53d423c

View file

@ -429,13 +429,16 @@ class Encoder(nn.Module):
for i in range(self.num_blocks):
for j in range(self.num_res_blocks):
breakpoint()
x = self.block_res_blocks[i][j](x)
if i < self.num_blocks - 1:
breakpoint()
x = self.conv_blocks[i](x)
for i in range(self.num_res_blocks):
breakpoint()
x = self.res_blocks[i](x)
x = self.norm1(x)
@ -550,14 +553,17 @@ class Decoder(nn.Module):
dtype, device = x.dtype, x.device
x = self.conv1(x)
for i in range(self.num_res_blocks):
breakpoint()
x = self.res_blocks[i](x)
for i in reversed(range(self.num_blocks)): # reverse here to make decoder symmetric with encoder
for j in range(self.num_res_blocks):
breakpoint()
x = self.block_res_blocks[i][j](x)
if i > 0:
t_stride = 2 if self.temporal_downsample[i - 1] else 1
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
breakpoint()
x = self.conv_blocks[i-1](x)
x = rearrange(x, "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)", ts=t_stride, hs=2, ws=2)
@ -738,6 +744,7 @@ class VAE_3D_V2(nn.Module):
video, _ = pack([first_frame, video], 'b c * h w')
video = pad_at_dim(video, (self.time_padding, 0), dim = 2)
breakpoint()
encoded_feature = self.encoder(video)
moments = self.quant_conv(encoded_feature).to(video.dtype)
@ -754,7 +761,7 @@ class VAE_3D_V2(nn.Module):
z = self.post_quant_conv(z)
dec = self.decoder(z)
breakpoint()
# SCH: moved decoder last conv layer here for separate first frame decoding
if decode_first_frame_separately:
left_pad, dec_ff, dec = dec[:, :, :self.time_padding], dec[:, :, self.time_padding], dec[:, :, (self.time_padding + 1):]
@ -789,6 +796,7 @@ class VAE_3D_V2(nn.Module):
batch, channels, frames = video.shape[:3]
assert divisible_by(frames - int(video_contains_first_frame), self.time_downsample_factor), f'number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}'
breakpoint()
posterior = self.encode(
video,
video_contains_first_frame = video_contains_first_frame,
@ -799,11 +807,14 @@ class VAE_3D_V2(nn.Module):
else:
z = posterior.mode()
breakpoint()
recon_video = self.decode(
z,
video_contains_first_frame = video_contains_first_frame
)
breakpoint()
recon_loss = F.mse_loss(video, recon_video)
kl_loss = posterior.kl()