diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index d74a72d..15ea9a9 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -429,13 +429,16 @@ class Encoder(nn.Module): for i in range(self.num_blocks): for j in range(self.num_res_blocks): + breakpoint() x = self.block_res_blocks[i][j](x) if i < self.num_blocks - 1: + breakpoint() x = self.conv_blocks[i](x) for i in range(self.num_res_blocks): + breakpoint() x = self.res_blocks[i](x) x = self.norm1(x) @@ -550,14 +553,17 @@ class Decoder(nn.Module): dtype, device = x.dtype, x.device x = self.conv1(x) for i in range(self.num_res_blocks): + breakpoint() x = self.res_blocks[i](x) for i in reversed(range(self.num_blocks)): # reverse here to make decoder symmetric with encoder for j in range(self.num_res_blocks): + breakpoint() x = self.block_res_blocks[i][j](x) if i > 0: t_stride = 2 if self.temporal_downsample[i - 1] else 1 # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 + breakpoint() x = self.conv_blocks[i-1](x) x = rearrange(x, "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)", ts=t_stride, hs=2, ws=2) @@ -738,6 +744,7 @@ class VAE_3D_V2(nn.Module): video, _ = pack([first_frame, video], 'b c * h w') video = pad_at_dim(video, (self.time_padding, 0), dim = 2) + breakpoint() encoded_feature = self.encoder(video) moments = self.quant_conv(encoded_feature).to(video.dtype) @@ -754,7 +761,7 @@ class VAE_3D_V2(nn.Module): z = self.post_quant_conv(z) dec = self.decoder(z) - + breakpoint() # SCH: moved decoder last conv layer here for separate first frame decoding if decode_first_frame_separately: left_pad, dec_ff, dec = dec[:, :, :self.time_padding], dec[:, :, self.time_padding], dec[:, :, (self.time_padding + 1):] @@ -789,6 +796,7 @@ class VAE_3D_V2(nn.Module): batch, channels, frames = video.shape[:3] assert divisible_by(frames - int(video_contains_first_frame), self.time_downsample_factor), f'number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}' + breakpoint() posterior = self.encode( video, video_contains_first_frame = video_contains_first_frame, @@ -799,11 +807,14 @@ class VAE_3D_V2(nn.Module): else: z = posterior.mode() + breakpoint() + recon_video = self.decode( z, video_contains_first_frame = video_contains_first_frame ) + breakpoint() recon_loss = F.mse_loss(video, recon_video) kl_loss = posterior.kl()