diff --git a/scripts/inference-debug.py b/scripts/inference-debug.py index 0ef10ed..13b02f4 100644 --- a/scripts/inference-debug.py +++ b/scripts/inference-debug.py @@ -59,29 +59,29 @@ def main(): - # 2.3 DEBUG: USE BOOSTER - # 2.3. initialize ColossalAI booster - if cfg.plugin == "zero2": - plugin = LowLevelZeroPlugin( - stage=2, - precision=cfg.dtype, - initial_scale=2**16, - max_norm=cfg.grad_clip, - ) - set_data_parallel_group(dist.group.WORLD) - elif cfg.plugin == "zero2-seq": - plugin = ZeroSeqParallelPlugin( - sp_size=cfg.sp_size, - stage=2, - precision=cfg.dtype, - initial_scale=2**16, - max_norm=cfg.grad_clip, - ) - set_sequence_parallel_group(plugin.sp_group) - set_data_parallel_group(plugin.dp_group) - else: - raise ValueError(f"Unknown plugin {cfg.plugin}") - booster = Booster(plugin=plugin) + # # 2.3 DEBUG: USE BOOSTER + # # 2.3. initialize ColossalAI booster + # if cfg.plugin == "zero2": + # plugin = LowLevelZeroPlugin( + # stage=2, + # precision=cfg.dtype, + # initial_scale=2**16, + # max_norm=cfg.grad_clip, + # ) + # set_data_parallel_group(dist.group.WORLD) + # elif cfg.plugin == "zero2-seq": + # plugin = ZeroSeqParallelPlugin( + # sp_size=cfg.sp_size, + # stage=2, + # precision=cfg.dtype, + # initial_scale=2**16, + # max_norm=cfg.grad_clip, + # ) + # set_sequence_parallel_group(plugin.sp_group) + # set_data_parallel_group(plugin.dp_group) + # else: + # raise ValueError(f"Unknown plugin {cfg.plugin}") + # booster = Booster(plugin=plugin) # ====================================================== @@ -154,15 +154,15 @@ def main(): os.makedirs(save_dir, exist_ok=True) - ### TODO: DEBUG, USE booster - torch.set_default_dtype(dtype) - # vae, optimizer, _, dataloader, lr_scheduler = booster.boost( - # model=vae, optimizer=optimizer, lr_scheduler=lr_scheduler, dataloader=dataloader + # ### TODO: DEBUG, USE booster + # torch.set_default_dtype(dtype) + # # vae, optimizer, _, dataloader, lr_scheduler = booster.boost( + # # model=vae, optimizer=optimizer, lr_scheduler=lr_scheduler, dataloader=dataloader + # # ) + # vae, _, _, dataloader, _ = booster.boost( + # model=vae, dataloader=dataloader # ) - vae, _, _, dataloader, _ = booster.boost( - model=vae, dataloader=dataloader - ) - torch.set_default_dtype(torch.float) + # torch.set_default_dtype(torch.float) # load model using booster