enable pipeline vae training

This commit is contained in:
Shen-Chenhui 2024-04-18 14:13:56 +08:00
parent aba1de6eb3
commit 246eddecdd
3 changed files with 16 additions and 12 deletions

View file

@ -25,7 +25,7 @@ vae_2d = dict(
model = dict(
type="VAE_MAGVIT_V2",
in_out_channels = 3,
in_out_channels = 4,
latent_embed_dim = 256,
filters = 128,
num_res_blocks = 4,
@ -44,7 +44,7 @@ discriminator = dict(
type="DISCRIMINATOR_3D",
image_size = image_size,
num_frames = num_frames,
in_channels = 3,
in_channels = 4,
filters = 128,
channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution
)

View file

@ -932,6 +932,8 @@ class VAE_3D_V2(nn.Module): # , ModelMixin
video_contains_first_frame = video_contains_first_frame
)
breakpoint()
return recon_video, posterior
@ -991,13 +993,18 @@ class VEALoss(nn.Module):
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0:
# handle channels
channels = video.shape[1]
assert channels in {1,3,4}
if channels == 1:
input_vgg_input = repeat(input_vgg_input, 'b 1 h w -> b c h w', c = 3)
recon_vgg_input = repeat(recon_vgg_input, 'b 1 h w -> b c h w', c = 3)
input_vgg_input = repeat(video, 'b 1 h w -> b c h w', c = 3)
recon_vgg_input = repeat(recon_video, 'b 1 h w -> b c h w', c = 3)
elif channels == 4: # SCH: take the first 3 for perceptual loss calc
input_vgg_input = input_vgg_input[:, :3]
recon_vgg_input = recon_vgg_input[:, :3]
perceptual_loss = self.perceptual_loss_fn(video, recon_video)
input_vgg_input = video[:, :3]
recon_vgg_input = recon_video[:, :3]
else:
input_vgg_input = video
recon_vgg_input = recon_video
perceptual_loss = self.perceptual_loss_fn(input_vgg_input, recon_vgg_input)
recon_loss = recon_loss + self.perceptual_loss_weight * perceptual_loss
nll_loss = recon_loss / torch.exp(self.logvar) + self.logvar
@ -1223,7 +1230,6 @@ def load_checkpoint_with_inflation(model, ckpt_path):
"""
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
state_dict = find_model(ckpt_path)
breakpoint() # NOTE: need to manually check before first use
with torch.no_grad():
for key in state_dict:
if key in model:

View file

@ -314,13 +314,11 @@ def main():
video_contains_first_frame = True
else:
video = x
# ===== Spatial VAE =====
if cfg.get("use_pipeline") == True:
with torch.no_grad():
video = vae.encode(video)
breakpoint()
video = vae_2d.encode(video)
# ====== VAE ======
recon_video, posterior = vae(