mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
debug
This commit is contained in:
parent
dbe1c4cf2f
commit
89a53d423c
|
|
@ -429,13 +429,16 @@ class Encoder(nn.Module):
|
|||
|
||||
for i in range(self.num_blocks):
|
||||
for j in range(self.num_res_blocks):
|
||||
breakpoint()
|
||||
x = self.block_res_blocks[i][j](x)
|
||||
|
||||
if i < self.num_blocks - 1:
|
||||
breakpoint()
|
||||
x = self.conv_blocks[i](x)
|
||||
|
||||
|
||||
for i in range(self.num_res_blocks):
|
||||
breakpoint()
|
||||
x = self.res_blocks[i](x)
|
||||
|
||||
x = self.norm1(x)
|
||||
|
|
@ -550,14 +553,17 @@ class Decoder(nn.Module):
|
|||
dtype, device = x.dtype, x.device
|
||||
x = self.conv1(x)
|
||||
for i in range(self.num_res_blocks):
|
||||
breakpoint()
|
||||
x = self.res_blocks[i](x)
|
||||
for i in reversed(range(self.num_blocks)): # reverse here to make decoder symmetric with encoder
|
||||
for j in range(self.num_res_blocks):
|
||||
breakpoint()
|
||||
x = self.block_res_blocks[i][j](x)
|
||||
|
||||
if i > 0:
|
||||
t_stride = 2 if self.temporal_downsample[i - 1] else 1
|
||||
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
|
||||
breakpoint()
|
||||
x = self.conv_blocks[i-1](x)
|
||||
x = rearrange(x, "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)", ts=t_stride, hs=2, ws=2)
|
||||
|
||||
|
|
@ -738,6 +744,7 @@ class VAE_3D_V2(nn.Module):
|
|||
video, _ = pack([first_frame, video], 'b c * h w')
|
||||
video = pad_at_dim(video, (self.time_padding, 0), dim = 2)
|
||||
|
||||
breakpoint()
|
||||
encoded_feature = self.encoder(video)
|
||||
|
||||
moments = self.quant_conv(encoded_feature).to(video.dtype)
|
||||
|
|
@ -754,7 +761,7 @@ class VAE_3D_V2(nn.Module):
|
|||
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
|
||||
breakpoint()
|
||||
# SCH: moved decoder last conv layer here for separate first frame decoding
|
||||
if decode_first_frame_separately:
|
||||
left_pad, dec_ff, dec = dec[:, :, :self.time_padding], dec[:, :, self.time_padding], dec[:, :, (self.time_padding + 1):]
|
||||
|
|
@ -789,6 +796,7 @@ class VAE_3D_V2(nn.Module):
|
|||
batch, channels, frames = video.shape[:3]
|
||||
assert divisible_by(frames - int(video_contains_first_frame), self.time_downsample_factor), f'number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}'
|
||||
|
||||
breakpoint()
|
||||
posterior = self.encode(
|
||||
video,
|
||||
video_contains_first_frame = video_contains_first_frame,
|
||||
|
|
@ -799,11 +807,14 @@ class VAE_3D_V2(nn.Module):
|
|||
else:
|
||||
z = posterior.mode()
|
||||
|
||||
breakpoint()
|
||||
|
||||
recon_video = self.decode(
|
||||
z,
|
||||
video_contains_first_frame = video_contains_first_frame
|
||||
)
|
||||
|
||||
breakpoint()
|
||||
recon_loss = F.mse_loss(video, recon_video)
|
||||
|
||||
kl_loss = posterior.kl()
|
||||
|
|
|
|||
Loading…
Reference in a new issue