Skip to content

Commit 32b3e92

Browse files
yuweivvvzzhx1
authored andcommitted
Update linear_op.py
Co-authored-by: zzhx1 <[email protected]> Signed-off-by: 子潜 <[email protected]>
1 parent 9741455 commit 32b3e92

File tree

5 files changed

+51
-37
lines changed

5 files changed

+51
-37
lines changed

vllm_ascend/distributed/parallel_state.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ def get_lmhead_tp_group() -> GroupCoordinator:
3838
"lm head tensor parallel group is not initialized")
3939
return _LMTP
4040

41+
4142
def get_dftp_group() -> GroupCoordinator:
4243
assert _DFTP is not None, (
4344
"denseffn tensor parallel group is not initialized")
4445
return _DFTP
4546

47+
4648
def get_flashcomm2_otp_group() -> GroupCoordinator:
4749
return _FLASHCOMM2_OTP
4850

@@ -183,17 +185,18 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
183185
get_world_group().local_rank,
184186
backend,
185187
group_name="lmheadtp")
186-
187-
denseffn_tensor_parallel_size = get_ascend_config().denseffn_tensor_parallel_size
188+
189+
denseffn_tensor_parallel_size = get_ascend_config(
190+
).denseffn_tensor_parallel_size
188191
if denseffn_tensor_parallel_size is not None:
189192
group_ranks = []
190193
global _DFTP
191-
num_denseffn_tensor_parallel_groups: int = (world_size //
192-
denseffn_tensor_parallel_size)
194+
num_denseffn_tensor_parallel_groups: int = (
195+
world_size // denseffn_tensor_parallel_size)
193196
for i in range(num_denseffn_tensor_parallel_groups):
194197
ranks = list(
195198
range(i * denseffn_tensor_parallel_size,
196-
(i + 1) * denseffn_tensor_parallel_size))
199+
(i + 1) * denseffn_tensor_parallel_size))
197200
group_ranks.append(ranks)
198201
_DFTP = init_model_parallel_group(group_ranks,
199202
get_world_group().local_rank,

vllm_ascend/ops/linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
from vllm.model_executor.utils import set_weight_attrs
3838

3939
from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
40-
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz, is_first_k_dense
40+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_enable_nz,
41+
is_first_k_dense)
4142

4243

4344
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):

vllm_ascend/ops/linear_op.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,17 @@
5252
from vllm.forward_context import get_forward_context
5353

5454
from vllm_ascend.ascend_config import get_ascend_config
55-
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_odp_group,
55+
from vllm_ascend.distributed.parallel_state import (get_dftp_group,
56+
get_flashcomm2_odp_group,
5657
get_flashcomm2_otp_group,
5758
get_mlp_tp_group,
58-
get_otp_group,
59-
get_dftp_group)
60-
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
61-
flashcomm2_enable,
59+
get_otp_group)
60+
from vllm_ascend.utils import (dense_optim_enable, denseffn_tp_enable,
61+
enable_sp, flashcomm2_enable,
6262
get_flashcomm2_reorgnized_batch_ids,
63-
matmul_allreduce_enable, mlp_tp_enable,
64-
oproj_tp_enable, shared_expert_dp_enabled,
65-
denseffn_tp_enable, is_first_k_dense)
63+
is_first_k_dense, matmul_allreduce_enable,
64+
mlp_tp_enable, oproj_tp_enable,
65+
shared_expert_dp_enabled)
6666

6767

6868
class CustomLinearOp:
@@ -161,10 +161,10 @@ def __init__(self, layer):
161161

162162
@property
163163
def comm_group(self):
164-
if denseffn_tp_enable():
165-
return get_dftp_group()
166-
else:
164+
if mlp_tp_enable():
167165
return get_mlp_tp_group()
166+
else:
167+
return get_dftp_group()
168168

169169
def apply_impl(
170170
self,
@@ -187,10 +187,10 @@ def __init__(self, layer):
187187

188188
@property
189189
def comm_group(self):
190-
if denseffn_tp_enable():
191-
return get_dftp_group()
192-
else:
190+
if mlp_tp_enable():
193191
return get_mlp_tp_group()
192+
else:
193+
return get_dftp_group()
194194

195195
def apply_impl(
196196
self, input_: torch.Tensor
@@ -613,7 +613,9 @@ def update_attrs(self):
613613
def _get_column_parallel_op(
614614
prefix, layer
615615
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
616-
if (mlp_tp_enable() or (denseffn_tp_enable() and is_first_k_dense(prefix))) and "gate_up_proj" in prefix:
616+
if (mlp_tp_enable() or
617+
(denseffn_tp_enable()
618+
and is_first_k_dense(prefix))) and "gate_up_proj" in prefix:
617619
return MLPColumnParallelOp(layer)
618620
if enable_sp():
619621
if "shared_expert" in prefix:
@@ -633,7 +635,9 @@ def _get_row_parallel_op(
633635
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
634636
Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp,
635637
SequenceRowParallelOp]]:
636-
if "down_proj" in prefix and (mlp_tp_enable() or (denseffn_tp_enable() and is_first_k_dense(prefix))):
638+
if "down_proj" in prefix and (mlp_tp_enable() or
639+
(denseffn_tp_enable()
640+
and is_first_k_dense(prefix))):
637641
return MLPRowParallelOp(layer)
638642
if "o_proj" in prefix and oproj_tp_enable():
639643
return OProjRowParallelOp(layer)

vllm_ascend/quantization/quant_config.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,15 @@
3636
from vllm.model_executor.utils import set_weight_attrs
3737

3838
from vllm_ascend.ascend_config import get_ascend_config
39-
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group,
39+
from vllm_ascend.distributed.parallel_state import (get_dftp_group,
40+
get_flashcomm2_otp_group,
4041
get_mlp_tp_group,
41-
get_otp_group,
42-
get_dftp_group)
42+
get_otp_group)
4343
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
4444
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
45-
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, flashcomm2_enable,
46-
mlp_tp_enable, oproj_tp_enable, denseffn_tp_enable)
45+
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, denseffn_tp_enable,
46+
flashcomm2_enable, mlp_tp_enable,
47+
oproj_tp_enable)
4748

4849
from .utils import get_quant_method
4950

@@ -349,7 +350,9 @@ def apply(
349350
if isinstance(layer, RowParallelLinear):
350351
if layer.prefix.find("o_proj") != -1 and oproj_tp_enable():
351352
tp_rank = get_otp_group().rank_in_group
352-
elif layer.prefix.find("down_proj") != -1 and (mlp_tp_enable() or (denseffn_tp_enable() and layer.is_first_k_dense)):
353+
elif layer.prefix.find("down_proj") != -1 and (
354+
mlp_tp_enable() or
355+
(denseffn_tp_enable() and layer.is_first_k_dense)):
353356
if denseffn_tp_enable() and layer.is_first_k_dense:
354357
tp_rank = get_dftp_group().rank_in_group
355358
else:

vllm_ascend/utils.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -768,9 +768,11 @@ def lmhead_tp_enable() -> bool:
768768
def oproj_tp_enable() -> bool:
769769
return get_ascend_config().oproj_tensor_parallel_size is not None
770770

771+
771772
def denseffn_tp_enable() -> bool:
772773
return get_ascend_config().denseffn_tensor_parallel_size is not None
773774

775+
774776
def mlp_tp_enable() -> bool:
775777
return envs_ascend.VLLM_ASCEND_ENABLE_MLP_OPTIMIZE
776778

@@ -1018,24 +1020,25 @@ def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
10181020

10191021
return reorgnized_batch_ids
10201022

1023+
10211024
def is_first_k_dense(prefix: str) -> bool:
10221025
from vllm.config import get_current_vllm_config
10231026
match = re.search(r'layers\.(\d+)\.', prefix)
10241027
if not match:
10251028
return False
10261029

10271030
layer_idx = int(match.group(1))
1028-
1031+
10291032
vllm_config = get_current_vllm_config()
10301033
if vllm_config is None:
1031-
raise ValueError("get_current_vllm_config() returned None. "
1032-
"Ensure this function is called within the model initialization context.")
1034+
raise ValueError(
1035+
"get_current_vllm_config() returned None. "
1036+
"Ensure this function is called within the model initialization context."
1037+
)
10331038
config = vllm_config.model_config.hf_config
10341039

1035-
is_moe_layer = (
1036-
config.n_routed_experts is not None and
1037-
layer_idx >= config.first_k_dense_replace and
1038-
layer_idx % config.moe_layer_freq == 0
1039-
)
1040+
is_moe_layer = (config.n_routed_experts is not None
1041+
and layer_idx >= config.first_k_dense_replace
1042+
and layer_idx % config.moe_layer_freq == 0)
10401043

1041-
return not is_moe_layer
1044+
return not is_moe_layer

0 commit comments

Comments
 (0)