This commit is contained in:
Shen-Chenhui 2024-04-08 18:00:30 +08:00
parent e151b64319
commit 45ea2bd29d

View file

@ -59,29 +59,29 @@ def main():
# 2.3 DEBUG: USE BOOSTER
# 2.3. initialize ColossalAI booster
if cfg.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=cfg.dtype,
initial_scale=2**16,
max_norm=cfg.grad_clip,
)
set_data_parallel_group(dist.group.WORLD)
elif cfg.plugin == "zero2-seq":
plugin = ZeroSeqParallelPlugin(
sp_size=cfg.sp_size,
stage=2,
precision=cfg.dtype,
initial_scale=2**16,
max_norm=cfg.grad_clip,
)
set_sequence_parallel_group(plugin.sp_group)
set_data_parallel_group(plugin.dp_group)
else:
raise ValueError(f"Unknown plugin {cfg.plugin}")
booster = Booster(plugin=plugin)
# # 2.3 DEBUG: USE BOOSTER
# # 2.3. initialize ColossalAI booster
# if cfg.plugin == "zero2":
# plugin = LowLevelZeroPlugin(
# stage=2,
# precision=cfg.dtype,
# initial_scale=2**16,
# max_norm=cfg.grad_clip,
# )
# set_data_parallel_group(dist.group.WORLD)
# elif cfg.plugin == "zero2-seq":
# plugin = ZeroSeqParallelPlugin(
# sp_size=cfg.sp_size,
# stage=2,
# precision=cfg.dtype,
# initial_scale=2**16,
# max_norm=cfg.grad_clip,
# )
# set_sequence_parallel_group(plugin.sp_group)
# set_data_parallel_group(plugin.dp_group)
# else:
# raise ValueError(f"Unknown plugin {cfg.plugin}")
# booster = Booster(plugin=plugin)
# ======================================================
@ -154,15 +154,15 @@ def main():
os.makedirs(save_dir, exist_ok=True)
### TODO: DEBUG, USE booster
torch.set_default_dtype(dtype)
# vae, optimizer, _, dataloader, lr_scheduler = booster.boost(
# model=vae, optimizer=optimizer, lr_scheduler=lr_scheduler, dataloader=dataloader
# ### TODO: DEBUG, USE booster
# torch.set_default_dtype(dtype)
# # vae, optimizer, _, dataloader, lr_scheduler = booster.boost(
# # model=vae, optimizer=optimizer, lr_scheduler=lr_scheduler, dataloader=dataloader
# # )
# vae, _, _, dataloader, _ = booster.boost(
# model=vae, dataloader=dataloader
# )
vae, _, _, dataloader, _ = booster.boost(
model=vae, dataloader=dataloader
)
torch.set_default_dtype(torch.float)
# torch.set_default_dtype(torch.float)
# load model using booster