@@ -17,240 +17,76 @@ limitations under the License.
1717
1818#include " mlu_ops_api.h"
1919
20- namespace {
21- torch::Tensor create_group_gemm_output (
22- const torch::Tensor& a,
23- const torch::Tensor& b,
24- const torch::Tensor& group_list,
25- torch::ScalarType dtype = torch::ScalarType::BFloat16) {
26- torch::TensorOptions target_options = a.options ().dtype (dtype);
27- if (b.dim () != 2 ) {
28- return torch::empty ({a.size (0 ), b.size (1 )}, target_options);
29- }
30- return torch::empty ({group_list.size (0 ), a.size (0 ), b.size (0 )},
31- target_options);
32- }
33- } // namespace
34-
3520namespace xllm ::kernel::mlu {
36- torch::Tensor fused_moe (
37- const torch::Tensor& hidden_states,
38- const torch::Tensor& gating_output,
39- const torch::Tensor& w1,
40- const torch::Tensor& w2,
41- const std::optional<torch::Tensor>& bias1,
42- const std::optional<torch::Tensor>& bias2,
43- const std::optional<torch::Tensor>& residual,
44- const std::optional<torch::Tensor>& input_smooth,
45- const std::optional<torch::Tensor>& act_smooth,
46- const std::optional<torch::Tensor>& w1_scale,
47- const std::optional<torch::Tensor>& w2_scale,
48- const std::optional<torch::Tensor>& e_score_correction_bias,
21+
22+ std::tuple<torch::Tensor, torch::Tensor> moe_active_topk (
23+ const torch::Tensor& input,
4924 int64_t topk,
50- bool renormalize,
51- bool gated,
52- const std::string& act_mode,
53- const std::string& scoring_func,
5425 int64_t num_expert_group,
5526 int64_t topk_group,
27+ bool normalize,
28+ const std::optional<torch::Tensor>& mask,
29+ const std::string& normed_by,
30+ const std::string& scoring_func,
5631 double route_scale,
57- int64_t start_expert_id,
58- bool avg_moe,
59- const std::optional<torch::List<int64_t >>& w1_quant_flag,
60- const std::optional<torch::List<int64_t >>& w2_quant_flag) {
61- auto dtype = hidden_states.dtype ();
62- auto ori_input_shape = hidden_states.sizes ();
63-
64- auto hidden_states_2d = hidden_states.reshape ({-1 , hidden_states.size (-1 )});
65- int64_t tokens = hidden_states_2d.size (0 );
66- auto gating_output_2d = gating_output.reshape ({-1 , gating_output.size (-1 )});
67-
68- std::optional<torch::Tensor> residual_2d = std::nullopt ;
69- if (residual.has_value ()) {
70- residual_2d = residual.value ().reshape ({-1 , residual.value ().size (-1 )});
71- }
72-
73- // check smooth quant variables
74- bool all_present = input_smooth && act_smooth && w1_scale && w2_scale;
75- bool all_none = !input_smooth && !act_smooth && !w1_scale && !w2_scale;
76- CHECK (all_none || all_present)
77- << " input_smooth, act_smooth, w1_scale and w2_scale must be present or "
78- " absent at the same time." ;
79- bool is_smoothquant = all_present;
80- int64_t expert_num = gating_output_2d.size (-1 );
81- int64_t expert_size = w1.size (0 );
82-
83- // apply softmax_topk or sigmoid_topk
84- auto reduce_weight = torch::empty (
85- {gating_output_2d.size (0 ), topk},
86- torch::dtype (torch::kFloat ).device (gating_output_2d.device ()));
87- auto expert_id = torch::empty (
88- {gating_output_2d.size (0 ), topk},
89- torch::dtype (torch::kInt32 ).device (gating_output_2d.device ()));
90-
91- tmo::torch_api::moe_active_topk (gating_output_2d,
32+ const std::optional<torch::Tensor>& e_score_correction_bias) {
33+ auto reduce_weight =
34+ torch::empty ({input.size (0 ), topk},
35+ torch::dtype (torch::kFloat ).device (input.device ()));
36+ auto expert_id =
37+ torch::empty ({input.size (0 ), topk},
38+ torch::dtype (torch::kInt32 ).device (input.device ()));
39+ tmo::torch_api::moe_active_topk (input,
9240 topk,
9341 num_expert_group,
9442 topk_group,
95- renormalize ,
96- /* mask= */ std:: nullopt ,
97- /* normed_by= */ " topk_logit " ,
43+ normalize ,
44+ mask,
45+ normed_by,
9846 scoring_func,
9947 route_scale,
10048 e_score_correction_bias,
10149 reduce_weight,
10250 expert_id);
51+ return std::make_tuple (reduce_weight, expert_id);
52+ }
10353
104- auto output_vec = tmo::torch_api::moe_gen_idx (expert_id, expert_num);
105- auto expand_idx = output_vec[0 ];
106- auto combine_idx = output_vec[1 ];
107- auto token_count = output_vec[2 ];
108- auto cusum_token_count = output_vec[3 ];
109-
110- // prepare the parameters for the first group gemm
111- auto token_count_slice =
112- token_count.slice (0 , start_expert_id, start_expert_id + expert_size);
113- auto gather_index_start_position =
114- cusum_token_count.index ({start_expert_id}).unsqueeze (0 );
115- torch::Tensor expand_hidden_states;
116- torch::Tensor input_scale;
117-
118- if (is_smoothquant) {
119- // w8a8 path: quantize input hidden states directly (fused with
120- // moe_expand_input)
121- std::tie (expand_hidden_states, input_scale) =
122- scaled_quantize (hidden_states_2d, // Use original hidden_states_2d
123- // instead of expand_hidden_states
124- input_smooth.value (),
125- /* zero=*/ std::nullopt ,
126- token_count_slice,
127- expand_idx,
128- gather_index_start_position,
129- /* output=*/ std::nullopt ,
130- /* output_scale=*/ std::nullopt ,
131- /* act_mode=*/ " none" ,
132- /* active_coef=*/ 1.0 ,
133- /* is_gated=*/ false ,
134- /* quant_type=*/ torch::kChar );
135- } else {
136- // bf16/fp32 path: expand input hidden states
137- expand_hidden_states = tmo::torch_api::moe_expand_input (hidden_states_2d,
138- expand_idx,
139- cusum_token_count,
140- start_expert_id,
141- expert_size);
142- }
143-
144- torch::Tensor gemm1_out = create_group_gemm_output (
145- expand_hidden_states, w1, token_count_slice, dtype.toScalarType ());
146-
147- // Unified group_gemm call using input_scale/w1_scale/quant_flag only if
148- // present
149- tmo::torch_api::group_gemm (
150- expand_hidden_states,
151- w1,
152- token_count_slice,
153- gemm1_out,
154- /* gather_idx=*/ std::nullopt ,
155- /* c=*/ std::nullopt ,
156- /* alpha=*/ std::nullopt ,
157- /* beta=*/ std::nullopt ,
158- /* a_scale=*/ input_scale.defined () ? std::make_optional (input_scale)
159- : std::nullopt ,
160- /* b_scale=*/ w1_scale.has_value () ? std::make_optional (w1_scale.value ())
161- : std::nullopt ,
162- /* bias=*/ std::nullopt ,
163- /* a_calibration=*/ std::nullopt ,
164- /* b_calibration=*/ std::nullopt ,
165- /* quant_flag=*/ w1_quant_flag.has_value () ? w1_quant_flag : std::nullopt ,
166- /* b_offset=*/ std::nullopt ,
167- /* tile_config=*/ std::nullopt ,
168- /* max_dim=*/ tokens,
169- /* trans_a=*/ false ,
170- /* trans_b=*/ true ,
171- /* a_quant_bit=*/ is_smoothquant ? 8 : -1 );
172-
173- // prepare the parameters for the second group gemm
174- torch::Tensor act_out;
175- torch::Tensor act_out_scale;
176- if (is_smoothquant) {
177- // w8a8 path: reuse quantized_input and input_scale from first group_gemm
178- act_out = gated ? expand_hidden_states.slice (1 , 0 , gemm1_out.size (1 ) / 2 )
179- : expand_hidden_states.slice (1 , 0 , gemm1_out.size (1 ));
180- act_out_scale = input_scale.slice (0 , 0 , gemm1_out.size (0 ));
181-
182- // Quantize gemm1_out directly (fused with active operation) using reused
183- // tensors
184- auto [quantized_activation, activation_scale] =
185- scaled_quantize (gemm1_out,
186- act_smooth.value (),
187- /* zero=*/ std::nullopt ,
188- /* token_count=*/ token_count_slice,
189- /* gather_index=*/ std::nullopt ,
190- /* gather_index_start_position=*/ std::nullopt ,
191- act_out, // output - reuse from quantized_input
192- act_out_scale, // output_scale - reuse from input_scale
193- /* act_mode=*/ act_mode,
194- /* active_coef=*/ 1.0 ,
195- /* is_gated=*/ gated,
196- /* quant_type=*/ torch::kChar );
197- act_out = quantized_activation;
198- act_out_scale = activation_scale;
199- } else {
200- // bf16/fp32 path: apply activation function first
201- act_out = gated ? gemm1_out.slice (1 , 0 , gemm1_out.size (1 ) / 2 ) : gemm1_out;
202- tmo::torch_api::active (gemm1_out,
203- act_out,
204- bias1,
205- cusum_token_count,
206- act_mode,
207- gated,
208- start_expert_id,
209- expert_size);
210- }
211-
212- torch::Tensor gemm2_out = create_group_gemm_output (
213- act_out, w2, token_count_slice, dtype.toScalarType ());
54+ std::vector<torch::Tensor> moe_gen_idx (const torch::Tensor& expert_id,
55+ int64_t expert_num) {
56+ return tmo::torch_api::moe_gen_idx (expert_id, expert_num);
57+ }
21458
215- // Unified group_gemm call, now only checks the existance of
216- // input_scale/w1_scale for smoothquant
217- tmo::torch_api::group_gemm (
218- act_out,
219- w2,
220- token_count_slice,
221- gemm2_out,
222- /* gather_idx=*/ std::nullopt ,
223- /* c=*/ std::nullopt ,
224- /* alpha=*/ std::nullopt ,
225- /* beta=*/ std::nullopt ,
226- act_out_scale.defined () ? std::make_optional (act_out_scale)
227- : std::nullopt , // a_scale
228- w2_scale.has_value () ? std::make_optional (w2_scale.value ())
229- : std::nullopt , // b_scale
230- /* bias=*/ std::nullopt ,
231- /* a_calibration=*/ std::nullopt ,
232- /* b_calibration=*/ std::nullopt ,
233- w2_quant_flag.has_value () ? w2_quant_flag : std::nullopt , // quant_flag
234- /* b_offset=*/ std::nullopt ,
235- /* tile_config=*/ std::nullopt ,
236- /* max_dim=*/ tokens,
237- /* trans_a=*/ false ,
238- /* trans_b=*/ true ,
239- /* a_quant_bit=*/ is_smoothquant ? 8 : -1 );
59+ torch::Tensor moe_expand_input (
60+ const torch::Tensor& input,
61+ const torch::Tensor& gather_index,
62+ const std::optional<torch::Tensor>& cusum_token_count,
63+ int64_t start_expert_id,
64+ int64_t expert_size) {
65+ return tmo::torch_api::moe_expand_input (
66+ input, gather_index, cusum_token_count, start_expert_id, expert_size);
67+ }
24068
241- auto output = torch::empty ({reduce_weight.size (0 ), gemm2_out.size (1 )},
242- gemm2_out.options ());
243- tmo::torch_api::moe_combine_result (gemm2_out,
69+ torch::Tensor moe_combine_result (
70+ const torch::Tensor& input,
71+ const torch::Tensor& reduce_weight,
72+ const torch::Tensor& gather_ids,
73+ const std::optional<torch::Tensor>& residual,
74+ const std::optional<torch::Tensor>& cusum_token_count,
75+ const int64_t start_expert_id,
76+ const int64_t expert_size,
77+ const std::optional<torch::Tensor>& bias) {
78+ auto output =
79+ torch::empty ({reduce_weight.size (0 ), input.size (1 )}, input.options ());
80+ tmo::torch_api::moe_combine_result (input,
24481 output,
24582 reduce_weight,
246- combine_idx ,
247- residual_2d ,
83+ gather_ids ,
84+ residual ,
24885 cusum_token_count,
24986 start_expert_id,
25087 expert_size,
251- bias2);
252-
253- return output.reshape (ori_input_shape);
88+ bias);
89+ return output;
25490}
25591
25692} // namespace xllm::kernel::mlu
0 commit comments