inference script

This commit is contained in:
Shen-Chenhui 2024-04-19 11:06:13 +08:00
parent a78ffe95a6
commit c36d747b2f

View file

@ -178,10 +178,10 @@ def main():
# ===== Spatial VAE =====
if cfg.get("use_pipeline") == True:
with torch.no_grad():
video = vae_2d.encode(video)
video_enc_spatail = vae_2d.encode(video)
recon_video, posterior = vae(
video,
video_enc_spatail,
video_contains_first_frame = video_contains_first_frame
)
@ -189,7 +189,7 @@ def main():
# ====== Calc Loss ======
# simple nll loss
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(
video,
video_enc_spatail,
recon_video,
posterior,
split = "eval"
@ -208,7 +208,7 @@ def main():
vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss
# ====== Discriminator Loss ======
real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2)
real_video = pad_at_dim(video_enc_spatail, (disc_time_padding, 0), value = 0., dim = 2)
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
@ -239,29 +239,24 @@ def main():
# ===== Spatial VAE =====
if cfg.get("use_pipeline") == True:
with torch.no_grad():
recon_video_decode_spatial = vae_2d.decode(recon_video)
if coordinator.is_master():
for idx, sample in enumerate(recon_video):
pos = step * cfg.batch_size + idx
if cfg.get("use_pipeline") == True:
save_path = os.path.join(save_dir, f"sample_{pos}_time_decode")
else:
save_path = os.path.join(save_dir, f"sample_{pos}")
save_sample(sample, fps=cfg.fps, save_path=save_path)
if cfg.get("use_pipeline") == True:
# store intermediate encoded video (spatial)
for idx, sample in enumerate(video):
with torch.no_grad(): # 2nd stage decoding
recon_pipeline = vae_2d.decode(recon_video)
recon_2d = vae_2d.decode(video_enc_spatail)
for idx, (sample_2d, sample_pipeline) in enumerate(zip(recon_pipeline, recon_2d)):
pos = step * cfg.batch_size + idx
save_path = os.path.join(save_dir, f"sample_{pos}_space_encode")
save_sample(sample, fps=cfg.fps, save_path=save_path)
# store final decoded video (decompressed spatially)
for idx, sample in enumerate(recon_video_decode_spatial):
save_path = os.path.join(save_dir, f"sample_{pos}")
save_sample(sample_2d, fps=cfg.fps, save_path=save_path+"_2d")
save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path+"_pipeline")
else:
for idx, sample in enumerate(recon_video):
pos = step * cfg.batch_size + idx
save_path = os.path.join(save_dir, f"sample_{pos}_final")
save_path = os.path.join(save_dir, f"sample_{pos}")
save_sample(sample, fps=cfg.fps, save_path=save_path)