mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-07 13:02:35 +02:00
debug
This commit is contained in:
parent
e151b64319
commit
45ea2bd29d
|
|
@ -59,29 +59,29 @@ def main():
|
|||
|
||||
|
||||
|
||||
# 2.3 DEBUG: USE BOOSTER
|
||||
# 2.3. initialize ColossalAI booster
|
||||
if cfg.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=cfg.dtype,
|
||||
initial_scale=2**16,
|
||||
max_norm=cfg.grad_clip,
|
||||
)
|
||||
set_data_parallel_group(dist.group.WORLD)
|
||||
elif cfg.plugin == "zero2-seq":
|
||||
plugin = ZeroSeqParallelPlugin(
|
||||
sp_size=cfg.sp_size,
|
||||
stage=2,
|
||||
precision=cfg.dtype,
|
||||
initial_scale=2**16,
|
||||
max_norm=cfg.grad_clip,
|
||||
)
|
||||
set_sequence_parallel_group(plugin.sp_group)
|
||||
set_data_parallel_group(plugin.dp_group)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {cfg.plugin}")
|
||||
booster = Booster(plugin=plugin)
|
||||
# # 2.3 DEBUG: USE BOOSTER
|
||||
# # 2.3. initialize ColossalAI booster
|
||||
# if cfg.plugin == "zero2":
|
||||
# plugin = LowLevelZeroPlugin(
|
||||
# stage=2,
|
||||
# precision=cfg.dtype,
|
||||
# initial_scale=2**16,
|
||||
# max_norm=cfg.grad_clip,
|
||||
# )
|
||||
# set_data_parallel_group(dist.group.WORLD)
|
||||
# elif cfg.plugin == "zero2-seq":
|
||||
# plugin = ZeroSeqParallelPlugin(
|
||||
# sp_size=cfg.sp_size,
|
||||
# stage=2,
|
||||
# precision=cfg.dtype,
|
||||
# initial_scale=2**16,
|
||||
# max_norm=cfg.grad_clip,
|
||||
# )
|
||||
# set_sequence_parallel_group(plugin.sp_group)
|
||||
# set_data_parallel_group(plugin.dp_group)
|
||||
# else:
|
||||
# raise ValueError(f"Unknown plugin {cfg.plugin}")
|
||||
# booster = Booster(plugin=plugin)
|
||||
|
||||
|
||||
# ======================================================
|
||||
|
|
@ -154,15 +154,15 @@ def main():
|
|||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
|
||||
### TODO: DEBUG, USE booster
|
||||
torch.set_default_dtype(dtype)
|
||||
# vae, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
# model=vae, optimizer=optimizer, lr_scheduler=lr_scheduler, dataloader=dataloader
|
||||
# ### TODO: DEBUG, USE booster
|
||||
# torch.set_default_dtype(dtype)
|
||||
# # vae, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
# # model=vae, optimizer=optimizer, lr_scheduler=lr_scheduler, dataloader=dataloader
|
||||
# # )
|
||||
# vae, _, _, dataloader, _ = booster.boost(
|
||||
# model=vae, dataloader=dataloader
|
||||
# )
|
||||
vae, _, _, dataloader, _ = booster.boost(
|
||||
model=vae, dataloader=dataloader
|
||||
)
|
||||
torch.set_default_dtype(torch.float)
|
||||
# torch.set_default_dtype(torch.float)
|
||||
|
||||
|
||||
# load model using booster
|
||||
|
|
|
|||
Loading…
Reference in a new issue