From 1f9c0ab007e5c221e9f201a511740f7b3b7f4b3a Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Tue, 30 Apr 2024 08:37:07 +0000 Subject: [PATCH] update inference vae --- configs/vae/train/video.py | 2 +- scripts/inference-vae.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/configs/vae/train/video.py b/configs/vae/train/video.py index 453a226..5414242 100644 --- a/configs/vae/train/video.py +++ b/configs/vae/train/video.py @@ -19,7 +19,7 @@ plugin = "zero2" # Define model model = dict( type="VideoAutoencoderPipeline", - freeze_vae_2d=True, + freeze_vae_2d=False, from_pretrained=None, vae_2d=dict( type="VideoAutoencoderKL", diff --git a/scripts/inference-vae.py b/scripts/inference-vae.py index fc9a952..ac25fbf 100644 --- a/scripts/inference-vae.py +++ b/scripts/inference-vae.py @@ -111,12 +111,17 @@ def main(): running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps) if not use_dist or coordinator.is_master(): + ori_dir = f"{save_dir}_ori" + rec_dir = f"{save_dir}_rec" + ref_dir = f"{save_dir}_ref" + os.makedirs(ori_dir, exist_ok=True) + os.makedirs(rec_dir, exist_ok=True) + os.makedirs(ref_dir, exist_ok=True) for idx, vid in enumerate(x): pos = step * cfg.batch_size + idx - save_path = os.path.join(save_dir, f"sample_{pos:03d}") - save_sample(vid, fps=cfg.fps, save_path=save_path + "_ori") - save_sample(x_rec[idx], fps=cfg.fps, save_path=save_path + "_rec") - save_sample(x_ref[idx], fps=cfg.fps, save_path=save_path + "_ref") + save_sample(vid, fps=cfg.fps, save_path=f"{ori_dir}/{pos:03d}") + save_sample(x_rec[idx], fps=cfg.fps, save_path=f"{rec_dir}/{pos:03d}") + save_sample(x_ref[idx], fps=cfg.fps, save_path=f"{ref_dir}/{pos:03d}") print("test vae loss:", running_loss) print("test nll loss:", running_nll)