mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 21:42:26 +02:00
Merge branch 'vae-clean-new' of https://github.com/hpcaitech/Open-Sora-dev into vae-clean-new
This commit is contained in:
commit
4328e9a507
|
|
@ -19,7 +19,7 @@ plugin = "zero2"
|
|||
# Define model
|
||||
model = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
freeze_vae_2d=True,
|
||||
freeze_vae_2d=False,
|
||||
from_pretrained=None,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
|
|
|
|||
|
|
@ -111,12 +111,17 @@ def main():
|
|||
running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps)
|
||||
|
||||
if not use_dist or coordinator.is_master():
|
||||
ori_dir = f"{save_dir}_ori"
|
||||
rec_dir = f"{save_dir}_rec"
|
||||
ref_dir = f"{save_dir}_ref"
|
||||
os.makedirs(ori_dir, exist_ok=True)
|
||||
os.makedirs(rec_dir, exist_ok=True)
|
||||
os.makedirs(ref_dir, exist_ok=True)
|
||||
for idx, vid in enumerate(x):
|
||||
pos = step * cfg.batch_size + idx
|
||||
save_path = os.path.join(save_dir, f"sample_{pos:03d}")
|
||||
save_sample(vid, fps=cfg.fps, save_path=save_path + "_ori")
|
||||
save_sample(x_rec[idx], fps=cfg.fps, save_path=save_path + "_rec")
|
||||
save_sample(x_ref[idx], fps=cfg.fps, save_path=save_path + "_ref")
|
||||
save_sample(vid, fps=cfg.fps, save_path=f"{ori_dir}/{pos:03d}")
|
||||
save_sample(x_rec[idx], fps=cfg.fps, save_path=f"{rec_dir}/{pos:03d}")
|
||||
save_sample(x_ref[idx], fps=cfg.fps, save_path=f"{ref_dir}/{pos:03d}")
|
||||
|
||||
print("test vae loss:", running_loss)
|
||||
print("test nll loss:", running_nll)
|
||||
|
|
|
|||
Loading…
Reference in a new issue