Open-Sora/opensora/models/vae
Shen-Chenhui de0199c6b2 debug
2024-04-16 17:30:03 +08:00
..
__init__.py register module 2024-04-09 18:04:08 +08:00
discriminator_3d.py reimplement blurpool 2024-04-15 17:47:33 +08:00
lpips.py debug 2024-04-02 11:21:01 +08:00
magvit2.py magvit v2 enc dec arc 2024-04-09 17:49:01 +08:00
model_utils.py add simple discriminator 2024-04-11 15:14:16 +08:00
README.md reimplement blurpool 2024-04-15 17:47:33 +08:00
vae.py merge mask-related utils 2024-03-23 16:32:51 +08:00
vae_3d.py strides debug 2024-04-09 18:14:25 +08:00
vae_3d_v2.py debug 2024-04-16 17:30:03 +08:00

Commands

1. VAE 3D

1.1 Train

# train on pexel dataset
WANDB_API_KEY=<wandb_api_key> CUDA_VISIBLE_DEVICES=<n> torchrun --master_port=<port_num> --nnodes=1 --nproc_per_node=1 scripts/train-vae.py configs/vae_3d/train/16x256x256.py --data-path /home/shenchenhui/data/pexels/train.csv --wandb True

1.2 Inference

CUDA_VISIBLE_DEVICES=6 torchrun --standalone --nnodes=1 --nproc_per_node=1 scripts/inference-vae.py configs/vae_3d/inference/16x256x256.py --ckpt-path /home/shenchenhui/Open-Sora-dev/outputs/train_pexel_028/epoch3-global_step20000/ --data-path /home/shenchenhui/data/pexels/debug.csv --save-dir outputs/pexel


# resume training debug
CUDA_VISIBLE_DEVICES=5 torchrun --master_port=29530 --nnodes=1 --nproc_per_node=1 scripts/train-vae.py configs/vae_3d/train/16x256x256.py --data-path /home/shenchenhui/data/pexels/debug.csv  --load /home/shenchenhui/Open-Sora-dev/outputs/006-F16S3-VAE_3D_B/epoch49-global_step50

2. MAGVIT-v2

2.1 dependencies

'accelerate>=0.24.0',
'beartype',
'einops>=0.7.0',
'ema-pytorch>=0.2.4',
'pytorch-warmup',
'gateloop-transformer>=0.2.2',
'kornia',
'opencv-python',
'pillow',
'pytorch-custom-utils>=0.0.9',
'numpy',
'vector-quantize-pytorch>=1.11.8',
'taylor-series-linear-attention>=0.1.5',
'torch',
'torchvision',
'x-transformers'

Note: uses hotfix/zero branch of https://github.com/ver217/ColossalAI.git. clone the repo, go to the branch, then do pip install .

2.2 Train

CUDA_VISIBLE_DEVICES7 torchrun --master_port=29510 --nnodes=1 --nproc_per_node=1 scripts/train-vae-v2.py configs/vae_magvit_v2/train/17x128x128.py --data-path /home/shenchenhui/data/pexels/train.csv