mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-24 17:44:47 +02:00
debug
This commit is contained in:
parent
78f9ec8710
commit
dfae10e32b
|
|
@ -47,9 +47,9 @@ def main():
|
|||
# ======================================================
|
||||
# 2. runtime variables
|
||||
# ======================================================
|
||||
torch.set_grad_enabled(False)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
# torch.set_grad_enabled(False)
|
||||
# torch.backends.cuda.matmul.allow_tf32 = True
|
||||
# torch.backends.cudnn.allow_tf32 = True
|
||||
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device = get_current_device()
|
||||
dtype = to_torch_dtype(cfg.dtype)
|
||||
|
|
@ -174,11 +174,11 @@ def main():
|
|||
loss_steps += 1
|
||||
running_loss = loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
|
||||
|
||||
if coordinator.is_master():
|
||||
for idx, sample in enumerate(reconstructions):
|
||||
pos = step * cfg.batch_size + idx
|
||||
save_path = os.path.join(save_dir, f"sample_{pos}")
|
||||
save_sample(sample, fps=cfg.fps, save_path=save_path)
|
||||
# if coordinator.is_master():
|
||||
# for idx, sample in enumerate(reconstructions):
|
||||
# pos = step * cfg.batch_size + idx
|
||||
# save_path = os.path.join(save_dir, f"sample_{pos}")
|
||||
# save_sample(sample, fps=cfg.fps, save_path=save_path)
|
||||
|
||||
print("test loss:", running_loss)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue