mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-21 11:59:01 +02:00
debug
This commit is contained in:
parent
a401630e7a
commit
a634f7327e
|
|
@ -46,3 +46,4 @@ save_dir = "outputs/samples"
|
|||
batch_size = 8
|
||||
|
||||
grad_clip = 1.0
|
||||
grad_checkpoint = True
|
||||
|
|
@ -25,6 +25,8 @@ from colossalai.booster import Booster
|
|||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -100,7 +102,7 @@ def main():
|
|||
dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
num_workers=cfg.num_workers,
|
||||
shuffle=False,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
|
|
@ -119,7 +121,18 @@ def main():
|
|||
# latent_size = vae.get_latent_size(input_size)
|
||||
|
||||
# 3.2. move to device & eval
|
||||
vae = vae.to(device, dtype).eval()
|
||||
vae = vae.to(device, dtype)
|
||||
|
||||
# 4.5. setup optimizer
|
||||
optimizer = HybridAdam(
|
||||
filter(lambda p: p.requires_grad, vae.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True
|
||||
)
|
||||
lr_scheduler = None
|
||||
|
||||
# 4.6. prepare for training
|
||||
if cfg.grad_checkpoint:
|
||||
set_grad_checkpoint(vae)
|
||||
vae.train()
|
||||
|
||||
# # 3.3. build scheduler
|
||||
# scheduler = build_module(cfg.scheduler, SCHEDULERS)
|
||||
|
|
|
|||
Loading…
Reference in a new issue