diff --git a/scripts/inference-vae-v2.py b/scripts/inference-vae-v2.py index 4c743b8..8a8b42c 100644 --- a/scripts/inference-vae-v2.py +++ b/scripts/inference-vae-v2.py @@ -178,10 +178,10 @@ def main(): # ===== Spatial VAE ===== if cfg.get("use_pipeline") == True: with torch.no_grad(): - video = vae_2d.encode(video) + video_enc_spatail = vae_2d.encode(video) recon_video, posterior = vae( - video, + video_enc_spatail, video_contains_first_frame = video_contains_first_frame ) @@ -189,7 +189,7 @@ def main(): # ====== Calc Loss ====== # simple nll loss nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn( - video, + video_enc_spatail, recon_video, posterior, split = "eval" @@ -208,7 +208,7 @@ def main(): vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss # ====== Discriminator Loss ====== - real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2) + real_video = pad_at_dim(video_enc_spatail, (disc_time_padding, 0), value = 0., dim = 2) fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2) if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0: @@ -239,29 +239,24 @@ def main(): # ===== Spatial VAE ===== - if cfg.get("use_pipeline") == True: - with torch.no_grad(): - recon_video_decode_spatial = vae_2d.decode(recon_video) + if coordinator.is_master(): - for idx, sample in enumerate(recon_video): - pos = step * cfg.batch_size + idx - if cfg.get("use_pipeline") == True: - save_path = os.path.join(save_dir, f"sample_{pos}_time_decode") - else: - save_path = os.path.join(save_dir, f"sample_{pos}") - save_sample(sample, fps=cfg.fps, save_path=save_path) - if cfg.get("use_pipeline") == True: - # store intermediate encoded video (spatial) - for idx, sample in enumerate(video): + with torch.no_grad(): # 2nd stage decoding + recon_pipeline = vae_2d.decode(recon_video) + recon_2d = vae_2d.decode(video_enc_spatail) + + for idx, (sample_2d, sample_pipeline) in enumerate(zip(recon_pipeline, recon_2d)): pos = step * cfg.batch_size + idx - save_path = os.path.join(save_dir, f"sample_{pos}_space_encode") - save_sample(sample, fps=cfg.fps, save_path=save_path) - # store final decoded video (decompressed spatially) - for idx, sample in enumerate(recon_video_decode_spatial): + save_path = os.path.join(save_dir, f"sample_{pos}") + save_sample(sample_2d, fps=cfg.fps, save_path=save_path+"_2d") + save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path+"_pipeline") + + else: + for idx, sample in enumerate(recon_video): pos = step * cfg.batch_size + idx - save_path = os.path.join(save_dir, f"sample_{pos}_final") + save_path = os.path.join(save_dir, f"sample_{pos}") save_sample(sample, fps=cfg.fps, save_path=save_path)