-
Notifications
You must be signed in to change notification settings - Fork 28
Description
Hi team, currently Perplexity MoE kernels don't transmit routing weights in dispatch. While route.weight is extracted in a2a_dispatch_send kernel, it's not transmitted in the send buffer:
cppstruct ExpertAndOffset {
uint32_t expert;
uint32_t offset;
uint32_t position;
float weight; // extracted but not sent
};
This reduces communication overhead, however, it makes it difficult to apply routing probabilities inside the expert MLP, e.g. swiglu(y) * routing_weights, instead of only applying weights at the combine stage.
Would it be possible to add an optional parameter to return routing weights? For example:
def dispatch(
self,
out_expert_num_tokens: torch.Tensor,
out_expert_x: torch.Tensor,
out_expert_prob: Optional[torch.Tensor], # shape [num_recv_tokens, top-k], permuted to match out_expert_x
out_expert_x_scale: Optional[torch.Tensor],
dp_x: torch.Tensor,
dp_x_scale: Optional[torch.Tensor],
indices: torch.Tensor,
weights: torch.Tensor,
bound_m: Optional[torch.Tensor] = None,
do_send: bool = True,
do_recv: bool = True,
) -> None:
This could be gated by an optional flag (e.g., return_weights=False by default) to preserve the current bandwidth benefits while returning routing weights when needed.
One option would be to include the weight in the send buffer, at the cost of an additional sizeof(float32)*top-k per token. Alternatively, the weights could be reconstructed on receive from indices and the original routing weights, though this might still require transmitting some form of index data.
Any insights or guidance would be much appreciated.