Skip to content

Commit

Permalink
Remove duplicated _initialize_parameter_parallel_groups definition in…
Browse files Browse the repository at this point in the history
… engine (#709)

Co-authored-by: Jiangang Zhu <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
3 people authored Apr 30, 2021
1 parent de694b9 commit 6da4fcc
Showing 1 changed file with 1 addition and 20 deletions.
21 changes: 1 addition & 20 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, _initialize_parameter_parallel_groups
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
Expand Down Expand Up @@ -71,25 +71,6 @@ def split_half_float_double_csr(tensors):
return buckets


def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
data_parallel_size = int(dist.get_world_size())
if parameter_parallel_size is None:
parameter_parallel_size = int(data_parallel_size)
logger.info("data_parallel_size: %s, parameter_parallel_size: %s",
data_parallel_size,
parameter_parallel_size)
assert data_parallel_size % parameter_parallel_size == 0, \
'world size should be divisible by parameter parallel size'
rank = dist.get_rank()
my_group = None
for i in range(dist.get_world_size() // parameter_parallel_size):
ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
my_group = group
return my_group


def print_configuration(args, name):
logger.info('{}:'.format(name))
for arg in sorted(vars(args)):
Expand Down

0 comments on commit 6da4fcc

Please sign in to comment.