diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index 3de391fbe..c3c6c9b48 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -668,6 +668,7 @@ def __init__( blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + use_irope: bool = False, ) -> None: self.num_heads = num_heads self.head_size = head_size diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 03d0bcca4..62266f7bf 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -172,6 +172,7 @@ def __init__( blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + use_irope: bool = False, ) -> None: self.num_heads = num_heads self.head_size = head_size diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 7eebc7d07..0a2a39c6c 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -150,6 +150,7 @@ def fused_experts( topk_ids: torch.Tensor, top_k: int, expert_map: torch.Tensor = None, + apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ Fused experts with top-k routing. @@ -188,6 +189,9 @@ def fused_experts( # assert dtype in [torch.float32, torch.float16, torch.bfloat16 # ], "Only float32, float16, and bfloat16 are supported" + if apply_router_weight_on_input: + hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) + if expert_map is not None: # Generate token indices and flatten token_indices = (torch.arange(num_tokens, @@ -289,6 +293,7 @@ def fused_experts( torch.zeros_like(weighted_down_out)).to(dtype) final_hidden_states.index_add_(0, sorted_token_indices, valid_output) else: + scales = None if apply_router_weight_on_input else topk_weights # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available. final_hidden_states = torch_npu.npu_moe_finalize_routing( @@ -296,7 +301,7 @@ def fused_experts( skip1=None, skip2=None, bias=None, - scales=topk_weights, + scales=scales, expanded_src_to_dst_row=expanded_row_idx, export_for_source_row=topk_ids, ) @@ -363,9 +368,6 @@ def select_experts( Raises: ValueError: If an unsupported scoring function is provided. """ - if custom_routing_function is not None: - raise NotImplementedError( - "Custom routing function is not supported now") if scoring_func == "softmax": # NOTE: vLLM use dtype=torch.float here @@ -402,9 +404,15 @@ def select_experts( k=top_k, dim=-1, sorted=False) - else: + elif custom_routing_function is None: topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) topk_weights = topk_weights.to(hidden_states.dtype) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) # Required by npu_moe_init_routing topk_ids = topk_ids.to(torch.int32) @@ -462,6 +470,7 @@ def apply( scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill=False, + apply_router_weight_on_input=False, **kwargs, ): # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern @@ -504,13 +513,15 @@ def apply( expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name) else: - return fused_experts(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map) + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input) class AscendFusedMoE(FusedMoE): @@ -534,6 +545,7 @@ def __init__(self, custom_routing_function=None, scoring_func="softmax", e_score_correction_bias=None, + apply_router_weight_on_input: bool = False, activation="silu"): super(FusedMoE, self).__init__() @@ -564,6 +576,7 @@ def __init__(self, self.e_score_correction_bias = e_score_correction_bias self.expert_map = None self.activation = activation + self.apply_router_weight_on_input = apply_router_weight_on_input if self.ep_size > 1: # Create a tensor of size num_experts filled with -1 @@ -652,7 +665,8 @@ def forward(self, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, - is_prefill=is_prefill) + is_prefill=is_prefill, + apply_router_weight_on_input=self.apply_router_weight_on_input) if self.dp_size > 1: if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore