From bc4aa4f217944d21e468226a7d900d743736e7ee Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 26 Mar 2025 14:52:48 +0800 Subject: [PATCH] [ckpt] fix shape error when gathering weights under sp + dp parallelism --- opensora/utils/ckpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/opensora/utils/ckpt.py b/opensora/utils/ckpt.py index c1e499e..2a0efdc 100644 --- a/opensora/utils/ckpt.py +++ b/opensora/utils/ckpt.py @@ -304,13 +304,13 @@ def master_weights_gathering(model: torch.nn.Module, optimizer: LowLevelZeroOpti model_shape_dict (dict): The shape of the model parameters. device (torch.device): The device to gather the model to. """ - pg = get_data_parallel_group(get_mixed_dp_pg=True) - world_size = dist.get_world_size(pg) w2m = optimizer.get_working_to_master_map() for name, param in model.named_parameters(): master_p = w2m[id(param)] + zero_pg = optimizer.param_to_pg[param] + world_size = dist.get_world_size(zero_pg) all_params = [torch.empty_like(master_p) for _ in range(world_size)] - dist.all_gather(all_params, master_p, group=pg) + dist.all_gather(all_params, master_p, group=zero_pg) if dist.get_rank() == 0: all_params = torch.cat(all_params) gathered_param = remove_padding(all_params, param.shape).view(param.shape)