This commit is contained in:
Shen-Chenhui 2024-04-08 17:17:22 +08:00
parent a401630e7a
commit a634f7327e
2 changed files with 16 additions and 2 deletions

View file

@ -46,3 +46,4 @@ save_dir = "outputs/samples"
batch_size = 8
grad_clip = 1.0
grad_checkpoint = True

View file

@ -25,6 +25,8 @@ from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
from colossalai.utils import get_current_device
from colossalai.nn.optimizer import HybridAdam
from opensora.acceleration.checkpoint import set_grad_checkpoint
def main():
@ -100,7 +102,7 @@ def main():
dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
shuffle=False,
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
@ -119,7 +121,18 @@ def main():
# latent_size = vae.get_latent_size(input_size)
# 3.2. move to device & eval
vae = vae.to(device, dtype).eval()
vae = vae.to(device, dtype)
# 4.5. setup optimizer
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, vae.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True
)
lr_scheduler = None
# 4.6. prepare for training
if cfg.grad_checkpoint:
set_grad_checkpoint(vae)
vae.train()
# # 3.3. build scheduler
# scheduler = build_module(cfg.scheduler, SCHEDULERS)