Specifically, our VAE consists of a pipeline of a [spatial VAE](https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers) followed by a temporal VAE.
For the temporal VAE, we follow the implementation of [MAGVIT-v2](https://arxiv.org/abs/2310.05737), with the following modifications:
* We remove the architecture specific to the codebook.
* We do not use the discriminator, and use the VAE reconstruction loss, kl loss, and perceptual loss for training.
* In the last linear layer of the encoder, we scale down to a diagonal Gaussian Distribution of 4 channels, following our previously trained STDiT that takes in 4 channels input.
* Our decoder is symmetric to the encoder architecture.
## Training
We train the model in different stages.
We first train the temporal VAE only by freezing the spatial VAE for 380k steps on a single machine (8 GPUs).
We use an additional identity loss to make features from the 3D VAE similar to the features from the 2D VAE.
We train the VAE using 20% images and 80% videos with 17 frames.
Note that you need to adjust the `epochs` in the config file accordingly with respect to your own csv data size.
## Inference
To visually check the performance of the VAE, you may run the following inference.
It saves the original video to your specified video directory with `_ori` postfix (i.e. `"YOUR_VIDEO_DIR"_ori`), the reconstructed video from the full pipeline with the `_rec` postfix (i.e. `"YOUR_VIDEO_DIR"_rec`), and the reconstructed video from the 2D compression and decompression with the `_spatial` postfix (i.e. `"YOUR_VIDEO_DIR"_spatial`).