From fa0ca3983e82d43c612bfafd95d39851775e5eaa Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Mon, 8 Apr 2024 15:40:14 +0800 Subject: [PATCH] debug inference code --- opensora/models/vae/README.md | 8 +- scripts/train-vae.py | 174 ++++++++++++++++++++-------------- 2 files changed, 109 insertions(+), 73 deletions(-) diff --git a/opensora/models/vae/README.md b/opensora/models/vae/README.md index 6cdbe14..2f251a0 100644 --- a/opensora/models/vae/README.md +++ b/opensora/models/vae/README.md @@ -11,11 +11,11 @@ WANDB_API_KEY= CUDA_VISIBLE_DEVICES= torchrun --master_port=

0 and (global_step + 1) % cfg.ckpt_every == 0: - save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") - os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) - # TODO: save in model? - booster.save_model(vae, os.path.join(save_dir, "model"), shard=True) - booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) - if lr_scheduler is not None: - booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) - running_states = { - "epoch": epoch, - "step": step+1, - "global_step": global_step+1, - "sample_start_index": (step+1) * cfg.batch_size, - } - if coordinator.is_master(): - save_json(running_states, os.path.join(save_dir, "running_states.json")) - dist.barrier() - logger.info( - f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" - ) + # # Save checkpoint + # if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0: + # save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") + # os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) + # # TODO: save in model? + # booster.save_model(vae, os.path.join(save_dir, "model"), shard=True) + # booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) + # if lr_scheduler is not None: + # booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + # running_states = { + # "epoch": epoch, + # "step": step+1, + # "global_step": global_step+1, + # "sample_start_index": (step+1) * cfg.batch_size, + # } + # if coordinator.is_master(): + # save_json(running_states, os.path.join(save_dir, "running_states.json")) + # dist.barrier() + # logger.info( + # f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" + # ) - # the continue epochs are not resumed, so we need to reset the sampler start index and start step - dataloader.sampler.set_start_index(0) - start_step = 0 + # # the continue epochs are not resumed, so we need to reset the sampler start index and start step + # dataloader.sampler.set_start_index(0) + # start_step = 0 + # DEBUG inference + + # 4.1. batch generation + + # define loss function + loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype) + running_loss = 0.0 + loss_steps = 0 + + from opensora.datasets import save_sample + + total_steps = len(dataloader) + dataloader_iter = iter(dataloader) + + with tqdm( + range(total_steps), + # desc=f"Avg Loss: {running_loss}", + disable=not coordinator.is_master(), + total=total_steps, + initial=0, + ) as pbar: + for step in pbar: + batch = next(dataloader_iter) + x = batch["video"].to(device, dtype) # [B, C, T, H, W] + reconstructions, posterior = vae(x) + loss = loss_function(x, reconstructions, posterior) + loss_steps += 1 + running_loss = loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps) + + if coordinator.is_master(): + for idx, sample in enumerate(reconstructions): + pos = step * cfg.batch_size + idx + save_path = os.path.join("outputs/debug", f"sample_{pos}") + save_sample(sample, fps=cfg.fps, save_path=save_path) + + print("test loss:", running_loss) if __name__ == "__main__": main()