mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 12:49:38 +02:00
[ckpt] fix shape error when gathering weights under sp + dp parallelism
This commit is contained in:
parent
3455b7e0fb
commit
bc4aa4f217
|
|
@ -304,13 +304,13 @@ def master_weights_gathering(model: torch.nn.Module, optimizer: LowLevelZeroOpti
|
|||
model_shape_dict (dict): The shape of the model parameters.
|
||||
device (torch.device): The device to gather the model to.
|
||||
"""
|
||||
pg = get_data_parallel_group(get_mixed_dp_pg=True)
|
||||
world_size = dist.get_world_size(pg)
|
||||
w2m = optimizer.get_working_to_master_map()
|
||||
for name, param in model.named_parameters():
|
||||
master_p = w2m[id(param)]
|
||||
zero_pg = optimizer.param_to_pg[param]
|
||||
world_size = dist.get_world_size(zero_pg)
|
||||
all_params = [torch.empty_like(master_p) for _ in range(world_size)]
|
||||
dist.all_gather(all_params, master_p, group=pg)
|
||||
dist.all_gather(all_params, master_p, group=zero_pg)
|
||||
if dist.get_rank() == 0:
|
||||
all_params = torch.cat(all_params)
|
||||
gathered_param = remove_padding(all_params, param.shape).view(param.shape)
|
||||
|
|
|
|||
Loading…
Reference in a new issue