diff --git a/configs/vae_magvit_v2/train/pipeline_16x128x128.py b/configs/vae_magvit_v2/train/pipeline_16x128x128.py index 722f5cb..a26f8a8 100644 --- a/configs/vae_magvit_v2/train/pipeline_16x128x128.py +++ b/configs/vae_magvit_v2/train/pipeline_16x128x128.py @@ -8,6 +8,7 @@ data_path = "CSV_PATH" use_image_transform = False num_workers = 4 video_contains_first_frame = False +use_pipeline = True # Define acceleration dtype = "bf16" @@ -17,6 +18,10 @@ sp_size = 1 # Define model +vae_2d = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) model = dict( type="VAE_MAGVIT_V2", diff --git a/scripts/train-vae-v2.py b/scripts/train-vae-v2.py index 4a0116a..ed63341 100644 --- a/scripts/train-vae-v2.py +++ b/scripts/train-vae-v2.py @@ -136,6 +136,11 @@ def main(): # 4. build model # ====================================================== # 4.1. build model + + if cfg.get("use_pipeline") == True: + # use 2D VAE, then temporal VAE + vae_2d = build_module(cfg.vae_2d, MODELS) + vae = build_module(cfg.model, MODELS, device=device) vae_numel, vae_numel_trainable = get_model_numel(vae) logger.info( @@ -149,6 +154,9 @@ def main(): ) # 4.3. move to device + if cfg.get("use_pipeline") == True: + vae_2d.to(device, dtype).eval() # eval mode, not training! + vae = vae.to(device, dtype) discriminator = discriminator.to(device, dtype) @@ -307,6 +315,13 @@ def main(): else: video = x + # ===== Spatial VAE ===== + if cfg.get("use_pipeline") == True: + with torch.no_grad(): + video = vae.encode(video) + + breakpoint() + # ====== VAE ====== recon_video, posterior = vae( video,