Skip to content

Commit 5fde7cb

Browse files
phantomlei3root
authored andcommitted
refactor: breakdown fused moe kernel for deepseek all2all setup.
1 parent 25e16fa commit 5fde7cb

File tree

10 files changed

+629
-350
lines changed

10 files changed

+629
-350
lines changed

xllm/core/kernels/mlu/fused_moe.cpp

Lines changed: 49 additions & 213 deletions
Original file line numberDiff line numberDiff line change
@@ -17,240 +17,76 @@ limitations under the License.
1717

1818
#include "mlu_ops_api.h"
1919

20-
namespace {
21-
torch::Tensor create_group_gemm_output(
22-
const torch::Tensor& a,
23-
const torch::Tensor& b,
24-
const torch::Tensor& group_list,
25-
torch::ScalarType dtype = torch::ScalarType::BFloat16) {
26-
torch::TensorOptions target_options = a.options().dtype(dtype);
27-
if (b.dim() != 2) {
28-
return torch::empty({a.size(0), b.size(1)}, target_options);
29-
}
30-
return torch::empty({group_list.size(0), a.size(0), b.size(0)},
31-
target_options);
32-
}
33-
} // namespace
34-
3520
namespace xllm::kernel::mlu {
36-
torch::Tensor fused_moe(
37-
const torch::Tensor& hidden_states,
38-
const torch::Tensor& gating_output,
39-
const torch::Tensor& w1,
40-
const torch::Tensor& w2,
41-
const std::optional<torch::Tensor>& bias1,
42-
const std::optional<torch::Tensor>& bias2,
43-
const std::optional<torch::Tensor>& residual,
44-
const std::optional<torch::Tensor>& input_smooth,
45-
const std::optional<torch::Tensor>& act_smooth,
46-
const std::optional<torch::Tensor>& w1_scale,
47-
const std::optional<torch::Tensor>& w2_scale,
48-
const std::optional<torch::Tensor>& e_score_correction_bias,
21+
22+
std::tuple<torch::Tensor, torch::Tensor> moe_active_topk(
23+
const torch::Tensor& input,
4924
int64_t topk,
50-
bool renormalize,
51-
bool gated,
52-
const std::string& act_mode,
53-
const std::string& scoring_func,
5425
int64_t num_expert_group,
5526
int64_t topk_group,
27+
bool normalize,
28+
const std::optional<torch::Tensor>& mask,
29+
const std::string& normed_by,
30+
const std::string& scoring_func,
5631
double route_scale,
57-
int64_t start_expert_id,
58-
bool avg_moe,
59-
const std::optional<torch::List<int64_t>>& w1_quant_flag,
60-
const std::optional<torch::List<int64_t>>& w2_quant_flag) {
61-
auto dtype = hidden_states.dtype();
62-
auto ori_input_shape = hidden_states.sizes();
63-
64-
auto hidden_states_2d = hidden_states.reshape({-1, hidden_states.size(-1)});
65-
int64_t tokens = hidden_states_2d.size(0);
66-
auto gating_output_2d = gating_output.reshape({-1, gating_output.size(-1)});
67-
68-
std::optional<torch::Tensor> residual_2d = std::nullopt;
69-
if (residual.has_value()) {
70-
residual_2d = residual.value().reshape({-1, residual.value().size(-1)});
71-
}
72-
73-
// check smooth quant variables
74-
bool all_present = input_smooth && act_smooth && w1_scale && w2_scale;
75-
bool all_none = !input_smooth && !act_smooth && !w1_scale && !w2_scale;
76-
CHECK(all_none || all_present)
77-
<< "input_smooth, act_smooth, w1_scale and w2_scale must be present or "
78-
"absent at the same time.";
79-
bool is_smoothquant = all_present;
80-
int64_t expert_num = gating_output_2d.size(-1);
81-
int64_t expert_size = w1.size(0);
82-
83-
// apply softmax_topk or sigmoid_topk
84-
auto reduce_weight = torch::empty(
85-
{gating_output_2d.size(0), topk},
86-
torch::dtype(torch::kFloat).device(gating_output_2d.device()));
87-
auto expert_id = torch::empty(
88-
{gating_output_2d.size(0), topk},
89-
torch::dtype(torch::kInt32).device(gating_output_2d.device()));
90-
91-
tmo::torch_api::moe_active_topk(gating_output_2d,
32+
const std::optional<torch::Tensor>& e_score_correction_bias) {
33+
auto reduce_weight =
34+
torch::empty({input.size(0), topk},
35+
torch::dtype(torch::kFloat).device(input.device()));
36+
auto expert_id =
37+
torch::empty({input.size(0), topk},
38+
torch::dtype(torch::kInt32).device(input.device()));
39+
tmo::torch_api::moe_active_topk(input,
9240
topk,
9341
num_expert_group,
9442
topk_group,
95-
renormalize,
96-
/*mask=*/std::nullopt,
97-
/*normed_by=*/"topk_logit",
43+
normalize,
44+
mask,
45+
normed_by,
9846
scoring_func,
9947
route_scale,
10048
e_score_correction_bias,
10149
reduce_weight,
10250
expert_id);
51+
return std::make_tuple(reduce_weight, expert_id);
52+
}
10353

104-
auto output_vec = tmo::torch_api::moe_gen_idx(expert_id, expert_num);
105-
auto expand_idx = output_vec[0];
106-
auto combine_idx = output_vec[1];
107-
auto token_count = output_vec[2];
108-
auto cusum_token_count = output_vec[3];
109-
110-
// prepare the parameters for the first group gemm
111-
auto token_count_slice =
112-
token_count.slice(0, start_expert_id, start_expert_id + expert_size);
113-
auto gather_index_start_position =
114-
cusum_token_count.index({start_expert_id}).unsqueeze(0);
115-
torch::Tensor expand_hidden_states;
116-
torch::Tensor input_scale;
117-
118-
if (is_smoothquant) {
119-
// w8a8 path: quantize input hidden states directly (fused with
120-
// moe_expand_input)
121-
std::tie(expand_hidden_states, input_scale) =
122-
scaled_quantize(hidden_states_2d, // Use original hidden_states_2d
123-
// instead of expand_hidden_states
124-
input_smooth.value(),
125-
/*zero=*/std::nullopt,
126-
token_count_slice,
127-
expand_idx,
128-
gather_index_start_position,
129-
/*output=*/std::nullopt,
130-
/*output_scale=*/std::nullopt,
131-
/*act_mode=*/"none",
132-
/*active_coef=*/1.0,
133-
/*is_gated=*/false,
134-
/*quant_type=*/torch::kChar);
135-
} else {
136-
// bf16/fp32 path: expand input hidden states
137-
expand_hidden_states = tmo::torch_api::moe_expand_input(hidden_states_2d,
138-
expand_idx,
139-
cusum_token_count,
140-
start_expert_id,
141-
expert_size);
142-
}
143-
144-
torch::Tensor gemm1_out = create_group_gemm_output(
145-
expand_hidden_states, w1, token_count_slice, dtype.toScalarType());
146-
147-
// Unified group_gemm call using input_scale/w1_scale/quant_flag only if
148-
// present
149-
tmo::torch_api::group_gemm(
150-
expand_hidden_states,
151-
w1,
152-
token_count_slice,
153-
gemm1_out,
154-
/*gather_idx=*/std::nullopt,
155-
/*c=*/std::nullopt,
156-
/*alpha=*/std::nullopt,
157-
/*beta=*/std::nullopt,
158-
/*a_scale=*/input_scale.defined() ? std::make_optional(input_scale)
159-
: std::nullopt,
160-
/*b_scale=*/w1_scale.has_value() ? std::make_optional(w1_scale.value())
161-
: std::nullopt,
162-
/*bias=*/std::nullopt,
163-
/*a_calibration=*/std::nullopt,
164-
/*b_calibration=*/std::nullopt,
165-
/*quant_flag=*/w1_quant_flag.has_value() ? w1_quant_flag : std::nullopt,
166-
/*b_offset=*/std::nullopt,
167-
/*tile_config=*/std::nullopt,
168-
/*max_dim=*/tokens,
169-
/*trans_a=*/false,
170-
/*trans_b=*/true,
171-
/*a_quant_bit=*/is_smoothquant ? 8 : -1);
172-
173-
// prepare the parameters for the second group gemm
174-
torch::Tensor act_out;
175-
torch::Tensor act_out_scale;
176-
if (is_smoothquant) {
177-
// w8a8 path: reuse quantized_input and input_scale from first group_gemm
178-
act_out = gated ? expand_hidden_states.slice(1, 0, gemm1_out.size(1) / 2)
179-
: expand_hidden_states.slice(1, 0, gemm1_out.size(1));
180-
act_out_scale = input_scale.slice(0, 0, gemm1_out.size(0));
181-
182-
// Quantize gemm1_out directly (fused with active operation) using reused
183-
// tensors
184-
auto [quantized_activation, activation_scale] =
185-
scaled_quantize(gemm1_out,
186-
act_smooth.value(),
187-
/*zero=*/std::nullopt,
188-
/*token_count=*/token_count_slice,
189-
/*gather_index=*/std::nullopt,
190-
/*gather_index_start_position=*/std::nullopt,
191-
act_out, // output - reuse from quantized_input
192-
act_out_scale, // output_scale - reuse from input_scale
193-
/*act_mode=*/act_mode,
194-
/*active_coef=*/1.0,
195-
/*is_gated=*/gated,
196-
/*quant_type=*/torch::kChar);
197-
act_out = quantized_activation;
198-
act_out_scale = activation_scale;
199-
} else {
200-
// bf16/fp32 path: apply activation function first
201-
act_out = gated ? gemm1_out.slice(1, 0, gemm1_out.size(1) / 2) : gemm1_out;
202-
tmo::torch_api::active(gemm1_out,
203-
act_out,
204-
bias1,
205-
cusum_token_count,
206-
act_mode,
207-
gated,
208-
start_expert_id,
209-
expert_size);
210-
}
211-
212-
torch::Tensor gemm2_out = create_group_gemm_output(
213-
act_out, w2, token_count_slice, dtype.toScalarType());
54+
std::vector<torch::Tensor> moe_gen_idx(const torch::Tensor& expert_id,
55+
int64_t expert_num) {
56+
return tmo::torch_api::moe_gen_idx(expert_id, expert_num);
57+
}
21458

215-
// Unified group_gemm call, now only checks the existance of
216-
// input_scale/w1_scale for smoothquant
217-
tmo::torch_api::group_gemm(
218-
act_out,
219-
w2,
220-
token_count_slice,
221-
gemm2_out,
222-
/*gather_idx=*/std::nullopt,
223-
/*c=*/std::nullopt,
224-
/*alpha=*/std::nullopt,
225-
/*beta=*/std::nullopt,
226-
act_out_scale.defined() ? std::make_optional(act_out_scale)
227-
: std::nullopt, // a_scale
228-
w2_scale.has_value() ? std::make_optional(w2_scale.value())
229-
: std::nullopt, // b_scale
230-
/*bias=*/std::nullopt,
231-
/*a_calibration=*/std::nullopt,
232-
/*b_calibration=*/std::nullopt,
233-
w2_quant_flag.has_value() ? w2_quant_flag : std::nullopt, // quant_flag
234-
/*b_offset=*/std::nullopt,
235-
/*tile_config=*/std::nullopt,
236-
/*max_dim=*/tokens,
237-
/*trans_a=*/false,
238-
/*trans_b=*/true,
239-
/*a_quant_bit=*/is_smoothquant ? 8 : -1);
59+
torch::Tensor moe_expand_input(
60+
const torch::Tensor& input,
61+
const torch::Tensor& gather_index,
62+
const std::optional<torch::Tensor>& cusum_token_count,
63+
int64_t start_expert_id,
64+
int64_t expert_size) {
65+
return tmo::torch_api::moe_expand_input(
66+
input, gather_index, cusum_token_count, start_expert_id, expert_size);
67+
}
24068

241-
auto output = torch::empty({reduce_weight.size(0), gemm2_out.size(1)},
242-
gemm2_out.options());
243-
tmo::torch_api::moe_combine_result(gemm2_out,
69+
torch::Tensor moe_combine_result(
70+
const torch::Tensor& input,
71+
const torch::Tensor& reduce_weight,
72+
const torch::Tensor& gather_ids,
73+
const std::optional<torch::Tensor>& residual,
74+
const std::optional<torch::Tensor>& cusum_token_count,
75+
const int64_t start_expert_id,
76+
const int64_t expert_size,
77+
const std::optional<torch::Tensor>& bias) {
78+
auto output =
79+
torch::empty({reduce_weight.size(0), input.size(1)}, input.options());
80+
tmo::torch_api::moe_combine_result(input,
24481
output,
24582
reduce_weight,
246-
combine_idx,
247-
residual_2d,
83+
gather_ids,
84+
residual,
24885
cusum_token_count,
24986
start_expert_id,
25087
expert_size,
251-
bias2);
252-
253-
return output.reshape(ori_input_shape);
88+
bias);
89+
return output;
25490
}
25591

25692
} // namespace xllm::kernel::mlu
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "mlu_ops_api.h"
17+
18+
namespace xllm::kernel::mlu {
19+
20+
torch::Tensor group_gemm(const torch::Tensor& a,
21+
const torch::Tensor& b,
22+
const torch::Tensor& token_count,
23+
torch::Tensor& output,
24+
const std::optional<torch::Tensor>& a_scale,
25+
const std::optional<torch::Tensor>& b_scale,
26+
const std::optional<torch::List<int64_t>>& quant_flag,
27+
const int64_t max_dim,
28+
const bool trans_a,
29+
const bool trans_b,
30+
const int64_t a_quant_bit) {
31+
tmo::torch_api::group_gemm(a,
32+
b,
33+
token_count,
34+
output,
35+
/*gather_idx=*/std::nullopt,
36+
/*c=*/std::nullopt,
37+
/*alpha=*/std::nullopt,
38+
/*beta=*/std::nullopt,
39+
a_scale,
40+
b_scale,
41+
/*bias=*/std::nullopt,
42+
/*a_calibration=*/std::nullopt,
43+
/*b_calibration=*/std::nullopt,
44+
quant_flag,
45+
/*b_offset=*/std::nullopt,
46+
/*tile_config=*/std::nullopt,
47+
max_dim,
48+
trans_a,
49+
trans_b,
50+
a_quant_bit);
51+
return output;
52+
}
53+
54+
} // namespace xllm::kernel::mlu

0 commit comments

Comments
 (0)