Skip to content

Commit

Permalink
Merge branch 'master' into olruwase/ds_5241
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Feb 6, 2025
2 parents 21bfca0 + f04649d commit 5fe5810
Show file tree
Hide file tree
Showing 17 changed files with 1,662 additions and 164 deletions.
33 changes: 32 additions & 1 deletion deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .module_inject import replace_transformer_layer, revert_transformer_layer
from .module_inject import replace_transformer_layer, revert_transformer_layer, set_autotp_mode

from .utils import log_dist, OnDevice, logger
from .comm.comm import init_distributed
Expand Down Expand Up @@ -364,3 +364,34 @@ def init_inference(model, config=None, **kwargs):
engine = InferenceEngine(model, config=ds_inference_config)

return engine


def tp_model_init(model, tp_size, dtype):
"""
Initialize the model for tensor parallelism.
Args:
model (torch.nn.Module): The model to be initialized.
tp_size (int): The tensor parallelism size.
dtype (torch.dtype): The data type to be used for the model.
Returns:
torch.nn.Module: The initialized model with tensor parallelism.
"""
# avoid re-entry
assert not hasattr(
model, 'ds_autotp_parsed'), "ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed."

set_autotp_mode(training=True)

from deepspeed.runtime.tensor_parallel import TpTrainingManager
# The expected usage here is for it to be invoked by transformers package.

#TODO: We should provide a custom TP mapping solution without using autoTP
#as modifying the autoTP logic may be more difficult for users compared to configuring it

model = TpTrainingManager(model=model, tp_size=tp_size, dtype=dtype).module

setattr(model, 'ds_autotp_parsed', True)

return model
6 changes: 6 additions & 0 deletions deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,12 @@ def broadcast(tensor, src, group=None, async_op=False, prof=False, log_name='bro
return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)


@timed_op
def broadcast_object_list(object_list, src, group=None, device=None):
global cdb
return cdb.broadcast_object_list(object_list=object_list, src=src, group=group, device=device)


@timed_op
def all_gather(tensor_list,
tensor,
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def broadcast(self, tensor, src, group=None, async_op=False):
else:
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)

@disable_compiler_collective
def broadcast_object_list(self, object_list, src, group=None, device=None):
return torch.distributed.broadcast_object_list(object_list=object_list, src=src, group=group, device=device)

@disable_compiler_collective
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
Expand Down
1 change: 0 additions & 1 deletion deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.timer import SynchronizedWallClockTimer
from deepspeed.runtime.compiler import is_compile_supported

from ..runtime.state_dict_factory import SDLoaderFactory
from ..runtime.weight_quantizer import WeightQuantization
from ..module_inject import replace_transformer_layer, generic_injection
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection
from .module_quantize import quantize_transformer_layer
from .replace_policy import HFBertLayerPolicy
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode
from .policy import DSPolicy
89 changes: 30 additions & 59 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from typing import Optional
import torch
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearAllreduce, Yuan_LinearLayer, GateUpPack_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp
from .fusedqkv_utils import require_tp_fused_qkvw
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
from deepspeed.utils import groups
from deepspeed.module_inject.layers import is_autotp_training_mode


def move(tensor, device, copy=True):
Expand Down Expand Up @@ -333,10 +335,18 @@ def tp_parser(model):
return policy_list

def set_tensor_parallel_config(self, mp_size, mp_group):

if is_autotp_training_mode():
self.mp_group = groups.get_tensor_model_parallel_group()
self.mp_size = groups.get_tensor_model_parallel_world_size()
return

self.mp_size = mp_size
self.mp_group = mp_group

def _replace(self, child, name, conv_linear_layer):
# This function should clearly define the routing rules for specific layers
# and avoid any complex shard-related logic.
if getattr(child, "replaced", False) == True:
return
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
Expand All @@ -352,80 +362,41 @@ def _replace(self, child, name, conv_linear_layer):
# For Yuan model
if 'Yuan' in str(self.module):
if 'v_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), True)
return LinearLayer(weight=weight, bias=bias)
return Yuan_LinearLayer(child, self.mp_group)

elif 'o_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), False)
return LinearAllreduce(weight, bias, self.mp_group)
# For Arctic model, bypass to all_reduce replacement for w2 weights
return Yuan_LinearAllreduce(child, self.mp_group)

# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
return GateUpPack_LinearLayer(child, self.mp_group)
# For Arctic model, bypass to all_reduce replacement for w2 weights
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
# For MoE MLP model, e.g., deepseek and jamba
down_proj = False
if 'down_proj' in name:
down_proj = True
# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]

setattr(child, "replaced", True)
if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
dim=1)
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data
return Conv_LinearALlreduce(child, self.mp_group, name=name)
elif name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(child, self.mp_group)

setattr(child, "replaced", True)
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
child.bias if child.bias is None else torch.nn.parameter.Parameter(
move(child.bias, device_name, return_new_copy)), self.mp_group)
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group)
return LinearAllreduce(child, self.mp_group, name=name)
else:

# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0] // mp_size, weight_shape[1]]
setattr(child, "replaced", True)
if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()

if require_tp_fused_qkvw(name, self.mp_size):
conv_LinearLayer(child, self.mp_group)
elif require_tp_fused_qkvw(name, self.mp_size):
#Check and handle fused qkv for TP
#The copy is a regular copy, The shape of dst and src is the same
data_dc = move(
prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index),
device_name, return_new_copy)

bias_data_dc = None if child.bias is None else move(
prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index),
device_name, return_new_copy)
else:
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
dim=1 if self.conv_linear_layer else 0)
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
dim=0)
bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy)
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
del bias_data
else:
bias_data_dc = None
return fused_LinearLayer(child, self.mp_group, fused_module=self.module)

setattr(child, "replaced", True)
return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc, requires_grad=False), bias=bias_data_dc)
return LinearLayer(child, self.mp_group, name=name)

def _slice_embedding(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
Expand Down
Loading

0 comments on commit 5fe5810

Please sign in to comment.