Skip to content

Commit aa78d09

Browse files
committed
[INTEL_HPU] Add fused Mixture of Experts op
1 parent 3db38f4 commit aa78d09

File tree

2 files changed

+799
-0
lines changed

2 files changed

+799
-0
lines changed
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "habanalabs/perf_lib_layer_params.h"
16+
#include "kernels/funcs.h"
17+
#include "kernels/hpu_funcs.h"
18+
#include "kernels/hpu_operator.h"
19+
#include "paddle/extension.h"
20+
#include "utils/utils.h"
21+
22+
namespace custom_kernel {
23+
24+
static const std::map<std::string_view, MoeActivationMode_t> activationModeMap =
25+
{{"gelu", MoeActivationMode_t::MOE_ACTIVATION_MODE_GELU},
26+
{"relu", MoeActivationMode_t::MOE_ACTIVATION_MODE_RELU},
27+
{"silu", MoeActivationMode_t::MOE_ACTIVATION_MODE_SILU}};
28+
29+
struct FusedMoEConfig {
30+
bool permuted_weights;
31+
bool fused_gemm;
32+
bool measurement_mode;
33+
std::string_view activation_mode;
34+
int32_t num_experts;
35+
int32_t experts_min;
36+
int32_t experts_max;
37+
bool dynamic_scale;
38+
bool blockwise_quantization;
39+
int32_t block_size;
40+
};
41+
42+
std::shared_ptr<ns_MoeKernel::ParamsV2> FillMixtureOfExpertsParams(
43+
const FusedMoEConfig& config) {
44+
auto moe_params = std::make_shared<ns_MoeKernel::ParamsV2>();
45+
memset(reinterpret_cast<void*>(moe_params.get()),
46+
0x00,
47+
sizeof(ns_MoeKernel::ParamsV2));
48+
49+
auto activationIterator = activationModeMap.find(config.activation_mode);
50+
moe_params->experts.activation = activationIterator->second;
51+
52+
moe_params->router.experts_min = config.experts_min;
53+
moe_params->router.experts_max = config.experts_max;
54+
55+
moe_params->flags =
56+
config.permuted_weights ? MoeFlags_t::MOE_FLAGS_PERMUTED_WEIGHTS : 0;
57+
moe_params->flags |=
58+
(config.fused_gemm ? MoeFlags_t::MOE_FLAGS_FUSED_GEMM : 0);
59+
moe_params->flags |=
60+
(config.measurement_mode ? MoeFlags_t::MOE_FLAGS_CALC_AMAX : 0);
61+
62+
return moe_params;
63+
}
64+
65+
class FusedMixtureOfExperts : public HpuFusedOperator {
66+
public:
67+
explicit FusedMixtureOfExperts(synDataType dtype)
68+
: HpuFusedOperator("moe_", false), dtype_(dtype) {}
69+
70+
template <typename T>
71+
void AddNodeMoeForward(std::vector<synTensor> inputs,
72+
std::vector<synTensor> outputs,
73+
std::shared_ptr<ns_MoeKernel::ParamsV2> params) {
74+
std::string node_name = "moe_fwd";
75+
76+
std::string guid = guid_ + guid_dtype<T>();
77+
78+
AddNode_IOP<ns_MoeKernel::ParamsV2>(
79+
inputs, outputs, *params, guid, node_name);
80+
}
81+
82+
template <typename T>
83+
void AddNode(ConvertTensors* ct, FusedMoEConfig config) {
84+
auto weights_per_expert = config.fused_gemm ? 2 : 3;
85+
std::vector<synTensor> inputs;
86+
87+
int64_t input_count = 3 + config.num_experts * weights_per_expert;
88+
for (int64_t i = 0; i < input_count; i++) {
89+
inputs.push_back(createTensorFromCT(ct, i));
90+
}
91+
92+
const bool measurement_mode = config.measurement_mode;
93+
std::vector<synTensor> outputs;
94+
if (measurement_mode) {
95+
for (size_t i = 0; i < 2; i++) {
96+
outputs.push_back(createTensorFromCT(ct, i, false));
97+
}
98+
} else {
99+
outputs.push_back(createTensorFromCT(ct, 0, false));
100+
}
101+
102+
auto params = FillMixtureOfExpertsParams(config);
103+
AddNodeMoeForward<T>(inputs, outputs, params);
104+
}
105+
106+
protected:
107+
synDataType dtype_;
108+
};
109+
110+
template <typename T, typename Context>
111+
void FusedMoEKernel(const Context& dev_ctx,
112+
const phi::DenseTensor& hidden_states,
113+
const phi::DenseTensor& expert_routing_table,
114+
const phi::DenseTensor& router_weights,
115+
const std::vector<phi::DenseTensor>& gate_up_weights,
116+
const std::vector<phi::DenseTensor>& down_weights,
117+
const bool permuted_weights,
118+
const std::string& activation,
119+
const int experts_min,
120+
const int experts_max,
121+
const bool measurement_mode,
122+
phi::DenseTensor* final_hidden_states,
123+
phi::DenseTensor* amax_per_expert) {
124+
ConvertTensors ct;
125+
ct.Add(hidden_states);
126+
ct.Add(expert_routing_table);
127+
ct.Add(router_weights);
128+
for (const auto& t : gate_up_weights) {
129+
ct.Add(t);
130+
}
131+
for (const auto& t : down_weights) {
132+
ct.Add(t);
133+
}
134+
std::vector<DIMS> inputs_dims = ct.GetDims();
135+
136+
ct.Add(final_hidden_states, false);
137+
ct.Add(amax_per_expert, false);
138+
139+
OpCacheOperator op_info;
140+
op_info.prepareOpInfo<T, nullptr_t>("fused_moe_", inputs_dims, nullptr);
141+
auto recipe = op_info.GetRecipe();
142+
143+
if (recipe == nullptr) {
144+
FusedMoEConfig config;
145+
memset(reinterpret_cast<void*>(&config), 0x00, sizeof(FusedMoEConfig));
146+
147+
config.permuted_weights = permuted_weights;
148+
config.fused_gemm = (gate_up_weights.size() == down_weights.size());
149+
config.measurement_mode = measurement_mode;
150+
config.activation_mode = activation;
151+
config.experts_min = experts_min;
152+
config.experts_max = experts_max;
153+
config.num_experts = router_weights.dims()[1];
154+
155+
FusedMixtureOfExperts op(op_info.datatype_);
156+
op.AddNode<T>(&ct, config);
157+
op.Compile();
158+
op_info.setOp(op);
159+
160+
recipe = op_info.GetRecipe();
161+
}
162+
163+
std::map<std::string, uint64_t> tensors = ct.GetDeviceAddr();
164+
RecipeRunner runner(recipe);
165+
runner.Run(reinterpret_cast<C_Stream>(dev_ctx.stream()), tensors);
166+
}
167+
168+
} // namespace custom_kernel
169+
170+
template <typename Context>
171+
void CallFusedMoEKernel(const Context& dev_ctx,
172+
const phi::DenseTensor& hidden_states,
173+
const phi::DenseTensor& expert_routing_table,
174+
const phi::DenseTensor& router_weights,
175+
const std::vector<phi::DenseTensor>& gate_up_weights,
176+
const std::vector<phi::DenseTensor>& down_weights,
177+
const bool permuted_weights,
178+
const std::string& activation,
179+
const int experts_min,
180+
const int experts_max,
181+
const bool measurement_mode,
182+
phi::DenseTensor* final_hidden_states,
183+
phi::DenseTensor* amax_per_expert) {
184+
if (hidden_states.dtype() == phi::DataType::FLOAT16) {
185+
custom_kernel::FusedMoEKernel<phi::dtype::float16>(dev_ctx,
186+
hidden_states,
187+
expert_routing_table,
188+
router_weights,
189+
gate_up_weights,
190+
down_weights,
191+
permuted_weights,
192+
activation,
193+
experts_min,
194+
experts_max,
195+
measurement_mode,
196+
final_hidden_states,
197+
amax_per_expert);
198+
} else if (hidden_states.dtype() == phi::DataType::BFLOAT16) {
199+
custom_kernel::FusedMoEKernel<phi::dtype::bfloat16>(dev_ctx,
200+
hidden_states,
201+
expert_routing_table,
202+
router_weights,
203+
gate_up_weights,
204+
down_weights,
205+
permuted_weights,
206+
activation,
207+
experts_min,
208+
experts_max,
209+
measurement_mode,
210+
final_hidden_states,
211+
amax_per_expert);
212+
} else {
213+
throw std::runtime_error("Unsupported data type for FusedMoEKernel");
214+
}
215+
}
216+
217+
std::vector<paddle::Tensor> MixtureOfExpertsForward(
218+
const paddle::Tensor& hidden_states,
219+
const paddle::Tensor& expert_routing_table,
220+
const paddle::Tensor& router_weights,
221+
const std::vector<paddle::Tensor>& gate_up_weights,
222+
const std::vector<paddle::Tensor>& down_weights,
223+
const bool permuted_weights,
224+
const std::string& activation,
225+
const int experts_min,
226+
const int experts_max,
227+
const bool measurement_mode) {
228+
auto dev_ctx = static_cast<const phi::CustomContext*>(
229+
paddle::experimental::DeviceContextPool::Instance().Get(
230+
hidden_states.place()));
231+
auto hidden_states_tensor =
232+
static_cast<const phi::DenseTensor*>(hidden_states.impl().get());
233+
auto expert_routing_table_tensor =
234+
static_cast<const phi::DenseTensor*>(expert_routing_table.impl().get());
235+
auto router_weights_tensor =
236+
static_cast<const phi::DenseTensor*>(router_weights.impl().get());
237+
238+
std::vector<phi::DenseTensor> gate_up_weights_vec;
239+
for (const auto& t : gate_up_weights) {
240+
gate_up_weights_vec.push_back(
241+
*static_cast<const phi::DenseTensor*>(t.impl().get()));
242+
}
243+
std::vector<phi::DenseTensor> down_weights_vec;
244+
for (const auto& t : down_weights) {
245+
down_weights_vec.push_back(
246+
*static_cast<const phi::DenseTensor*>(t.impl().get()));
247+
}
248+
249+
// allocate memory on device.
250+
int64_t num_tokens = hidden_states.dims()[0];
251+
int64_t hidden_dims = hidden_states.dims()[1];
252+
int64_t num_experts = router_weights.dims()[1];
253+
254+
std::shared_ptr<phi::DenseTensor> final_hidden_states =
255+
std::make_shared<phi::DenseTensor>();
256+
final_hidden_states->Resize(phi::make_ddim({num_tokens, hidden_dims}));
257+
dev_ctx->Alloc(final_hidden_states.get(), hidden_states.dtype());
258+
259+
std::shared_ptr<phi::DenseTensor> amax_per_expert =
260+
std::make_shared<phi::DenseTensor>();
261+
amax_per_expert->Resize(phi::make_ddim({num_experts}));
262+
dev_ctx->Alloc(amax_per_expert.get(), paddle::DataType::FLOAT32);
263+
264+
CallFusedMoEKernel(*dev_ctx,
265+
*hidden_states_tensor,
266+
*expert_routing_table_tensor,
267+
*router_weights_tensor,
268+
gate_up_weights_vec,
269+
down_weights_vec,
270+
permuted_weights,
271+
activation,
272+
experts_min,
273+
experts_max,
274+
measurement_mode,
275+
final_hidden_states.get(),
276+
amax_per_expert.get());
277+
278+
return {paddle::Tensor(final_hidden_states), paddle::Tensor(amax_per_expert)};
279+
}
280+
281+
std::vector<std::vector<int64_t>> MixtureOfExpertsInferShape(
282+
const std::vector<int64_t>& hidden_states_shape,
283+
const std::vector<int64_t>& expert_routing_table_shape,
284+
const std::vector<int64_t>& router_weights_shape,
285+
const std::vector<int64_t>& gate_up_weights_shape,
286+
const std::vector<int64_t>& down_weights_shape) {
287+
int64_t num_tokens = hidden_states_shape[0];
288+
int64_t hidden_dims = hidden_states_shape[1];
289+
int64_t num_experts = router_weights_shape[1];
290+
return {{num_tokens, hidden_dims}, {num_experts}};
291+
}
292+
293+
std::vector<paddle::DataType> MixtureOfExpertsInferDtype(
294+
const paddle::DataType& hidden_states_dtype,
295+
const paddle::DataType& expert_routing_table_dtype,
296+
const paddle::DataType& router_weights_dtype,
297+
const paddle::DataType& gate_up_weights_dtype,
298+
const paddle::DataType& down_weights_dtype) {
299+
return {hidden_states_dtype, paddle::DataType::FLOAT32};
300+
}
301+
302+
PD_BUILD_OP(mixture_of_experts)
303+
.Inputs({"hidden_states",
304+
"expert_routing_table",
305+
"router_weights",
306+
paddle::Vec("gate_up_weights"),
307+
paddle::Vec("down_weights")})
308+
.Outputs({"final_hidden_states", paddle::Optional("amax_per_expert")})
309+
.Attrs({"permuted_weights: bool",
310+
"activation: std::string",
311+
"experts_min: int",
312+
"experts_max: int",
313+
"measurement_mode: bool"})
314+
.SetKernelFn(PD_KERNEL(MixtureOfExpertsForward))
315+
.SetInferShapeFn(PD_INFER_SHAPE(MixtureOfExpertsInferShape))
316+
.SetInferDtypeFn(PD_INFER_DTYPE(MixtureOfExpertsInferDtype));

0 commit comments

Comments
 (0)