mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-20 17:35:58 +02:00
inference script
This commit is contained in:
parent
a78ffe95a6
commit
c36d747b2f
|
|
@ -178,10 +178,10 @@ def main():
|
|||
# ===== Spatial VAE =====
|
||||
if cfg.get("use_pipeline") == True:
|
||||
with torch.no_grad():
|
||||
video = vae_2d.encode(video)
|
||||
video_enc_spatail = vae_2d.encode(video)
|
||||
|
||||
recon_video, posterior = vae(
|
||||
video,
|
||||
video_enc_spatail,
|
||||
video_contains_first_frame = video_contains_first_frame
|
||||
)
|
||||
|
||||
|
|
@ -189,7 +189,7 @@ def main():
|
|||
# ====== Calc Loss ======
|
||||
# simple nll loss
|
||||
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(
|
||||
video,
|
||||
video_enc_spatail,
|
||||
recon_video,
|
||||
posterior,
|
||||
split = "eval"
|
||||
|
|
@ -208,7 +208,7 @@ def main():
|
|||
vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss
|
||||
|
||||
# ====== Discriminator Loss ======
|
||||
real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
real_video = pad_at_dim(video_enc_spatail, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
|
||||
if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
|
||||
|
|
@ -239,29 +239,24 @@ def main():
|
|||
|
||||
|
||||
# ===== Spatial VAE =====
|
||||
if cfg.get("use_pipeline") == True:
|
||||
with torch.no_grad():
|
||||
recon_video_decode_spatial = vae_2d.decode(recon_video)
|
||||
|
||||
|
||||
if coordinator.is_master():
|
||||
for idx, sample in enumerate(recon_video):
|
||||
pos = step * cfg.batch_size + idx
|
||||
if cfg.get("use_pipeline") == True:
|
||||
save_path = os.path.join(save_dir, f"sample_{pos}_time_decode")
|
||||
else:
|
||||
save_path = os.path.join(save_dir, f"sample_{pos}")
|
||||
save_sample(sample, fps=cfg.fps, save_path=save_path)
|
||||
|
||||
if cfg.get("use_pipeline") == True:
|
||||
# store intermediate encoded video (spatial)
|
||||
for idx, sample in enumerate(video):
|
||||
with torch.no_grad(): # 2nd stage decoding
|
||||
recon_pipeline = vae_2d.decode(recon_video)
|
||||
recon_2d = vae_2d.decode(video_enc_spatail)
|
||||
|
||||
for idx, (sample_2d, sample_pipeline) in enumerate(zip(recon_pipeline, recon_2d)):
|
||||
pos = step * cfg.batch_size + idx
|
||||
save_path = os.path.join(save_dir, f"sample_{pos}_space_encode")
|
||||
save_sample(sample, fps=cfg.fps, save_path=save_path)
|
||||
# store final decoded video (decompressed spatially)
|
||||
for idx, sample in enumerate(recon_video_decode_spatial):
|
||||
save_path = os.path.join(save_dir, f"sample_{pos}")
|
||||
save_sample(sample_2d, fps=cfg.fps, save_path=save_path+"_2d")
|
||||
save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path+"_pipeline")
|
||||
|
||||
else:
|
||||
for idx, sample in enumerate(recon_video):
|
||||
pos = step * cfg.batch_size + idx
|
||||
save_path = os.path.join(save_dir, f"sample_{pos}_final")
|
||||
save_path = os.path.join(save_dir, f"sample_{pos}")
|
||||
save_sample(sample, fps=cfg.fps, save_path=save_path)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue