mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-16 21:23:27 +02:00
add pipeline
This commit is contained in:
parent
f867ebc819
commit
aba1de6eb3
|
|
@ -8,6 +8,7 @@ data_path = "CSV_PATH"
|
|||
use_image_transform = False
|
||||
num_workers = 4
|
||||
video_contains_first_frame = False
|
||||
use_pipeline = True
|
||||
|
||||
# Define acceleration
|
||||
dtype = "bf16"
|
||||
|
|
@ -17,6 +18,10 @@ sp_size = 1
|
|||
|
||||
|
||||
# Define model
|
||||
vae_2d = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type="VAE_MAGVIT_V2",
|
||||
|
|
|
|||
|
|
@ -136,6 +136,11 @@ def main():
|
|||
# 4. build model
|
||||
# ======================================================
|
||||
# 4.1. build model
|
||||
|
||||
if cfg.get("use_pipeline") == True:
|
||||
# use 2D VAE, then temporal VAE
|
||||
vae_2d = build_module(cfg.vae_2d, MODELS)
|
||||
|
||||
vae = build_module(cfg.model, MODELS, device=device)
|
||||
vae_numel, vae_numel_trainable = get_model_numel(vae)
|
||||
logger.info(
|
||||
|
|
@ -149,6 +154,9 @@ def main():
|
|||
)
|
||||
|
||||
# 4.3. move to device
|
||||
if cfg.get("use_pipeline") == True:
|
||||
vae_2d.to(device, dtype).eval() # eval mode, not training!
|
||||
|
||||
vae = vae.to(device, dtype)
|
||||
discriminator = discriminator.to(device, dtype)
|
||||
|
||||
|
|
@ -307,6 +315,13 @@ def main():
|
|||
else:
|
||||
video = x
|
||||
|
||||
# ===== Spatial VAE =====
|
||||
if cfg.get("use_pipeline") == True:
|
||||
with torch.no_grad():
|
||||
video = vae.encode(video)
|
||||
|
||||
breakpoint()
|
||||
|
||||
# ====== VAE ======
|
||||
recon_video, posterior = vae(
|
||||
video,
|
||||
|
|
|
|||
Loading…
Reference in a new issue