Open-Sora/configs/vae/train/video_dc_ae.py
Zheng Zangwei (Alex Zheng) febf3ad4b2
Update Open-Sora 2.0 (#807)
* upload v2.0

* update docs

* [hotfix] fit latest fa3 (#802)

* update readme

* update readme

* update readme

* update train readme

* update readme

* update readme: motion score

* cleaning video dc ae WIP

* update config

* add dependency functions

* undo cleaning

* use latest dcae

* complete high compression training

* update hcae config

* cleaned up vae

* update ae.md

* further cleanup

* update vae & ae paths

* align naming of ae

* [hotfix] fix ring attn bwd for fa3 (#803)

* train ae default without wandb

* update config

* update evaluation results

* added hcae report

* update readme

* update readme demo

* update readme demo

* update readme gif

* display demo directly in readme

* update paper

* delete files

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Shen-Chenhui <shen_chenhui@u.nus.edu>
Co-authored-by: wuxiwen <wuxiwen.simon@gmail.com>
2025-03-12 13:14:22 +08:00

75 lines
1.3 KiB
Python

# ============
# model config
# ============
model = dict(
type="dc_ae",
model_name="dc-ae-f32t4c128",
from_scratch=True,
from_pretrained=None,
)
# ============
# data config
# ============
dataset = dict(
type="video_text",
transform_name="resize_crop",
data_path="datasets/pexels_45k_necessary.csv",
fps_max=24,
)
bucket_config = {
"256px_ar1:1": {32: (1.0, 1)},
}
num_bucket_build_workers = 64
num_workers = 12
prefetch_factor = 2
# ============
# train config
# ============
optim = dict(
cls="HybridAdam",
lr=5e-5,
eps=1e-8,
weight_decay=0.0,
adamw_mode=True,
betas=(0.9, 0.98),
)
lr_scheduler = dict(warmup_steps=0)
mixed_strategy = "mixed_video_image"
mixed_image_ratio = 0.2 # 1:4
dtype = "bf16"
plugin = "zero2"
plugin_config = dict(
reduce_bucket_size_in_m=128,
overlap_allgather=False,
)
grad_clip = 1.0
grad_checkpoint = False
pin_memory_cache_pre_alloc_numels = [50 * 1024 * 1024] * num_workers * prefetch_factor
seed = 42
outputs = "outputs"
epochs = 100
log_every = 10
ckpt_every = 3000
keep_n_latest = 50
ema_decay = 0.99
wandb_project = "dcae"
update_warmup_steps = True
# ============
# loss config
# ============
vae_loss_config = dict(
perceptual_loss_weight=0.5,
kl_loss_weight=0,
)