add pipeline

This commit is contained in:
Shen-Chenhui 2024-04-18 13:49:54 +08:00
parent f867ebc819
commit aba1de6eb3
2 changed files with 20 additions and 0 deletions

View file

@ -8,6 +8,7 @@ data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
video_contains_first_frame = False
use_pipeline = True
# Define acceleration
dtype = "bf16"
@ -17,6 +18,10 @@ sp_size = 1
# Define model
vae_2d = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
)
model = dict(
type="VAE_MAGVIT_V2",

View file

@ -136,6 +136,11 @@ def main():
# 4. build model
# ======================================================
# 4.1. build model
if cfg.get("use_pipeline") == True:
# use 2D VAE, then temporal VAE
vae_2d = build_module(cfg.vae_2d, MODELS)
vae = build_module(cfg.model, MODELS, device=device)
vae_numel, vae_numel_trainable = get_model_numel(vae)
logger.info(
@ -149,6 +154,9 @@ def main():
)
# 4.3. move to device
if cfg.get("use_pipeline") == True:
vae_2d.to(device, dtype).eval() # eval mode, not training!
vae = vae.to(device, dtype)
discriminator = discriminator.to(device, dtype)
@ -307,6 +315,13 @@ def main():
else:
video = x
# ===== Spatial VAE =====
if cfg.get("use_pipeline") == True:
with torch.no_grad():
video = vae.encode(video)
breakpoint()
# ====== VAE ======
recon_video, posterior = vae(
video,