Merge branch 'vae-clean-new' of https://github.com/hpcaitech/Open-Sora-dev into vae-clean-new

This commit is contained in:
Shen-Chenhui 2024-04-30 16:51:33 +08:00
commit 4328e9a507
2 changed files with 10 additions and 5 deletions

View file

@ -19,7 +19,7 @@ plugin = "zero2"
# Define model
model = dict(
type="VideoAutoencoderPipeline",
freeze_vae_2d=True,
freeze_vae_2d=False,
from_pretrained=None,
vae_2d=dict(
type="VideoAutoencoderKL",

View file

@ -111,12 +111,17 @@ def main():
running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps)
if not use_dist or coordinator.is_master():
ori_dir = f"{save_dir}_ori"
rec_dir = f"{save_dir}_rec"
ref_dir = f"{save_dir}_ref"
os.makedirs(ori_dir, exist_ok=True)
os.makedirs(rec_dir, exist_ok=True)
os.makedirs(ref_dir, exist_ok=True)
for idx, vid in enumerate(x):
pos = step * cfg.batch_size + idx
save_path = os.path.join(save_dir, f"sample_{pos:03d}")
save_sample(vid, fps=cfg.fps, save_path=save_path + "_ori")
save_sample(x_rec[idx], fps=cfg.fps, save_path=save_path + "_rec")
save_sample(x_ref[idx], fps=cfg.fps, save_path=save_path + "_ref")
save_sample(vid, fps=cfg.fps, save_path=f"{ori_dir}/{pos:03d}")
save_sample(x_rec[idx], fps=cfg.fps, save_path=f"{rec_dir}/{pos:03d}")
save_sample(x_ref[idx], fps=cfg.fps, save_path=f"{ref_dir}/{pos:03d}")
print("test vae loss:", running_loss)
print("test nll loss:", running_nll)