diff --git a/opensora/utils/ckpt.py b/opensora/utils/ckpt.py index c1e499e..2a0efdc 100644 --- a/opensora/utils/ckpt.py +++ b/opensora/utils/ckpt.py @@ -304,13 +304,13 @@ def master_weights_gathering(model: torch.nn.Module, optimizer: LowLevelZeroOpti model_shape_dict (dict): The shape of the model parameters. device (torch.device): The device to gather the model to. """ - pg = get_data_parallel_group(get_mixed_dp_pg=True) - world_size = dist.get_world_size(pg) w2m = optimizer.get_working_to_master_map() for name, param in model.named_parameters(): master_p = w2m[id(param)] + zero_pg = optimizer.param_to_pg[param] + world_size = dist.get_world_size(zero_pg) all_params = [torch.empty_like(master_p) for _ in range(world_size)] - dist.all_gather(all_params, master_p, group=pg) + dist.all_gather(all_params, master_p, group=zero_pg) if dist.get_rank() == 0: all_params = torch.cat(all_params) gathered_param = remove_padding(all_params, param.shape).view(param.shape)