Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autotp training #6922

Open
wants to merge 84 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 82 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
674a873
auto tp training
inkcherry Apr 3, 2024
a2e4c47
update parallel_states
inkcherry Apr 23, 2024
f4eb142
Merge branch 'master' into HEAD
inkcherry Nov 19, 2024
dd081ed
WA skips assertions, the loss remains exactly consistent with the low…
inkcherry Nov 19, 2024
cdaed2f
save/load ckpt & save/load hf model basic POC
inkcherry Nov 22, 2024
9aad0e7
finish all the basic functionalities
inkcherry Nov 27, 2024
2bb11fd
update
inkcherry Nov 28, 2024
e75c1c2
use groups for parallel_states
inkcherry Dec 2, 2024
840a5f2
enable bwd allreduce, enable scale loss by gas
inkcherry Dec 2, 2024
60bd6ab
add dataloader check
inkcherry Dec 4, 2024
9266383
refactor autoTP step1
inkcherry Dec 4, 2024
07174a9
rm parallel_states
inkcherry Dec 5, 2024
ee6323e
refactor autoTP step2
inkcherry Dec 5, 2024
6461b84
update ut step1
inkcherry Dec 10, 2024
4d73011
update
inkcherry Dec 11, 2024
c79c3bb
add uts
inkcherry Dec 11, 2024
97e659c
finished all ut code base
inkcherry Dec 12, 2024
a15905b
addllr scheduler test
inkcherry Dec 12, 2024
e9802b0
refine ut
inkcherry Dec 12, 2024
88b8acf
fix bcast_objlist
inkcherry Dec 15, 2024
868be0b
refine layers.py
inkcherry Dec 15, 2024
3788e07
refine gather
inkcherry Dec 15, 2024
27b24f6
pass codegen350M +TP2 ut
inkcherry Dec 16, 2024
3d7b89f
add mode choice
inkcherry Dec 16, 2024
47a6b0b
fix chatglm
inkcherry Dec 16, 2024
3a23997
fix chatglm2 with transformers=4.40 version
inkcherry Dec 16, 2024
e3ec46e
uneven
inkcherry Dec 16, 2024
9685879
fix uneven
inkcherry Dec 16, 2024
7b99b03
fix training
inkcherry Dec 16, 2024
570645f
refine code
inkcherry Dec 17, 2024
3729b64
remove skip bcase&reduce
inkcherry Dec 17, 2024
62d8858
fix typo
inkcherry Dec 17, 2024
dd17313
format
inkcherry Dec 17, 2024
93cf6f5
refine code
inkcherry Dec 18, 2024
87c4bc2
refine code
inkcherry Dec 18, 2024
1714bb5
refine
inkcherry Dec 18, 2024
dadf915
update yuan
inkcherry Dec 19, 2024
86c9399
optimize usage of move function
inkcherry Dec 19, 2024
2526dc6
refine args usage
inkcherry Dec 19, 2024
c9fd699
format
inkcherry Dec 19, 2024
797e71f
zero1 compatible
inkcherry Dec 19, 2024
86ae65e
remove wa
inkcherry Dec 22, 2024
3e40024
fix cpu device name
inkcherry Dec 22, 2024
7d94b77
fix lm-head
inkcherry Dec 23, 2024
b297950
add detach
inkcherry Dec 23, 2024
67ce220
fix ipex intergration
inkcherry Dec 23, 2024
f818be9
fix tied_embedding
inkcherry Dec 24, 2024
11c98f6
Merge remote-tracking branch 'origin/master' into autotp_training
inkcherry Jan 2, 2025
e22b625
format
inkcherry Jan 2, 2025
8531b64
Merge branch 'master' into autotp_training
tjruwase Jan 6, 2025
8d19e01
Merge branch 'master' into autotp_training
loadams Jan 6, 2025
060d48b
remove outdated comments
inkcherry Jan 13, 2025
6667ba1
Enhance unit test coverage
inkcherry Jan 13, 2025
84c9335
update ut
inkcherry Jan 13, 2025
cb29d7c
sequential some tests
inkcherry Jan 13, 2025
a49e77e
format
inkcherry Jan 13, 2025
0ef5274
use parameterized save path
inkcherry Jan 13, 2025
481088d
Merge remote-tracking branch 'my/autotp_training' into autotp_training
inkcherry Jan 13, 2025
f740de0
refactor infer/training path
inkcherry Jan 15, 2025
726004d
format
inkcherry Jan 15, 2025
bd8de77
remove empty line
inkcherry Jan 15, 2025
c334da0
remove autotp_size config from zero scope
inkcherry Jan 15, 2025
29eef07
update
inkcherry Jan 15, 2025
ba47ed1
format
inkcherry Jan 15, 2025
bbde63f
fix layer typo and rename
inkcherry Jan 15, 2025
bdca62c
fix python3.9
inkcherry Jan 15, 2025
5d89422
refine code
inkcherry Jan 15, 2025
0a9caff
refine
inkcherry Jan 15, 2025
c923a3b
refine config
inkcherry Jan 16, 2025
92be193
improve ut coverage for save
inkcherry Jan 17, 2025
23bd0fc
fix process exit early
inkcherry Jan 17, 2025
358f395
improve ut coverage
inkcherry Jan 17, 2025
cdfb54c
Merge remote-tracking branch 'origin/master' into autotp_training
inkcherry Jan 17, 2025
6d030c4
fix zero1 regression
inkcherry Jan 17, 2025
f9e7756
Merge branch 'master' into autotp_training
inkcherry Jan 20, 2025
6e7f846
fix ci
inkcherry Jan 20, 2025
c4fde7e
Merge branch 'autotp_training' of https://github.com/inkcherry/DeepSp…
inkcherry Jan 20, 2025
05bcecd
skip overflow test
inkcherry Jan 21, 2025
86f1c77
Merge branch 'master' into autotp_training
inkcherry Jan 22, 2025
668cb1a
Skip xpu tests until the ci is updated
inkcherry Jan 23, 2025
2e042a4
Merge branch 'autotp_training' of https://github.com/inkcherry/DeepSp…
inkcherry Jan 23, 2025
e08a234
Merge branch 'master' into autotp_training
delock Jan 24, 2025
20588f2
Merge branch 'master' into autotp_training
tjruwase Jan 30, 2025
1e05996
Merge branch 'master' into autotp_training
hwchen2017 Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .module_inject import replace_transformer_layer, revert_transformer_layer
from .module_inject import replace_transformer_layer, revert_transformer_layer, set_autotp_mode

from .utils import log_dist, OnDevice, logger
from .comm.comm import init_distributed
Expand Down Expand Up @@ -364,3 +364,34 @@ def init_inference(model, config=None, **kwargs):
engine = InferenceEngine(model, config=ds_inference_config)

return engine


def tp_model_init(model, tp_size, dtype):
"""
Initialize the model for tensor parallelism.

Args:
model (torch.nn.Module): The model to be initialized.
tp_size (int): The tensor parallelism size.
dtype (torch.dtype): The data type to be used for the model.

Returns:
torch.nn.Module: The initialized model with tensor parallelism.
"""
# avoid re-entry
assert not hasattr(
model, 'ds_autotp_parsed'), "ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed."

set_autotp_mode(training=True)

from deepspeed.runtime.tensor_parallel import TpTrainingManager
# The expected usage here is for it to be invoked by transformers package.

#TODO: We should provide a custom TP mapping solution without using autoTP
#as modifying the autoTP logic may be more difficult for users compared to configuring it

model = TpTrainingManager(model=model, tp_size=tp_size, dtype=dtype).module

setattr(model, 'ds_autotp_parsed', True)

return model
6 changes: 6 additions & 0 deletions deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,12 @@ def broadcast(tensor, src, group=None, async_op=False, prof=False, log_name='bro
return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)


@timed_op
def broadcast_object_list(object_list, src, group=None, device=None):
global cdb
return cdb.broadcast_object_list(object_list=object_list, src=src, group=group, device=device)


@timed_op
def all_gather(tensor_list,
tensor,
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def broadcast(self, tensor, src, group=None, async_op=False):
else:
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)

@disable_compiler_collective
def broadcast_object_list(self, object_list, src, group=None, device=None):
return torch.distributed.broadcast_object_list(object_list=object_list, src=src, group=group, device=device)

@disable_compiler_collective
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
Expand Down
1 change: 0 additions & 1 deletion deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.timer import SynchronizedWallClockTimer
from deepspeed.runtime.compiler import is_compile_supported

from ..runtime.state_dict_factory import SDLoaderFactory
from ..runtime.weight_quantizer import WeightQuantization
from ..module_inject import replace_transformer_layer, generic_injection
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection
from .module_quantize import quantize_transformer_layer
from .replace_policy import HFBertLayerPolicy
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode
from .policy import DSPolicy
89 changes: 30 additions & 59 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from typing import Optional
import torch
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearAllreduce, Yuan_LinearLayer, GateUpPack_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp
from .fusedqkv_utils import require_tp_fused_qkvw
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
from deepspeed.utils import groups
from deepspeed.module_inject.layers import is_autotp_training_mode


def move(tensor, device, copy=True):
Expand Down Expand Up @@ -333,10 +335,18 @@ def tp_parser(model):
return policy_list

def set_tensor_parallel_config(self, mp_size, mp_group):

if is_autotp_training_mode():
self.mp_group = groups.get_tensor_model_parallel_group()
self.mp_size = groups.get_tensor_model_parallel_world_size()
return

self.mp_size = mp_size
self.mp_group = mp_group

def _replace(self, child, name, conv_linear_layer):
# This function should clearly define the routing rules for specific layers
# and avoid any complex shard-related logic.
if getattr(child, "replaced", False) == True:
return
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
Expand All @@ -352,80 +362,41 @@ def _replace(self, child, name, conv_linear_layer):
# For Yuan model
if 'Yuan' in str(self.module):
if 'v_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), True)
return LinearLayer(weight=weight, bias=bias)
return Yuan_LinearLayer(child, self.mp_group)

elif 'o_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), False)
return LinearAllreduce(weight, bias, self.mp_group)
# For Arctic model, bypass to all_reduce replacement for w2 weights
return Yuan_LinearAllreduce(child, self.mp_group)

# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This additional code block is trying to deal with "MLP including chunk layer" (general case), but the returned module/object is in the name of GLM prefix.
It could be better to rename the GLM_LinearLayer to sth like GateUpPack_LinearLayer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the comments, modified:)

return GateUpPack_LinearLayer(child, self.mp_group)
# For Arctic model, bypass to all_reduce replacement for w2 weights
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
# For MoE MLP model, e.g., deepseek and jamba
down_proj = False
if 'down_proj' in name:
down_proj = True
# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]

setattr(child, "replaced", True)
if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
dim=1)
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data
return Conv_LinearALlreduce(child, self.mp_group, name=name)
elif name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(child, self.mp_group)

setattr(child, "replaced", True)
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
child.bias if child.bias is None else torch.nn.parameter.Parameter(
move(child.bias, device_name, return_new_copy)), self.mp_group)
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group)
return LinearAllreduce(child, self.mp_group, name=name)
else:

# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0] // mp_size, weight_shape[1]]
setattr(child, "replaced", True)
if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()

if require_tp_fused_qkvw(name, self.mp_size):
conv_LinearLayer(child, self.mp_group)
elif require_tp_fused_qkvw(name, self.mp_size):
#Check and handle fused qkv for TP
#The copy is a regular copy, The shape of dst and src is the same
data_dc = move(
prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index),
device_name, return_new_copy)

bias_data_dc = None if child.bias is None else move(
prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index),
device_name, return_new_copy)
else:
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
dim=1 if self.conv_linear_layer else 0)
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
dim=0)
bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy)
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
del bias_data
else:
bias_data_dc = None
return fused_LinearLayer(child, self.mp_group, fused_module=self.module)

setattr(child, "replaced", True)
return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc, requires_grad=False), bias=bias_data_dc)
return LinearLayer(child, self.mp_group, name=name)

def _slice_embedding(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
Expand Down
Loading