Skip to content

Return routing weights in dispatch send & recv kernel #7

@zoezeng1433

Description

@zoezeng1433

Hi team, currently Perplexity MoE kernels don't transmit routing weights in dispatch. While route.weight is extracted in a2a_dispatch_send kernel, it's not transmitted in the send buffer:

cppstruct ExpertAndOffset {
    uint32_t expert;
    uint32_t offset;
    uint32_t position;
    float weight;  // extracted but not sent
};

This reduces communication overhead, however, it makes it difficult to apply routing probabilities inside the expert MLP, e.g. swiglu(y) * routing_weights, instead of only applying weights at the combine stage.

Would it be possible to add an optional parameter to return routing weights? For example:

def dispatch(
        self,
        out_expert_num_tokens: torch.Tensor,
        out_expert_x: torch.Tensor,
        out_expert_prob: Optional[torch.Tensor],  # shape [num_recv_tokens, top-k], permuted to match out_expert_x
        out_expert_x_scale: Optional[torch.Tensor],
        dp_x: torch.Tensor,
        dp_x_scale: Optional[torch.Tensor],
        indices: torch.Tensor,
        weights: torch.Tensor,
        bound_m: Optional[torch.Tensor] = None,
        do_send: bool = True,
        do_recv: bool = True,
    ) -> None:

This could be gated by an optional flag (e.g., return_weights=False by default) to preserve the current bandwidth benefits while returning routing weights when needed.

One option would be to include the weight in the send buffer, at the cost of an additional sizeof(float32)*top-k per token. Alternatively, the weights could be reconstructed on receive from indices and the original routing weights, though this might still require transmitting some form of index data.

Any insights or guidance would be much appreciated.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions