Skip to content

moe support for llama4 and mllama4 #740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_ascend/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,7 @@ def __init__(
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
use_irope: bool = False,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def __init__(
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
use_irope: bool = False,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand Down
31 changes: 23 additions & 8 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def fused_experts(
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
Fused experts with top-k routing.
Expand Down Expand Up @@ -188,6 +189,9 @@ def fused_experts(
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
# ], "Only float32, float16, and bfloat16 are supported"

if apply_router_weight_on_input:
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)

if expert_map is not None:
# Generate token indices and flatten
token_indices = (torch.arange(num_tokens,
Expand Down Expand Up @@ -289,14 +293,15 @@ def fused_experts(
torch.zeros_like(weighted_down_out)).to(dtype)
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
else:
scales = None if apply_router_weight_on_input else topk_weights
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
final_hidden_states = torch_npu.npu_moe_finalize_routing(
down_out_list,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
scales=scales,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
Expand Down Expand Up @@ -363,9 +368,6 @@ def select_experts(
Raises:
ValueError: If an unsupported scoring function is provided.
"""
if custom_routing_function is not None:
raise NotImplementedError(
"Custom routing function is not supported now")

if scoring_func == "softmax":
# NOTE: vLLM use dtype=torch.float here
Expand Down Expand Up @@ -402,9 +404,15 @@ def select_experts(
k=top_k,
dim=-1,
sorted=False)
else:
elif custom_routing_function is None:
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_weights = topk_weights.to(hidden_states.dtype)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)

# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
Expand Down Expand Up @@ -462,6 +470,7 @@ def apply(
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill=False,
apply_router_weight_on_input=False,
**kwargs,
):
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
Expand Down Expand Up @@ -510,7 +519,8 @@ def apply(
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)


class AscendFusedMoE(FusedMoE):
Expand All @@ -534,7 +544,9 @@ def __init__(self,
custom_routing_function=None,
scoring_func="softmax",
e_score_correction_bias=None,
activation="silu"):
apply_router_weight_on_input: bool = False,
activation="silu",
):
super(FusedMoE, self).__init__()

if params_dtype is None:
Expand Down Expand Up @@ -564,6 +576,7 @@ def __init__(self,
self.e_score_correction_bias = e_score_correction_bias
self.expert_map = None
self.activation = activation
self.apply_router_weight_on_input = apply_router_weight_on_input

if self.ep_size > 1:
# Create a tensor of size num_experts filled with -1
Expand Down Expand Up @@ -652,7 +665,9 @@ def forward(self,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill)
is_prefill=is_prefill,
apply_router_weight_on_input=self.apply_router_weight_on_input,
)

if self.dp_size > 1:
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
Expand Down
Loading