[ckpt] fix shape error when gathering weights under sp + dp parallelism

This commit is contained in:
hxwang 2025-03-26 14:52:48 +08:00 committed by botbw
parent 3455b7e0fb
commit bc4aa4f217

View file

@ -304,13 +304,13 @@ def master_weights_gathering(model: torch.nn.Module, optimizer: LowLevelZeroOpti
model_shape_dict (dict): The shape of the model parameters.
device (torch.device): The device to gather the model to.
"""
pg = get_data_parallel_group(get_mixed_dp_pg=True)
world_size = dist.get_world_size(pg)
w2m = optimizer.get_working_to_master_map()
for name, param in model.named_parameters():
master_p = w2m[id(param)]
zero_pg = optimizer.param_to_pg[param]
world_size = dist.get_world_size(zero_pg)
all_params = [torch.empty_like(master_p) for _ in range(world_size)]
dist.all_gather(all_params, master_p, group=pg)
dist.all_gather(all_params, master_p, group=zero_pg)
if dist.get_rank() == 0:
all_params = torch.cat(all_params)
gathered_param = remove_padding(all_params, param.shape).view(param.shape)