From dfae10e32bfbdf0ac571b54c7b2ea029660cfad6 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Mon, 8 Apr 2024 16:56:28 +0800 Subject: [PATCH] debug --- scripts/inference-debug.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/scripts/inference-debug.py b/scripts/inference-debug.py index 94485a8..add814a 100644 --- a/scripts/inference-debug.py +++ b/scripts/inference-debug.py @@ -47,9 +47,9 @@ def main(): # ====================================================== # 2. runtime variables # ====================================================== - torch.set_grad_enabled(False) - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True + # torch.set_grad_enabled(False) + # torch.backends.cuda.matmul.allow_tf32 = True + # torch.backends.cudnn.allow_tf32 = True # device = "cuda" if torch.cuda.is_available() else "cpu" device = get_current_device() dtype = to_torch_dtype(cfg.dtype) @@ -174,11 +174,11 @@ def main(): loss_steps += 1 running_loss = loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps) - if coordinator.is_master(): - for idx, sample in enumerate(reconstructions): - pos = step * cfg.batch_size + idx - save_path = os.path.join(save_dir, f"sample_{pos}") - save_sample(sample, fps=cfg.fps, save_path=save_path) + # if coordinator.is_master(): + # for idx, sample in enumerate(reconstructions): + # pos = step * cfg.batch_size + idx + # save_path = os.path.join(save_dir, f"sample_{pos}") + # save_sample(sample, fps=cfg.fps, save_path=save_path) print("test loss:", running_loss)