This commit is contained in:
Shen-Chenhui 2024-04-08 16:56:28 +08:00
parent 78f9ec8710
commit dfae10e32b

View file

@ -47,9 +47,9 @@ def main():
# ======================================================
# 2. runtime variables
# ======================================================
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# torch.set_grad_enabled(False)
# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = get_current_device()
dtype = to_torch_dtype(cfg.dtype)
@ -174,11 +174,11 @@ def main():
loss_steps += 1
running_loss = loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
if coordinator.is_master():
for idx, sample in enumerate(reconstructions):
pos = step * cfg.batch_size + idx
save_path = os.path.join(save_dir, f"sample_{pos}")
save_sample(sample, fps=cfg.fps, save_path=save_path)
# if coordinator.is_master():
# for idx, sample in enumerate(reconstructions):
# pos = step * cfg.batch_size + idx
# save_path = os.path.join(save_dir, f"sample_{pos}")
# save_sample(sample, fps=cfg.fps, save_path=save_path)
print("test loss:", running_loss)