diff --git a/backends/intel_hpu/custom_ops/llama_infer/fused_block_attention.cc b/backends/intel_hpu/custom_ops/llama_infer/fused_block_attention.cc new file mode 100644 index 0000000000..8e678c7182 --- /dev/null +++ b/backends/intel_hpu/custom_ops/llama_infer/fused_block_attention.cc @@ -0,0 +1,1818 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "habanalabs/perf_lib_layer_params.h" +#include "kernels/funcs.h" +#include "kernels/hpu_funcs.h" +#include "kernels/hpu_operator.h" +#include "paddle/extension.h" +#include "utils/utils.h" + +namespace custom_kernel { + +struct FusedBlockAttentionParams { + ns_LayerNormKernel::Params rmsnorm_params; + ns_ConstantKernel::Params const_params; + ns_GatherKernel::Params index_select_params; + ns_Reduction::Params reduce_params; + ns_IndexReduce::Params index_reduce_params; + + int head_dim; + int num_head; + int num_kv_head; +}; + +class FusedMHABlockAttention : public HpuFusedOperator { + public: + explicit FusedMHABlockAttention(synDataType dtype) + : HpuFusedOperator("fused_block_attention_fwd_", false), dtype_(dtype) {} + template + void AddNode(ConvertTensors& ct, FusedBlockAttentionParams& params) { + auto ins = ct.GetTensors(); + auto outs = ct.GetTensors(false); + + std::vector src_dims = std::vector(ins[0].dims); + + int64_t batch_size = src_dims[0]; + int64_t seq_length = src_dims[1]; + int64_t hidden_size = ins[13].dims[0]; + int64_t block_size = ins[3].dims[1]; + int64_t num_of_block = ins[6].dims[0]; + + int64_t num_head = params.num_head; + int64_t head_dim = params.head_dim; + int64_t num_kv_head = params.num_kv_head; + + synGEMMParams gemm_params_f_f; + gemm_params_f_f.transpose_a = false; + gemm_params_f_f.transpose_b = false; + + synGEMMParams gemm_params_t_f; + gemm_params_t_f.transpose_a = true; + gemm_params_t_f.transpose_b = false; + + synGEMMParams gemm_params_f_t; + gemm_params_f_t.transpose_a = false; + gemm_params_f_t.transpose_b = true; + + synSectionHandle residual_section = createSection(); + auto src = createTensorFromCT(&ct, 0); + auto residual = createTensorFromCT(&ct, 1, true, residual_section); + auto residual_out = createTensorFromCT(&ct, 3, false, residual_section); + + std::vector add_residual_in; + add_residual_in.push_back(src); + add_residual_in.push_back(residual); + + std::vector add_residual_out; + add_residual_out.push_back(residual_out); + + AddNodeAdd(add_residual_in, add_residual_out, guid_ + "add_residual"); + + auto ln_scales = createTensorFromCT(&ct, 11); + + std::vector rmsnorm_inputs; + rmsnorm_inputs.push_back(residual_out); + rmsnorm_inputs.push_back(ln_scales); + + auto tmp_dims = src_dims; + tmp_dims[2] = 1; + auto norm_out = createTensorNoPresist("norm_out", dtype_, src_dims); + auto norm_var = createTensorNoPresist("norm_var", dtype_, tmp_dims); + + std::vector rmsnorm_outputs; + rmsnorm_outputs.push_back(norm_out); + rmsnorm_outputs.push_back(norm_var); + + AddNodeRmsNorm(rmsnorm_inputs, + rmsnorm_outputs, + params.rmsnorm_params, + guid_ + "rmsnorm"); + + auto qkv_weights = createTensorFromCT(&ct, 12); + std::vector mul_inputs; + mul_inputs.push_back(norm_out); + mul_inputs.push_back(qkv_weights); + + auto wt_dims = ins[12].dims; + tmp_dims[2] = wt_dims[0]; + + auto qkv_out = createTensorNoPresist("qkv_out", dtype_, tmp_dims); + std::vector mul_outputs; + mul_outputs.push_back(qkv_out); + + synGEMMParams gemm_params; + gemm_params.transpose_a = false; + gemm_params.transpose_b = true; + AddNodeBatchGemm(mul_inputs, mul_outputs, gemm_params, guid_ + "batchgemm"); + + auto reshape_dims = src_dims; + reshape_dims[2] = num_head + 2 * num_kv_head; + reshape_dims.push_back(head_dim); + + std::vector reshape_outputs; + auto reshape_out = + createTensorNoPresist("reshape_out", dtype_, reshape_dims); + reshape_outputs.push_back(reshape_out); + + AddNodeReshape(mul_outputs, reshape_outputs, guid_ + "reshape_qkv"); + + std::vector q_dims; + q_dims.push_back(batch_size); + q_dims.push_back(seq_length); + q_dims.push_back(num_head); + q_dims.push_back(head_dim); + std::vector kv_dims; + kv_dims.push_back(batch_size); + kv_dims.push_back(seq_length); + kv_dims.push_back(num_kv_head); + kv_dims.push_back(head_dim); + + auto q_split = createTensorNoPresist("q_split", dtype_, q_dims); + auto k_split = createTensorNoPresist("k_split", dtype_, kv_dims); + auto v_split = createTensorNoPresist("v_split", dtype_, kv_dims); + std::vector split_outpus; + split_outpus.push_back(q_split); + split_outpus.push_back(k_split); + split_outpus.push_back(v_split); + + synSplitParams splitParams; + splitParams.axis = 1; + AddNodeSplit(reshape_outputs, split_outpus, splitParams, guid_ + "split"); + + std::vector rotary_embs_inputs; + auto rotary_embs_c = createTensorFromCT(&ct, 2); + rotary_embs_inputs.push_back(rotary_embs_c); + + auto rotary_embs_dims = ins[2].dims; + rotary_embs_dims[0] = 1; + + std::vector cos_inputs; + auto cos_in = createTensorNoPresist("cos_in", dtype_, rotary_embs_dims); + cos_inputs.push_back(cos_in); + + synSliceParamsV2 sliceParams; + for (uint64_t i = 0; i < rotary_embs_dims.size(); i++) { + sliceParams.axes[i] = i; + sliceParams.steps[i] = 1; + sliceParams.starts[i] = 0; + sliceParams.ends[i] = rotary_embs_dims[rotary_embs_dims.size() - 1 - i]; + } + AddNodeSlice( + rotary_embs_inputs, cos_inputs, sliceParams, guid_ + "slice_cos"); + + std::vector sin_inputs; + auto sin_in = createTensorNoPresist("sin_in", dtype_, rotary_embs_dims); + sin_inputs.push_back(sin_in); + sliceParams.starts[rotary_embs_dims.size() - 1] = 1; + sliceParams.ends[rotary_embs_dims.size() - 1] = 2; + AddNodeSlice( + rotary_embs_inputs, sin_inputs, sliceParams, guid_ + "slice_sin"); + + rotary_embs_dims.erase(rotary_embs_dims.begin()); + auto sin_sq = + createTensorNoPresist("sin_squeezed", dtype_, rotary_embs_dims); + std::vector sin_squeezed; + sin_squeezed.push_back(sin_sq); + + synSqueezeParams squeezeParams; + squeezeParams.axis = 4; + AddNodeSqueeze( + sin_inputs, sin_squeezed, squeezeParams, guid_ + "squeeze_sin"); + + auto cos_sq = + createTensorNoPresist("cos_squeezed", dtype_, rotary_embs_dims); + std::vector cos_squeezed; + cos_squeezed.push_back(cos_sq); + AddNodeSqueeze( + cos_inputs, cos_squeezed, squeezeParams, guid_ + "squeeze_cos"); + + std::vector inputs_q; + std::vector outputs_q; + inputs_q.push_back(q_split); + inputs_q.push_back(sin_sq); + inputs_q.push_back(cos_sq); + + auto q_states = createTensorNoPresist("q_states", dtype_, q_dims); + outputs_q.push_back(q_states); + + ns_RoPESt2::ParamsV2 ropeParams; + ropeParams.offset = 0; + ropeParams.mode = ROTARY_POS_EMBEDDING_MODE_BLOCKWISE; + AddNodeRope(inputs_q, outputs_q, ropeParams, guid_ + "rope_q"); + + std::vector inputs_k; + std::vector outputs_k; + inputs_k.push_back(k_split); + inputs_k.push_back(sin_sq); + inputs_k.push_back(cos_sq); + + auto k_rope = createTensorNoPresist("k_rope", dtype_, kv_dims); + outputs_k.push_back(k_rope); + AddNodeRope(inputs_k, outputs_k, ropeParams, guid_ + "rope_k"); + + ////////////////////////////////////////////////////////////////// + kv_dims.erase(kv_dims.begin() + 1); + + std::vector outputs_k_squeeze; + auto k_squeeze = createTensorNoPresist("k_squeeze", dtype_, kv_dims); + outputs_k_squeeze.push_back(k_squeeze); + AddNodeReshape(outputs_k, outputs_k_squeeze, guid_ + "squeeze_k"); + + std::vector inputs_v_squeeze; + inputs_v_squeeze.push_back(v_split); + std::vector outputs_v_squeeze; + auto v_squeeze = createTensorNoPresist("v_squeeze", dtype_, kv_dims); + outputs_v_squeeze.push_back(v_squeeze); + AddNodeReshape(inputs_v_squeeze, outputs_v_squeeze, guid_ + "squeeze_v"); + + std::vector indices_concat_dims = + std::vector(ins[9].dims); + indices_concat_dims.emplace_back(1); + + std::vector inputs_concat; + inputs_concat.push_back(createTensor(indices_concat_dims.size(), + ins[9].type, + indices_concat_dims, + true, + ins[9].name)); + inputs_concat.push_back(createTensor(indices_concat_dims.size(), + ins[10].type, + indices_concat_dims, + true, + ins[10].name)); + + std::vector outputs_concat; + indices_concat_dims.back() = 2; + auto indices_concat = createTensor(indices_concat_dims.size(), + ins[9].type, + indices_concat_dims, + false, + "indices_concat"); + outputs_concat.push_back(indices_concat); + + synConcatenateParams concatParams; + concatParams.axis = 0; + AddNodeConcat( + inputs_concat, outputs_concat, concatParams, guid_ + "concat"); + + synSectionHandle kCache_section = createSection(); + auto key_cache = createTensorFromCT(&ct, 3, true, kCache_section); + auto kCache_out = createTensorFromCT(&ct, 1, false, kCache_section); + std::vector inputs_scatter_k; + inputs_scatter_k.push_back(key_cache); + inputs_scatter_k.push_back(indices_concat); + inputs_scatter_k.push_back(k_squeeze); + std::vector outputs_scatter_k; + outputs_scatter_k.push_back(kCache_out); + AddNodeScatter( + inputs_scatter_k, outputs_scatter_k, guid_ + "index_put_k"); + + synSectionHandle vCache_section = createSection(); + auto value_cache = createTensorFromCT(&ct, 4, true, vCache_section); + auto vCache_out = createTensorFromCT(&ct, 2, false, vCache_section); + std::vector inputs_scatter_v; + inputs_scatter_v.push_back(value_cache); + inputs_scatter_v.push_back(indices_concat); + inputs_scatter_v.push_back(v_squeeze); + std::vector outputs_scatter_v; + outputs_scatter_v.push_back(vCache_out); + AddNodeScatter( + inputs_scatter_v, outputs_scatter_v, guid_ + "index_put_v"); + ////////////////////////////////////////////////////////////////// + + std::vector scaler_dims = {1}; + auto scaler_tensor = + createTensorNoPresist("scaler_tensor", syn_type_bf16, scaler_dims); + std::vector scaler; + scaler.push_back(scaler_tensor); + AddNodeFull(scaler, params.const_params, guid_ + "full_scale"); + + std::vector scaled_q_in; + scaled_q_in.push_back(q_states); + scaled_q_in.push_back(scaler_tensor); + + auto scaled_q = createTensorNoPresist("scaled_q", dtype_, q_dims); + std::vector scaled_q_out; + scaled_q_out.push_back(scaled_q); + + AddNodeMultiply(scaled_q_in, scaled_q_out, guid_ + "mul_scale_q"); + + std::vector reshape_q_dims; + reshape_q_dims.push_back(batch_size); + reshape_q_dims.push_back(hidden_size); + + auto reshaped_q = + createTensorNoPresist("reshaped_q", dtype_, reshape_q_dims); + std::vector reshape_q_out; + reshape_q_out.push_back(reshaped_q); + + AddNodeReshape(scaled_q_out, reshape_q_out, guid_ + "reshape_scale_q"); + + /*******************************/ + + std::vector map_q_in; + auto block_mapping = createTensorFromCT(&ct, 7); + map_q_in.push_back(block_mapping); + map_q_in.push_back(reshaped_q); + + std::vector map_q_dims; + map_q_dims.push_back(num_of_block); + map_q_dims.push_back(hidden_size); + auto mapped_q = createTensorNoPresist("mapped_q", dtype_, map_q_dims); + std::vector map_q_out; + map_q_out.push_back(mapped_q); + + AddNodeGemm(map_q_in, map_q_out, gemm_params_f_f, guid_ + "gemm_map_q"); + + std::vector reshape_map_q_dims; + reshape_map_q_dims.push_back(num_of_block); + reshape_map_q_dims.push_back(num_head); + reshape_map_q_dims.push_back(1); + reshape_map_q_dims.push_back(head_dim); + + auto reshaped_map_q = + createTensorNoPresist("reshaped_map_q", dtype_, reshape_map_q_dims); + std::vector reshape_map_q_out; + reshape_map_q_out.push_back(reshaped_map_q); + + AddNodeReshape(map_q_out, reshape_map_q_out, guid_ + "reshape_map_q"); + + /*******************************/ + + std::vector index_select_k_in; + std::vector index_select_v_in; + + auto block_list = createTensorFromCT(&ct, 6); + + index_select_k_in.push_back(kCache_out); + index_select_v_in.push_back(vCache_out); + index_select_k_in.push_back(block_list); + index_select_v_in.push_back(block_list); + + std::vector index_selected_dims; + index_selected_dims.push_back(num_of_block); + index_selected_dims.push_back(block_size); + index_selected_dims.push_back(num_kv_head); + index_selected_dims.push_back(head_dim); + + auto index_select_k_i = + createTensorNoPresist("index_select_k_i", dtype_, index_selected_dims); + auto index_select_v_i = + createTensorNoPresist("index_select_v_i", dtype_, index_selected_dims); + std::vector index_select_k_out; + index_select_k_out.push_back(index_select_k_i); + std::vector index_select_v_out; + index_select_v_out.push_back(index_select_v_i); + + AddNodeIndexSelect(index_select_k_in, + index_select_k_out, + params.index_select_params, + guid_ + "index_select_k_i"); + AddNodeIndexSelect(index_select_v_in, + index_select_v_out, + params.index_select_params, + guid_ + "index_select_v_i"); + + std::vector axis = {0, 2, 1, 3}; + synTransposeParams trans_params; + for (size_t i = 0; i < axis.size(); i++) { + trans_params.permutation[i] = + static_cast(axis[i]); + } + trans_params.tensorDim = 4; + + std::vector transpose_dims; + transpose_dims.push_back(num_of_block); + transpose_dims.push_back(num_kv_head); + transpose_dims.push_back(block_size); + transpose_dims.push_back(head_dim); + + auto transpose_k = + createTensorNoPresist("transpose_k", dtype_, transpose_dims); + std::vector trans_index_select_k; + trans_index_select_k.push_back(transpose_k); + + AddNodeTranspose(index_select_k_out, + trans_index_select_k, + trans_params, + guid_ + "transpose_k"); + + auto transpose_v = + createTensorNoPresist("transpose_v", dtype_, transpose_dims); + std::vector trans_index_select_v; + trans_index_select_v.push_back(transpose_v); + + AddNodeTranspose(index_select_v_out, + trans_index_select_v, + trans_params, + guid_ + "transpose_v"); + + std::vector q_k_in; + q_k_in.push_back(reshaped_map_q); + q_k_in.push_back(transpose_k); + + std::vector q_k_dims; + q_k_dims.push_back(num_of_block); + q_k_dims.push_back(num_head); + q_k_dims.push_back(1); + q_k_dims.push_back(block_size); + auto q_k = createTensorNoPresist("q_k", dtype_, q_k_dims); + std::vector q_k_out; + q_k_out.push_back(q_k); + + AddNodeBatchGemm(q_k_in, q_k_out, gemm_params_f_t, guid_ + "batchgemm_q_k"); + + /*******************************/ + + auto block_bias = createTensorFromCT(&ct, 8); + std::vector block_bias_in; + block_bias_in.push_back(block_bias); + + std::vector reshaped_bias_dims; + reshaped_bias_dims.push_back(num_of_block); + reshaped_bias_dims.push_back(1); + reshaped_bias_dims.push_back(1); + reshaped_bias_dims.push_back(block_size); + + auto reshaped_bias = + createTensorNoPresist("reshaped_bias", dtype_, reshaped_bias_dims); + std::vector block_bias_out; + block_bias_out.push_back(reshaped_bias); + + AddNodeReshape(block_bias_in, block_bias_out, guid_ + "reshaped_bias"); + + std::vector add_bias_in; + add_bias_in.push_back(q_k); + add_bias_in.push_back(reshaped_bias); + + auto add_bias = createTensorNoPresist("add_bias", dtype_, q_k_dims); + std::vector add_bias_out; + add_bias_out.push_back(add_bias); + + AddNodeAdd(add_bias_in, add_bias_out, guid_ + "add_bias"); + /*******************************/ + + std::vector block_max_dims; + block_max_dims.push_back(num_of_block); + block_max_dims.push_back(num_head); + block_max_dims.push_back(1); + block_max_dims.push_back(1); + + auto block_max = createTensorNoPresist("block_max", dtype_, block_max_dims); + std::vector block_max_out; + block_max_out.push_back(block_max); + + AddNodeReduceMax( + add_bias_out, block_max_out, params.reduce_params, guid_ + "reduceMax"); + + /**************************************************************/ + + std::vector sum_adjusted_dims; + sum_adjusted_dims.push_back(num_of_block); + sum_adjusted_dims.push_back(num_head); + + auto block_max_2D = + createTensorNoPresist("block_max_2D", dtype_, sum_adjusted_dims); + std::vector block_max_2D_out; + block_max_2D_out.push_back(block_max_2D); + + AddNodeReshape( + block_max_out, block_max_2D_out, guid_ + "squeeze_block_max"); + + std::vector group_max_dims; + group_max_dims.push_back(batch_size + 1); + group_max_dims.push_back(num_head); + + auto group_max = createTensorNoPresist("group_max", dtype_, group_max_dims); + std::vector group_max_tensor; + group_max_tensor.push_back(group_max); + + params.const_params.constant.f = -std::numeric_limits::infinity(); + + AddNodeFull(group_max_tensor, params.const_params, guid_ + "full_inf"); + + auto block_groups = createTensorFromCT(&ct, 5); + std::vector index_reduce_in; + index_reduce_in.push_back(group_max); + index_reduce_in.push_back(block_groups); + index_reduce_in.push_back(block_max_2D); + + auto reduced_group_max = + createTensorNoPresist("reduced_group_max", dtype_, group_max_dims); + std::vector index_reduce_out; + index_reduce_out.push_back(reduced_group_max); + + AddNodeIndexReduce(index_reduce_in, + index_reduce_out, + params.index_reduce_params, + guid_ + "index_reduce_amax"); + + std::vector index_select_groupmax_in; + index_select_groupmax_in.push_back(reduced_group_max); + index_select_groupmax_in.push_back(block_groups); + + auto selected_group_max = + createTensorNoPresist("selected_group_max", dtype_, sum_adjusted_dims); + std::vector index_select_groupmax_out; + index_select_groupmax_out.push_back(selected_group_max); + params.index_select_params.axis = 1; + AddNodeIndexSelect(index_select_groupmax_in, + index_select_groupmax_out, + params.index_select_params, + guid_ + "index_select_groupmax"); + + std::vector sub_group_max_in; + sub_group_max_in.push_back(block_max_2D); + sub_group_max_in.push_back(selected_group_max); + + auto sub_group_max = + createTensorNoPresist("sub_group_max", dtype_, sum_adjusted_dims); + std::vector sub_group_max_out; + sub_group_max_out.push_back(sub_group_max); + + AddNodeSub(sub_group_max_in, sub_group_max_out, guid_ + "sub_group_max"); + + auto block_adjustment = + createTensorNoPresist("block_adjustment", dtype_, sum_adjusted_dims); + std::vector block_adjustment_out; + block_adjustment_out.push_back(block_adjustment); + AddNodeExp(sub_group_max_out, + block_adjustment_out, + guid_ + "exp_block_adjustment"); + + /**************************************************************/ + + std::vector sub_block_max_in; + sub_block_max_in.push_back(add_bias); + sub_block_max_in.push_back(block_max); + + auto sub_block_max = + createTensorNoPresist("sub_block_max", dtype_, q_k_dims); + std::vector sub_block_max_out; + sub_block_max_out.push_back(sub_block_max); + AddNodeSub(sub_block_max_in, sub_block_max_out, guid_ + "sub_block_max"); + + auto score = createTensorNoPresist("score", dtype_, q_k_dims); + std::vector score_out; + score_out.push_back(score); + AddNodeExp(sub_block_max_out, score_out, guid_ + "exp_score"); + + /*******************************/ + + std::vector score_v_in; + score_v_in.push_back(score); + score_v_in.push_back(transpose_v); + + auto score_v = createTensorNoPresist("score_v", dtype_, reshape_map_q_dims); + std::vector score_v_out; + score_v_out.push_back(score_v); + + AddNodeBatchGemm( + score_v_in, score_v_out, gemm_params_f_f, guid_ + "batchgemm_score_v"); + + auto reduceSum = createTensorNoPresist("reduceSum", dtype_, block_max_dims); + std::vector reduceSum_out; + reduceSum_out.push_back(reduceSum); + + AddNodeReduceSum( + score_out, reduceSum_out, params.reduce_params, guid_ + "reduceSum"); + + auto block_sums_2D = + createTensorNoPresist("block_sums_2D", dtype_, sum_adjusted_dims); + std::vector block_sums_2D_out; + block_sums_2D_out.push_back(block_sums_2D); + + AddNodeReshape( + reduceSum_out, block_sums_2D_out, guid_ + "squeeze_block_sums"); + + std::vector sum_adjusted_in; + sum_adjusted_in.push_back(block_sums_2D); + sum_adjusted_in.push_back(block_adjustment); + + auto sum_adjusted = + createTensorNoPresist("sum_adjusted", dtype_, sum_adjusted_dims); + std::vector sum_adjusted_out; + sum_adjusted_out.push_back(sum_adjusted); + + AddNodeMultiply( + sum_adjusted_in, sum_adjusted_out, guid_ + "mul_sum_adjusted"); + /***************************************************************/ + + std::vector map_sum_adjusted_in; + map_sum_adjusted_in.push_back(block_mapping); + map_sum_adjusted_in.push_back(sum_adjusted); + + std::vector map_sum_adjusted_dims; + map_sum_adjusted_dims.push_back(batch_size); + map_sum_adjusted_dims.push_back(num_head); + auto mapped_sum_adjusted = createTensorNoPresist( + "mapped_sum_adjusted", dtype_, map_sum_adjusted_dims); + std::vector map_sum_adjusted_out; + map_sum_adjusted_out.push_back(mapped_sum_adjusted); + + AddNodeGemm(map_sum_adjusted_in, + map_sum_adjusted_out, + gemm_params_t_f, + guid_ + "gemm_map_sum_adjusted"); + + std::vector group_sum_adjusted_in; + group_sum_adjusted_in.push_back(block_mapping); + group_sum_adjusted_in.push_back(mapped_sum_adjusted); + + auto group_sum_adjusted = + createTensorNoPresist("group_sum_adjusted", dtype_, sum_adjusted_dims); + std::vector group_sum_adjusted_out; + group_sum_adjusted_out.push_back(group_sum_adjusted); + + AddNodeGemm(group_sum_adjusted_in, + group_sum_adjusted_out, + gemm_params_f_f, + guid_ + "gemm_group_sum_adjusted"); + + /*******************************/ + + auto reshaped_group_sum_adjusted = createTensorNoPresist( + "reshaped_group_sum_adjusted", dtype_, block_max_dims); + std::vector group_sum_adjusted_4D; + group_sum_adjusted_4D.push_back(reshaped_group_sum_adjusted); + + AddNodeReshape(group_sum_adjusted_out, + group_sum_adjusted_4D, + guid_ + "reshaped_group_sum_adjusted"); + + auto reshaped_sum_adjusted = + createTensorNoPresist("reshaped_sum_adjusted", dtype_, block_max_dims); + std::vector sum_adjusted_4D; + sum_adjusted_4D.push_back(reshaped_sum_adjusted); + + AddNodeReshape( + sum_adjusted_out, sum_adjusted_4D, guid_ + "reshaped_sum_adjusted"); + + auto reshaped_block_adjustment = createTensorNoPresist( + "reshaped_block_adjustment", dtype_, block_max_dims); + std::vector block_adjustment_4D; + block_adjustment_4D.push_back(reshaped_block_adjustment); + + AddNodeReshape(block_adjustment_out, + block_adjustment_4D, + guid_ + "reshaped_block_adjustment"); + + std::vector max_sum_adjust_in; + max_sum_adjust_in.push_back(reshaped_group_sum_adjusted); + max_sum_adjust_in.push_back(reshaped_sum_adjusted); + auto max_sum_adjust = + createTensorNoPresist("max_sum_adjust", dtype_, block_max_dims); + std::vector max_sum_adjust_out; + max_sum_adjust_out.push_back(max_sum_adjust); + + AddNodeMaximum( + max_sum_adjust_in, max_sum_adjust_out, guid_ + "max_sum_adjust"); + + std::vector rescale_in; + rescale_in.push_back(reshaped_block_adjustment); + rescale_in.push_back(max_sum_adjust); + auto rescale = createTensorNoPresist("rescale", dtype_, block_max_dims); + std::vector rescale_out; + rescale_out.push_back(rescale); + AddNodeDivide(rescale_in, rescale_out, guid_ + "div_rescale"); + + std::vector rescale_v_in; + rescale_v_in.push_back(rescale); + rescale_v_in.push_back(score_v); + + auto rescale_v = + createTensorNoPresist("rescale_v", dtype_, reshape_map_q_dims); + std::vector rescale_v_out; + rescale_v_out.push_back(rescale_v); + + AddNodeMultiply(rescale_v_in, rescale_v_out, guid_ + "mul_rescale_v"); + + auto reshape_attn = + createTensorNoPresist("reshape_attn", dtype_, map_q_dims); + std::vector reshape_attn_out; + reshape_attn_out.push_back(reshape_attn); + + AddNodeReshape(rescale_v_out, reshape_attn_out, guid_ + "reshape_attn"); + + std::vector map_attn_in; + map_attn_in.push_back(block_mapping); + map_attn_in.push_back(reshape_attn); + + auto mapped_attn = + createTensorNoPresist("mapped_attn", dtype_, reshape_q_dims); + std::vector map_attn_out; + map_attn_out.push_back(mapped_attn); + + AddNodeGemm( + map_attn_in, map_attn_out, gemm_params_t_f, guid_ + "gemm_map_attn"); + + std::vector reshape_attn_dims; + reshape_attn_dims.push_back(batch_size); + reshape_attn_dims.push_back(1); + reshape_attn_dims.push_back(hidden_size); + auto attn = createTensorNoPresist("attn", dtype_, reshape_attn_dims); + std::vector attn_out; + attn_out.push_back(attn); + + AddNodeReshape(map_attn_out, attn_out, guid_ + "attn"); + + std::vector proj_in; + auto linear_weights = createTensorFromCT(&ct, 13); + proj_in.push_back(attn); + proj_in.push_back(linear_weights); + + auto linear_out = createTensorFromCT(&ct, 0, false); + std::vector proj_out; + proj_out.push_back(linear_out); + + AddNodeBatchGemm( + proj_in, proj_out, gemm_params_f_f, guid_ + "batchgemm_proj"); + } + + protected: + synDataType dtype_; +}; + +class FusedGQABlockAttention : public HpuFusedOperator { + public: + explicit FusedGQABlockAttention(synDataType dtype) + : HpuFusedOperator("fused_block_attention_fwd_", false), dtype_(dtype) {} + template + void AddNode(ConvertTensors& ct, FusedBlockAttentionParams& params) { + auto ins = ct.GetTensors(); + auto outs = ct.GetTensors(false); + + std::vector src_dims = std::vector(ins[0].dims); + + int64_t batch_size = src_dims[0]; + int64_t seq_length = src_dims[1]; + int64_t hidden_size = ins[13].dims[0]; + int64_t block_size = ins[3].dims[1]; + int64_t num_of_block = ins[6].dims[0]; + + int64_t num_head = params.num_head; + int64_t head_dim = params.head_dim; + int64_t num_kv_head = params.num_kv_head; + int64_t ngroups = num_head / num_kv_head; + + synGEMMParams gemm_params_f_f; + gemm_params_f_f.transpose_a = false; + gemm_params_f_f.transpose_b = false; + + synGEMMParams gemm_params_t_f; + gemm_params_t_f.transpose_a = true; + gemm_params_t_f.transpose_b = false; + + synGEMMParams gemm_params_f_t; + gemm_params_f_t.transpose_a = false; + gemm_params_f_t.transpose_b = true; + + synSectionHandle residual_section = createSection(); + auto src = createTensorFromCT(&ct, 0); + auto residual = createTensorFromCT(&ct, 1, true, residual_section); + auto residual_out = createTensorFromCT(&ct, 3, false, residual_section); + + std::vector add_residual_in; + add_residual_in.push_back(src); + add_residual_in.push_back(residual); + + std::vector add_residual_out; + add_residual_out.push_back(residual_out); + + AddNodeAdd(add_residual_in, add_residual_out, guid_ + "add_residual"); + + auto ln_scales = createTensorFromCT(&ct, 11); + + std::vector rmsnorm_inputs; + rmsnorm_inputs.push_back(residual_out); + rmsnorm_inputs.push_back(ln_scales); + + auto tmp_dims = src_dims; + tmp_dims[2] = 1; + auto norm_out = createTensorNoPresist("norm_out", dtype_, src_dims); + auto norm_var = createTensorNoPresist("norm_var", dtype_, tmp_dims); + + std::vector rmsnorm_outputs; + rmsnorm_outputs.push_back(norm_out); + rmsnorm_outputs.push_back(norm_var); + + AddNodeRmsNorm(rmsnorm_inputs, + rmsnorm_outputs, + params.rmsnorm_params, + guid_ + "rmsnorm"); + + auto qkv_weights = createTensorFromCT(&ct, 12); + std::vector mul_inputs; + mul_inputs.push_back(norm_out); + mul_inputs.push_back(qkv_weights); + + auto wt_dims = ins[12].dims; + tmp_dims[2] = wt_dims[0]; + + auto qkv_out = createTensorNoPresist("qkv_out", dtype_, tmp_dims); + std::vector mul_outputs; + mul_outputs.push_back(qkv_out); + + synGEMMParams gemm_params; + gemm_params.transpose_a = false; + gemm_params.transpose_b = true; + AddNodeBatchGemm(mul_inputs, mul_outputs, gemm_params, guid_ + "batchgemm"); + + auto reshape_dims = src_dims; + reshape_dims[2] = num_head + 2 * num_kv_head; + reshape_dims.push_back(head_dim); + + std::vector reshape_outputs; + auto reshape_out = + createTensorNoPresist("reshape_out", dtype_, reshape_dims); + reshape_outputs.push_back(reshape_out); + + AddNodeReshape(mul_outputs, reshape_outputs, guid_ + "reshape_qkv"); + + std::vector q_dims; + q_dims.push_back(batch_size); + q_dims.push_back(seq_length); + q_dims.push_back(num_head); + q_dims.push_back(head_dim); + std::vector kv_dims; + kv_dims.push_back(batch_size); + kv_dims.push_back(seq_length); + kv_dims.push_back(num_kv_head); + kv_dims.push_back(head_dim); + + auto q_split = createTensorNoPresist("q_split", dtype_, q_dims); + auto k_split = createTensorNoPresist("k_split", dtype_, kv_dims); + auto v_split = createTensorNoPresist("v_split", dtype_, kv_dims); + std::vector split_outpus; + split_outpus.push_back(q_split); + split_outpus.push_back(k_split); + split_outpus.push_back(v_split); + + synSplitParams splitParams; + splitParams.axis = 1; + AddNodeSplit(reshape_outputs, split_outpus, splitParams, guid_ + "split"); + + std::vector rotary_embs_inputs; + auto rotary_embs_c = createTensorFromCT(&ct, 2); + rotary_embs_inputs.push_back(rotary_embs_c); + + auto rotary_embs_dims = ins[2].dims; + rotary_embs_dims[0] = 1; + + std::vector cos_inputs; + auto cos_in = createTensorNoPresist("cos_in", dtype_, rotary_embs_dims); + cos_inputs.push_back(cos_in); + + synSliceParamsV2 sliceParams; + for (uint64_t i = 0; i < rotary_embs_dims.size(); i++) { + sliceParams.axes[i] = i; + sliceParams.steps[i] = 1; + sliceParams.starts[i] = 0; + sliceParams.ends[i] = rotary_embs_dims[rotary_embs_dims.size() - 1 - i]; + } + AddNodeSlice( + rotary_embs_inputs, cos_inputs, sliceParams, guid_ + "slice_cos"); + + std::vector sin_inputs; + auto sin_in = createTensorNoPresist("sin_in", dtype_, rotary_embs_dims); + sin_inputs.push_back(sin_in); + sliceParams.starts[rotary_embs_dims.size() - 1] = 1; + sliceParams.ends[rotary_embs_dims.size() - 1] = 2; + AddNodeSlice( + rotary_embs_inputs, sin_inputs, sliceParams, guid_ + "slice_sin"); + + rotary_embs_dims.erase(rotary_embs_dims.begin()); + auto sin_sq = + createTensorNoPresist("sin_squeezed", dtype_, rotary_embs_dims); + std::vector sin_squeezed; + sin_squeezed.push_back(sin_sq); + + synSqueezeParams squeezeParams; + squeezeParams.axis = 4; + AddNodeSqueeze( + sin_inputs, sin_squeezed, squeezeParams, guid_ + "squeeze_sin"); + + auto cos_sq = + createTensorNoPresist("cos_squeezed", dtype_, rotary_embs_dims); + std::vector cos_squeezed; + cos_squeezed.push_back(cos_sq); + AddNodeSqueeze( + cos_inputs, cos_squeezed, squeezeParams, guid_ + "squeeze_cos"); + + std::vector inputs_q; + std::vector outputs_q; + inputs_q.push_back(q_split); + inputs_q.push_back(sin_sq); + inputs_q.push_back(cos_sq); + + auto q_states = createTensorNoPresist("q_states", dtype_, q_dims); + outputs_q.push_back(q_states); + + ns_RoPESt2::ParamsV2 ropeParams; + ropeParams.offset = 0; + ropeParams.mode = ROTARY_POS_EMBEDDING_MODE_BLOCKWISE; + AddNodeRope(inputs_q, outputs_q, ropeParams, guid_ + "rope_q"); + + std::vector inputs_k; + std::vector outputs_k; + inputs_k.push_back(k_split); + inputs_k.push_back(sin_sq); + inputs_k.push_back(cos_sq); + + auto k_rope = createTensorNoPresist("k_rope", dtype_, kv_dims); + outputs_k.push_back(k_rope); + AddNodeRope(inputs_k, outputs_k, ropeParams, guid_ + "rope_k"); + + ////////////////////////////////////////////////////////////////// + kv_dims.erase(kv_dims.begin() + 1); + + std::vector outputs_k_squeeze; + auto k_squeeze = createTensorNoPresist("k_squeeze", dtype_, kv_dims); + outputs_k_squeeze.push_back(k_squeeze); + AddNodeReshape(outputs_k, outputs_k_squeeze, guid_ + "squeeze_k"); + + std::vector inputs_v_squeeze; + inputs_v_squeeze.push_back(v_split); + std::vector outputs_v_squeeze; + auto v_squeeze = createTensorNoPresist("v_squeeze", dtype_, kv_dims); + outputs_v_squeeze.push_back(v_squeeze); + AddNodeReshape(inputs_v_squeeze, outputs_v_squeeze, guid_ + "squeeze_v"); + + std::vector indices_concat_dims = + std::vector(ins[9].dims); + indices_concat_dims.emplace_back(1); + + std::vector inputs_concat; + inputs_concat.push_back(createTensor(indices_concat_dims.size(), + ins[9].type, + indices_concat_dims, + true, + ins[9].name)); + inputs_concat.push_back(createTensor(indices_concat_dims.size(), + ins[10].type, + indices_concat_dims, + true, + ins[10].name)); + + std::vector outputs_concat; + indices_concat_dims.back() = 2; + auto indices_concat = createTensor(indices_concat_dims.size(), + ins[9].type, + indices_concat_dims, + false, + "indices_concat"); + outputs_concat.push_back(indices_concat); + + synConcatenateParams concatParams; + concatParams.axis = 0; + AddNodeConcat( + inputs_concat, outputs_concat, concatParams, guid_ + "concat"); + + synSectionHandle kCache_section = createSection(); + auto key_cache = createTensorFromCT(&ct, 3, true, kCache_section); + auto kCache_out = createTensorFromCT(&ct, 1, false, kCache_section); + std::vector inputs_scatter_k; + inputs_scatter_k.push_back(key_cache); + inputs_scatter_k.push_back(indices_concat); + inputs_scatter_k.push_back(k_squeeze); + std::vector outputs_scatter_k; + outputs_scatter_k.push_back(kCache_out); + AddNodeScatter( + inputs_scatter_k, outputs_scatter_k, guid_ + "index_put_k"); + + synSectionHandle vCache_section = createSection(); + auto value_cache = createTensorFromCT(&ct, 4, true, vCache_section); + auto vCache_out = createTensorFromCT(&ct, 2, false, vCache_section); + std::vector inputs_scatter_v; + inputs_scatter_v.push_back(value_cache); + inputs_scatter_v.push_back(indices_concat); + inputs_scatter_v.push_back(v_squeeze); + std::vector outputs_scatter_v; + outputs_scatter_v.push_back(vCache_out); + AddNodeScatter( + inputs_scatter_v, outputs_scatter_v, guid_ + "index_put_v"); + ////////////////////////////////////////////////////////////////// + + std::vector scaler_dims = {1}; + auto scaler_tensor = + createTensorNoPresist("scaler_tensor", syn_type_bf16, scaler_dims); + std::vector scaler; + scaler.push_back(scaler_tensor); + AddNodeFull(scaler, params.const_params, guid_ + "full_scale"); + + std::vector scaled_q_in; + scaled_q_in.push_back(q_states); + scaled_q_in.push_back(scaler_tensor); + + auto scaled_q = createTensorNoPresist("scaled_q", dtype_, q_dims); + std::vector scaled_q_out; + scaled_q_out.push_back(scaled_q); + + AddNodeMultiply(scaled_q_in, scaled_q_out, guid_ + "mul_scale_q"); + + std::vector reshape_q_dims; + reshape_q_dims.push_back(batch_size); + reshape_q_dims.push_back(hidden_size); + + auto reshaped_q = + createTensorNoPresist("reshaped_q", dtype_, reshape_q_dims); + std::vector reshape_q_out; + reshape_q_out.push_back(reshaped_q); + + AddNodeReshape(scaled_q_out, reshape_q_out, guid_ + "reshape_scale_q"); + + /*******************************/ + + std::vector map_q_in; + auto block_mapping = createTensorFromCT(&ct, 7); + map_q_in.push_back(block_mapping); + map_q_in.push_back(reshaped_q); + + std::vector map_q_dims; + map_q_dims.push_back(num_of_block); + map_q_dims.push_back(hidden_size); + auto mapped_q = createTensorNoPresist("mapped_q", dtype_, map_q_dims); + std::vector map_q_out; + map_q_out.push_back(mapped_q); + + AddNodeGemm(map_q_in, map_q_out, gemm_params_f_f, guid_ + "gemm_map_q"); + + std::vector reshape_map_q_dims; + reshape_map_q_dims.push_back(num_of_block); + reshape_map_q_dims.push_back(num_kv_head); + reshape_map_q_dims.push_back(ngroups); + reshape_map_q_dims.push_back(1); + reshape_map_q_dims.push_back(head_dim); + + auto reshaped_map_q = + createTensorNoPresist("reshaped_map_q", dtype_, reshape_map_q_dims); + std::vector reshape_map_q_out; + reshape_map_q_out.push_back(reshaped_map_q); + + AddNodeReshape(map_q_out, reshape_map_q_out, guid_ + "reshape_map_q"); + + /*******************************/ + + std::vector index_select_k_in; + std::vector index_select_v_in; + + auto block_list = createTensorFromCT(&ct, 6); + + index_select_k_in.push_back(kCache_out); + index_select_v_in.push_back(vCache_out); + index_select_k_in.push_back(block_list); + index_select_v_in.push_back(block_list); + + std::vector index_selected_dims; + index_selected_dims.push_back(num_of_block); + index_selected_dims.push_back(block_size); + index_selected_dims.push_back(num_kv_head); + index_selected_dims.push_back(head_dim); + + auto index_select_k_i = + createTensorNoPresist("index_select_k_i", dtype_, index_selected_dims); + auto index_select_v_i = + createTensorNoPresist("index_select_v_i", dtype_, index_selected_dims); + std::vector index_select_k_out; + index_select_k_out.push_back(index_select_k_i); + std::vector index_select_v_out; + index_select_v_out.push_back(index_select_v_i); + + AddNodeIndexSelect(index_select_k_in, + index_select_k_out, + params.index_select_params, + guid_ + "index_select_k_i"); + AddNodeIndexSelect(index_select_v_in, + index_select_v_out, + params.index_select_params, + guid_ + "index_select_v_i"); + + std::vector axis = {0, 2, 1, 3}; + synTransposeParams trans_params; + for (size_t i = 0; i < axis.size(); i++) { + trans_params.permutation[i] = + static_cast(axis[i]); + } + trans_params.tensorDim = 4; + + std::vector transpose_dims; + transpose_dims.push_back(num_of_block); + transpose_dims.push_back(num_kv_head); + transpose_dims.push_back(block_size); + transpose_dims.push_back(head_dim); + + auto transpose_k = + createTensorNoPresist("transpose_k", dtype_, transpose_dims); + std::vector trans_index_select_k; + trans_index_select_k.push_back(transpose_k); + + AddNodeTranspose(index_select_k_out, + trans_index_select_k, + trans_params, + guid_ + "transpose_k"); + + auto transpose_v = + createTensorNoPresist("transpose_v", dtype_, transpose_dims); + std::vector trans_index_select_v; + trans_index_select_v.push_back(transpose_v); + + AddNodeTranspose(index_select_v_out, + trans_index_select_v, + trans_params, + guid_ + "transpose_v"); + + std::vector reshape_kv_dims; + reshape_kv_dims.push_back(num_of_block); + reshape_kv_dims.push_back(num_kv_head); + reshape_kv_dims.push_back(1); + reshape_kv_dims.push_back(block_size); + reshape_kv_dims.push_back(head_dim); + + auto index_select_k = + createTensorNoPresist("index_select_k", dtype_, reshape_kv_dims); + std::vector reshape_index_select_k; + reshape_index_select_k.push_back(index_select_k); + + AddNodeReshape( + trans_index_select_k, reshape_index_select_k, guid_ + "reshape_k"); + + auto index_select_v = + createTensorNoPresist("index_select_v", dtype_, reshape_kv_dims); + std::vector reshape_index_select_v; + reshape_index_select_v.push_back(index_select_v); + + AddNodeReshape( + trans_index_select_v, reshape_index_select_v, guid_ + "reshape_v"); + + std::vector q_k_in; + q_k_in.push_back(reshaped_map_q); + q_k_in.push_back(index_select_k); + + std::vector q_k_dims; + q_k_dims.push_back(num_of_block); + q_k_dims.push_back(num_kv_head); + q_k_dims.push_back(ngroups); + q_k_dims.push_back(1); + q_k_dims.push_back(block_size); + auto q_k = createTensorNoPresist("q_k", dtype_, q_k_dims); + std::vector q_k_out; + q_k_out.push_back(q_k); + + AddNodeBatchGemm(q_k_in, q_k_out, gemm_params_f_t, guid_ + "batchgemm_q_k"); + + /*******************************/ + + auto block_bias = createTensorFromCT(&ct, 8); + std::vector block_bias_in; + block_bias_in.push_back(block_bias); + + std::vector reshaped_bias_dims; + reshaped_bias_dims.push_back(num_of_block); + reshaped_bias_dims.push_back(1); + reshaped_bias_dims.push_back(1); + reshaped_bias_dims.push_back(1); + reshaped_bias_dims.push_back(block_size); + + auto reshaped_bias = + createTensorNoPresist("reshaped_bias", dtype_, reshaped_bias_dims); + std::vector block_bias_out; + block_bias_out.push_back(reshaped_bias); + + AddNodeReshape(block_bias_in, block_bias_out, guid_ + "reshaped_bias"); + + std::vector add_bias_in; + add_bias_in.push_back(q_k); + add_bias_in.push_back(reshaped_bias); + + auto add_bias = createTensorNoPresist("add_bias", dtype_, q_k_dims); + std::vector add_bias_out; + add_bias_out.push_back(add_bias); + + AddNodeAdd(add_bias_in, add_bias_out, guid_ + "add_bias"); + /*******************************/ + + std::vector block_max_dims; + block_max_dims.push_back(num_of_block); + block_max_dims.push_back(num_kv_head); + block_max_dims.push_back(ngroups); + block_max_dims.push_back(1); + block_max_dims.push_back(1); + + auto block_max = createTensorNoPresist("block_max", dtype_, block_max_dims); + std::vector block_max_out; + block_max_out.push_back(block_max); + + AddNodeReduceMax( + add_bias_out, block_max_out, params.reduce_params, guid_ + "reduceMax"); + + /**************************************************************/ + + std::vector sum_adjusted_dims; + sum_adjusted_dims.push_back(num_of_block); + sum_adjusted_dims.push_back(num_kv_head); + sum_adjusted_dims.push_back(ngroups); + + auto block_max_2D = + createTensorNoPresist("block_max_2D", dtype_, sum_adjusted_dims); + std::vector block_max_2D_out; + block_max_2D_out.push_back(block_max_2D); + + AddNodeReshape( + block_max_out, block_max_2D_out, guid_ + "squeeze_block_max"); + + std::vector group_max_dims; + group_max_dims.push_back(batch_size + 1); + group_max_dims.push_back(num_kv_head); + group_max_dims.push_back(ngroups); + + auto group_max = createTensorNoPresist("group_max", dtype_, group_max_dims); + std::vector group_max_tensor; + group_max_tensor.push_back(group_max); + + params.const_params.constant.f = -std::numeric_limits::infinity(); + + AddNodeFull(group_max_tensor, params.const_params, guid_ + "full_inf"); + + auto block_groups = createTensorFromCT(&ct, 5); + std::vector index_reduce_in; + index_reduce_in.push_back(group_max); + index_reduce_in.push_back(block_groups); + index_reduce_in.push_back(block_max_2D); + + auto reduced_group_max = + createTensorNoPresist("reduced_group_max", dtype_, group_max_dims); + std::vector index_reduce_out; + index_reduce_out.push_back(reduced_group_max); + + AddNodeIndexReduce(index_reduce_in, + index_reduce_out, + params.index_reduce_params, + guid_ + "index_reduce_amax"); + + std::vector index_select_groupmax_in; + index_select_groupmax_in.push_back(reduced_group_max); + index_select_groupmax_in.push_back(block_groups); + + auto selected_group_max = + createTensorNoPresist("selected_group_max", dtype_, sum_adjusted_dims); + std::vector index_select_groupmax_out; + index_select_groupmax_out.push_back(selected_group_max); + params.index_select_params.axis = 2; + AddNodeIndexSelect(index_select_groupmax_in, + index_select_groupmax_out, + params.index_select_params, + guid_ + "index_select_groupmax"); + + std::vector sub_group_max_in; + sub_group_max_in.push_back(block_max_2D); + sub_group_max_in.push_back(selected_group_max); + + auto sub_group_max = + createTensorNoPresist("sub_group_max", dtype_, sum_adjusted_dims); + std::vector sub_group_max_out; + sub_group_max_out.push_back(sub_group_max); + + AddNodeSub(sub_group_max_in, sub_group_max_out, guid_ + "sub_group_max"); + + auto block_adjustment = + createTensorNoPresist("block_adjustment", dtype_, sum_adjusted_dims); + std::vector block_adjustment_out; + block_adjustment_out.push_back(block_adjustment); + AddNodeExp(sub_group_max_out, + block_adjustment_out, + guid_ + "exp_block_adjustment"); + + /**************************************************************/ + + std::vector sub_block_max_in; + sub_block_max_in.push_back(add_bias); + sub_block_max_in.push_back(block_max); + + auto sub_block_max = + createTensorNoPresist("sub_block_max", dtype_, q_k_dims); + std::vector sub_block_max_out; + sub_block_max_out.push_back(sub_block_max); + AddNodeSub(sub_block_max_in, sub_block_max_out, guid_ + "sub_block_max"); + + auto score = createTensorNoPresist("score", dtype_, q_k_dims); + std::vector score_out; + score_out.push_back(score); + AddNodeExp(sub_block_max_out, score_out, guid_ + "exp_score"); + + /*******************************/ + + std::vector score_v_in; + score_v_in.push_back(score); + score_v_in.push_back(index_select_v); + + auto score_v = createTensorNoPresist("score_v", dtype_, reshape_map_q_dims); + std::vector score_v_out; + score_v_out.push_back(score_v); + + AddNodeBatchGemm( + score_v_in, score_v_out, gemm_params_f_f, guid_ + "batchgemm_score_v"); + + auto reduceSum = createTensorNoPresist("reduceSum", dtype_, block_max_dims); + std::vector reduceSum_out; + reduceSum_out.push_back(reduceSum); + + AddNodeReduceSum( + score_out, reduceSum_out, params.reduce_params, guid_ + "reduceSum"); + + auto block_sums_2D = + createTensorNoPresist("block_sums_2D", dtype_, sum_adjusted_dims); + std::vector block_sums_2D_out; + block_sums_2D_out.push_back(block_sums_2D); + + AddNodeReshape( + reduceSum_out, block_sums_2D_out, guid_ + "squeeze_block_sums"); + + std::vector sum_adjusted_in; + sum_adjusted_in.push_back(block_sums_2D); + sum_adjusted_in.push_back(block_adjustment); + + auto sum_adjusted = + createTensorNoPresist("sum_adjusted", dtype_, sum_adjusted_dims); + std::vector sum_adjusted_out; + sum_adjusted_out.push_back(sum_adjusted); + + AddNodeMultiply( + sum_adjusted_in, sum_adjusted_out, guid_ + "mul_sum_adjusted"); + /***************************************************************/ + + std::vector reshaped_sum_adjusted_dims; + reshaped_sum_adjusted_dims.push_back(num_of_block); + reshaped_sum_adjusted_dims.push_back(num_head); + + auto flatten_sum_adjusted = createTensorNoPresist( + "flatten_sum_adjusted", dtype_, reshaped_sum_adjusted_dims); + std::vector reshaped_sum_adjusted_out; + reshaped_sum_adjusted_out.push_back(flatten_sum_adjusted); + + AddNodeReshape(sum_adjusted_out, + reshaped_sum_adjusted_out, + guid_ + "flatten_sum_adjusted"); + + std::vector map_sum_adjusted_in; + map_sum_adjusted_in.push_back(block_mapping); + map_sum_adjusted_in.push_back(flatten_sum_adjusted); + + std::vector map_sum_adjusted_dims; + map_sum_adjusted_dims.push_back(batch_size); + map_sum_adjusted_dims.push_back(num_head); + auto mapped_sum_adjusted = createTensorNoPresist( + "mapped_sum_adjusted", dtype_, map_sum_adjusted_dims); + std::vector map_sum_adjusted_out; + map_sum_adjusted_out.push_back(mapped_sum_adjusted); + + AddNodeGemm(map_sum_adjusted_in, + map_sum_adjusted_out, + gemm_params_t_f, + guid_ + "gemm_map_sum_adjusted"); + + std::vector group_sum_adjusted_in; + group_sum_adjusted_in.push_back(block_mapping); + group_sum_adjusted_in.push_back(mapped_sum_adjusted); + + std::vector matmul_sum_adjusted_dims; + matmul_sum_adjusted_dims.push_back(num_of_block); + matmul_sum_adjusted_dims.push_back(num_head); + auto group_sum_adjusted = createTensorNoPresist( + "group_sum_adjusted", dtype_, matmul_sum_adjusted_dims); + std::vector group_sum_adjusted_out; + group_sum_adjusted_out.push_back(group_sum_adjusted); + + AddNodeGemm(group_sum_adjusted_in, + group_sum_adjusted_out, + gemm_params_f_f, + guid_ + "gemm_group_sum_adjusted"); + + /*******************************/ + + auto reshaped_group_sum_adjusted = createTensorNoPresist( + "reshaped_group_sum_adjusted", dtype_, block_max_dims); + std::vector group_sum_adjusted_4D; + group_sum_adjusted_4D.push_back(reshaped_group_sum_adjusted); + + AddNodeReshape(group_sum_adjusted_out, + group_sum_adjusted_4D, + guid_ + "reshaped_group_sum_adjusted"); + + auto reshaped_sum_adjusted = + createTensorNoPresist("reshaped_sum_adjusted", dtype_, block_max_dims); + std::vector sum_adjusted_4D; + sum_adjusted_4D.push_back(reshaped_sum_adjusted); + + AddNodeReshape( + sum_adjusted_out, sum_adjusted_4D, guid_ + "reshaped_sum_adjusted"); + + auto reshaped_block_adjustment = createTensorNoPresist( + "reshaped_block_adjustment", dtype_, block_max_dims); + std::vector block_adjustment_4D; + block_adjustment_4D.push_back(reshaped_block_adjustment); + + AddNodeReshape(block_adjustment_out, + block_adjustment_4D, + guid_ + "reshaped_block_adjustment"); + + std::vector max_sum_adjust_in; + max_sum_adjust_in.push_back(reshaped_group_sum_adjusted); + max_sum_adjust_in.push_back(reshaped_sum_adjusted); + auto max_sum_adjust = + createTensorNoPresist("max_sum_adjust", dtype_, block_max_dims); + std::vector max_sum_adjust_out; + max_sum_adjust_out.push_back(max_sum_adjust); + + AddNodeMaximum( + max_sum_adjust_in, max_sum_adjust_out, guid_ + "max_sum_adjust"); + + std::vector rescale_in; + rescale_in.push_back(reshaped_block_adjustment); + rescale_in.push_back(max_sum_adjust); + auto rescale = createTensorNoPresist("rescale", dtype_, block_max_dims); + std::vector rescale_out; + rescale_out.push_back(rescale); + AddNodeDivide(rescale_in, rescale_out, guid_ + "div_rescale"); + + std::vector rescale_v_in; + rescale_v_in.push_back(rescale); + rescale_v_in.push_back(score_v); + + auto rescale_v = + createTensorNoPresist("rescale_v", dtype_, reshape_map_q_dims); + std::vector rescale_v_out; + rescale_v_out.push_back(rescale_v); + + AddNodeMultiply(rescale_v_in, rescale_v_out, guid_ + "mul_rescale_v"); + + auto reshape_attn = + createTensorNoPresist("reshape_attn", dtype_, map_q_dims); + std::vector reshape_attn_out; + reshape_attn_out.push_back(reshape_attn); + + AddNodeReshape(rescale_v_out, reshape_attn_out, guid_ + "reshape_attn"); + + std::vector map_attn_in; + map_attn_in.push_back(block_mapping); + map_attn_in.push_back(reshape_attn); + + auto mapped_attn = + createTensorNoPresist("mapped_attn", dtype_, reshape_q_dims); + std::vector map_attn_out; + map_attn_out.push_back(mapped_attn); + + AddNodeGemm( + map_attn_in, map_attn_out, gemm_params_t_f, guid_ + "gemm_map_attn"); + + std::vector reshape_attn_dims; + reshape_attn_dims.push_back(batch_size); + reshape_attn_dims.push_back(1); + reshape_attn_dims.push_back(hidden_size); + auto attn = createTensorNoPresist("attn", dtype_, reshape_attn_dims); + std::vector attn_out; + attn_out.push_back(attn); + + AddNodeReshape(map_attn_out, attn_out, guid_ + "attn"); + + std::vector proj_in; + auto linear_weights = createTensorFromCT(&ct, 13); + proj_in.push_back(attn); + proj_in.push_back(linear_weights); + + auto linear_out = createTensorFromCT(&ct, 0, false); + std::vector proj_out; + proj_out.push_back(linear_out); + + AddNodeBatchGemm( + proj_in, proj_out, gemm_params_f_f, guid_ + "batchgemm_proj"); + } + + protected: + synDataType dtype_; +}; + +template +void FusedBlockAttentionKernel(const Context& dev_ctx, + const phi::DenseTensor& src, + const phi::DenseTensor& residual, + const phi::DenseTensor& rotary_embs, + const phi::DenseTensor& key_cache, + const phi::DenseTensor& value_cache, + const phi::DenseTensor& block_groups, + const phi::DenseTensor& block_list, + const phi::DenseTensor& block_mapping, + const phi::DenseTensor& block_bias, + const phi::DenseTensor& block_indices, + const phi::DenseTensor& block_offsets, + const phi::DenseTensor& ln_scales, + const phi::DenseTensor& qkv_weights, + const phi::DenseTensor& linear_weights, + phi::DenseTensor* out_linear, + const phi::Scalar& epsilon, + const phi::Scalar& head_dim, + const phi::Scalar& num_head, + const phi::Scalar& scaling_factor) { + std::vector src_dims = phi::vectorize(src.dims()); + std::vector qkv_weights_dims = + phi::vectorize(qkv_weights.dims()); + + int head_dim_ = head_dim.to(); + int num_head_ = num_head.to(); + const int64_t fused_hidden_size = qkv_weights_dims[0]; + const int num_kv_head = + (fused_hidden_size - num_head_ * head_dim_) / head_dim_ / 2; + + ConvertTensors ct; + ct.Add(src); + ct.Add(residual); + ct.Add(rotary_embs); + ct.Add(key_cache); + ct.Add(value_cache); + ct.Add(block_groups); + std::vector inputs_dims = ct.GetDims(); + ct.Add(block_list); + ct.Add(block_mapping); + ct.Add(block_bias); + ct.Add(block_indices); + ct.Add(block_offsets); + ct.Add(ln_scales); + ct.Add(qkv_weights); + ct.Add(linear_weights); + ct.Add(out_linear, false); + ct.Add(key_cache, false); + ct.Add(value_cache, false); + ct.Add(residual, false); + + OpCacheOperator op_info; + op_info.prepareOpInfo( + "fused_block_attention_fwd_", inputs_dims, nullptr); + auto recipe = op_info.GetRecipe(); + + if (recipe == nullptr) { + FusedBlockAttentionParams params; + memset(reinterpret_cast(¶ms), + 0x00, + sizeof(FusedBlockAttentionParams)); + params.rmsnorm_params.epsValid = true; + params.rmsnorm_params.eps = epsilon.to(); + params.const_params.constant.f = scaling_factor.to(); + params.index_select_params.axis = 3; + params.reduce_params.reductionDimension = 0; + params.index_reduce_params.mode = INDEX_REDUCE_AMAX; + params.index_reduce_params.include_self = true; + params.index_reduce_params.axis = 0; + params.head_dim = head_dim_; + params.num_head = num_head_; + params.num_kv_head = num_kv_head; + + if (num_head_ == num_kv_head) { + FusedMHABlockAttention op(op_info.datatype_); + op.AddNode(ct, params); + op.Compile(); + op_info.setOp(op); + } else { + FusedGQABlockAttention op(op_info.datatype_); + op.AddNode(ct, params); + op.Compile(); + op_info.setOp(op); + } + + recipe = op_info.GetRecipe(); + } + + std::map tensors = ct.GetDeviceAddr(); + RecipeRunner runner(recipe); + runner.Run(reinterpret_cast(dev_ctx.stream()), tensors); +} + +} // namespace custom_kernel + +template +void CallFusedBlockAttentionKernel(const Context& dev_ctx, + const phi::DenseTensor& src, + const phi::DenseTensor& residual, + const phi::DenseTensor& rotary_embs, + const phi::DenseTensor& key_cache, + const phi::DenseTensor& value_cache, + const phi::DenseTensor& block_groups, + const phi::DenseTensor& block_list, + const phi::DenseTensor& block_mapping, + const phi::DenseTensor& block_bias, + const phi::DenseTensor& block_indices, + const phi::DenseTensor& block_offsets, + const phi::DenseTensor& ln_scales, + const phi::DenseTensor& qkv_weights, + const phi::DenseTensor& linear_weights, + phi::DenseTensor* out_linear, + const phi::Scalar& epsilon, + const phi::Scalar& head_dim, + const phi::Scalar& num_head, + const phi::Scalar& scaling_factor) { + if (src.dtype() == phi::DataType::FLOAT16) { + custom_kernel::FusedBlockAttentionKernel( + dev_ctx, + src, + residual, + rotary_embs, + key_cache, + value_cache, + block_groups, + block_list, + block_mapping, + block_bias, + block_indices, + block_offsets, + ln_scales, + qkv_weights, + linear_weights, + out_linear, + epsilon, + head_dim, + num_head, + scaling_factor); + } else if (src.dtype() == phi::DataType::BFLOAT16) { + custom_kernel::FusedBlockAttentionKernel( + dev_ctx, + src, + residual, + rotary_embs, + key_cache, + value_cache, + block_groups, + block_list, + block_mapping, + block_bias, + block_indices, + block_offsets, + ln_scales, + qkv_weights, + linear_weights, + out_linear, + epsilon, + head_dim, + num_head, + scaling_factor); + } else { + throw std::runtime_error( + "Unsupported data type for FusedBlockAttentionKernel"); + } +} + +std::vector FusedBlockAttentionForward( + const paddle::Tensor& src, + const paddle::Tensor& residual, + const paddle::Tensor& rotary_embs, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& block_groups, + const paddle::Tensor& block_list, + const paddle::Tensor& block_mapping, + const paddle::Tensor& block_bias, + const paddle::Tensor& block_indices, + const paddle::Tensor& block_offsets, + const paddle::Tensor& ln_scales, + const paddle::Tensor& qkv_weights, + const paddle::Tensor& linear_weights, + float epsilon, + int head_dim, + int num_head, + float scaling_factor) { + auto dev_ctx = static_cast( + paddle::experimental::DeviceContextPool::Instance().Get(src.place())); + auto src_tensor = static_cast(src.impl().get()); + auto residual_tensor = + static_cast(residual.impl().get()); + auto rotary_embs_tensor = + static_cast(rotary_embs.impl().get()); + auto key_cache_tensor = + static_cast(key_cache.impl().get()); + auto value_cache_tensor = + static_cast(value_cache.impl().get()); + auto block_groups_tensor = + static_cast(block_groups.impl().get()); + auto block_list_tensor = + static_cast(block_list.impl().get()); + auto block_mapping_tensor = + static_cast(block_mapping.impl().get()); + auto block_bias_tensor = + static_cast(block_bias.impl().get()); + auto block_indices_tensor = + static_cast(block_indices.impl().get()); + auto block_offsets_tensor = + static_cast(block_offsets.impl().get()); + auto ln_scales_tensor = + static_cast(ln_scales.impl().get()); + auto qkv_weights_tensor = + static_cast(qkv_weights.impl().get()); + auto linear_weights_tensor = + static_cast(linear_weights.impl().get()); + + // allocate memory on device. + int64_t batch_size = src.dims()[0]; + int64_t out_features = linear_weights.dims()[1]; + + std::shared_ptr out_linear = + std::make_shared(); + out_linear->Resize(phi::make_ddim({batch_size, 1, out_features})); + dev_ctx->Alloc(out_linear.get(), src_tensor->dtype()); + + CallFusedBlockAttentionKernel(*dev_ctx, + *src_tensor, + *residual_tensor, + *rotary_embs_tensor, + *key_cache_tensor, + *value_cache_tensor, + *block_groups_tensor, + *block_list_tensor, + *block_mapping_tensor, + *block_bias_tensor, + *block_indices_tensor, + *block_offsets_tensor, + *ln_scales_tensor, + *qkv_weights_tensor, + *linear_weights_tensor, + out_linear.get(), + phi::Scalar(epsilon), + phi::Scalar(head_dim), + phi::Scalar(num_head), + phi::Scalar(scaling_factor)); + return {paddle::Tensor(out_linear)}; +} + +std::vector> FusedBlockAttentionShape( + const std::vector& src_shape, + const std::vector& residual_shape, + const std::vector& rotary_embs_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& block_groups_shape, + const std::vector& block_list_shape, + const std::vector& block_mapping_shape, + const std::vector& block_bias_shape, + const std::vector& block_indices_shape, + const std::vector& block_offsets_shape, + const std::vector& ln_scales_shape, + const std::vector& qkv_weights_shape, + const std::vector& linear_weights_shape, + float epsilon, + int head_dim, + int num_head, + float scaling_factor) { + int64_t batch_size = src_shape[0]; + int64_t out_features = linear_weights_shape[1]; + return {{batch_size, 1, out_features}}; +} + +std::vector FusedBlockAttentionDtype( + const paddle::DataType& src_dtype, + const paddle::DataType& residual_dtype, + const paddle::DataType& rotary_embs_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& block_groups_dtype, + const paddle::DataType& block_list_dtype, + const paddle::DataType& block_mapping_dtype, + const paddle::DataType& block_bias_dtype, + const paddle::DataType& block_indices_dtype, + const paddle::DataType& block_offsets_dtype, + const paddle::DataType& ln_scales_dtype, + const paddle::DataType& qkv_weights_dtype, + const paddle::DataType& linear_weights_dtype) { + return {src_dtype}; +} + +PD_BUILD_OP(fused_block_attention) + .Inputs({"src", + "residual", + "rotary_embs", + "key_cache", + "value_cache", + "block_groups", + "block_list", + "block_mapping", + "block_bias", + "block_indices", + "block_offsets", + "ln_scales", + "qkv_weights", + "linear_weights"}) + .Outputs({"out_linear"}) + .Attrs({"epsilon: float", + "head_dim: int", + "num_head: int", + "scaling_factor: float"}) + .SetKernelFn(PD_KERNEL(FusedBlockAttentionForward)) + .SetInferShapeFn(PD_INFER_SHAPE(FusedBlockAttentionShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(FusedBlockAttentionDtype)); diff --git a/backends/intel_hpu/custom_ops/llama_infer/fused_rms_mlp_add.cc b/backends/intel_hpu/custom_ops/llama_infer/fused_rms_mlp_add.cc new file mode 100644 index 0000000000..93814e0255 --- /dev/null +++ b/backends/intel_hpu/custom_ops/llama_infer/fused_rms_mlp_add.cc @@ -0,0 +1,284 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "habanalabs/perf_lib_layer_params.h" +#include "kernels/funcs.h" +#include "kernels/hpu_funcs.h" +#include "kernels/hpu_operator.h" +#include "paddle/extension.h" +#include "utils/utils.h" + +namespace custom_kernel { + +struct FusedRmsMlpResParams { + ns_LayerNormKernel::Params rmsnorm_params; + synSplitParams split_params; +}; + +class FusedRmsMlpRes : public HpuFusedOperator { + public: + explicit FusedRmsMlpRes(synDataType dtype) + : HpuFusedOperator("fused_rms_mlp_res_fwd_", false), dtype_(dtype) {} + template + void AddNode(ConvertTensors& ct, FusedRmsMlpResParams params) { + auto ins = ct.GetTensors(); + auto outs = ct.GetTensors(false); + + synGEMMParams gemm_params; + gemm_params.transpose_a = false; + gemm_params.transpose_b = false; + + synSectionHandle section = createSection(); + auto hidden_states = createTensorFromCT(&ct, 0); + auto residual_input = createTensorFromCT(&ct, 4, true, section); + auto residual_out = createTensorFromCT(&ct, 1, false, section); + + std::vector add_residual_in; + add_residual_in.push_back(hidden_states); + add_residual_in.push_back(residual_input); + + std::vector add_residual_out; + add_residual_out.push_back(residual_out); + + AddNodeAdd(add_residual_in, add_residual_out, guid_ + "add_residual"); + + auto ln_scales = createTensorFromCT(&ct, 1); + std::vector rmsnorm_inputs; + rmsnorm_inputs.push_back(residual_out); + rmsnorm_inputs.push_back(ln_scales); + + auto tmp_dims = ins[0].dims; + tmp_dims[2] = 1; + auto norm_out = createTensorNoPresist("norm_out", ins[0].type, ins[0].dims); + auto norm_var = createTensorNoPresist("norm_var", ins[0].type, tmp_dims); + std::vector rmsnorm_outputs; + rmsnorm_outputs.push_back(norm_out); + rmsnorm_outputs.push_back(norm_var); + + AddNodeRmsNorm(rmsnorm_inputs, + rmsnorm_outputs, + params.rmsnorm_params, + guid_ + "rmsnorm"); + + auto proj_weight = createTensorFromCT(&ct, 2); + std::vector proj_dims = { + ins[0].dims[0], ins[0].dims[1], ins[2].dims[1]}; + auto proj_out = createTensorNoPresist("proj_out", ins[0].type, proj_dims); + + std::vector proj_inputs; + proj_inputs.push_back(norm_out); + proj_inputs.push_back(proj_weight); + std::vector proj_outputs; + proj_outputs.push_back(proj_out); + + AddNodeGemm(proj_inputs, proj_outputs, gemm_params, guid_ + "gemm_up_proj"); + + std::vector split_out_dims = { + proj_dims[0], proj_dims[1], proj_dims[2] / 2}; + auto gate_out = + createTensorNoPresist("gate_out", ins[0].type, split_out_dims); + auto up_out = createTensorNoPresist("up_out", ins[0].type, split_out_dims); + auto down_weight = createTensorFromCT(&ct, 3); + + std::vector split_inputs; + split_inputs.push_back(proj_out); + std::vector split_outputs; + split_outputs.push_back(gate_out); + split_outputs.push_back(up_out); + + AddNodeSplit( + split_inputs, split_outputs, params.split_params, guid_ + "split"); + + auto silu_out = + createTensorNoPresist("silu_out", ins[0].type, split_out_dims); + std::vector silu_inputs; + silu_inputs.push_back(gate_out); + std::vector silu_outputs; + silu_outputs.push_back(silu_out); + + AddNodeSilu(silu_inputs, silu_outputs, guid_ + "silu"); + + auto multi_out = + createTensorNoPresist("multi_out", ins[0].type, split_out_dims); + std::vector multi_inputs; + multi_inputs.push_back(silu_out); + multi_inputs.push_back(up_out); + std::vector multi_outputs; + multi_outputs.push_back(multi_out); + + AddNodeMultiply(multi_inputs, multi_outputs, guid_ + "_multi"); + + auto mlp_out = createTensorFromCT(&ct, 0, false); + std::vector down_inputs; + down_inputs.push_back(multi_out); + down_inputs.push_back(down_weight); + std::vector down_outputs; + down_outputs.push_back(mlp_out); + + AddNodeGemm( + down_inputs, down_outputs, gemm_params, guid_ + "gemm_down_proj"); + } + + protected: + synDataType dtype_; +}; + +template +void FusedRmsMlpResKernel(const Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& residual, + const phi::DenseTensor& ln_scales, + const phi::DenseTensor& proj_weight, + const phi::DenseTensor& down_weight, + const phi::Scalar& epsilon, + phi::DenseTensor* out) { + // allocate memory on device. + dev_ctx.template Alloc(out); + if (out->numel() == 0) { + return; + } + + std::vector ln_scales_dims = + phi::vectorize(ln_scales.dims()); + + const phi::Scalar axis_scalar = proj_weight.dims().size() - 1; + int64_t axis = axis_scalar.to(); + if (axis < 0) { + axis = proj_weight.dims().size() + axis; + } + FusedRmsMlpResParams params; + memset(reinterpret_cast(¶ms), 0x00, sizeof(FusedRmsMlpResParams)); + params.rmsnorm_params.epsValid = true; + params.rmsnorm_params.eps = epsilon.to(); + + params.split_params = {{0}}; + params.split_params.axis = proj_weight.dims().size() - 1 - axis; + + ConvertTensors ct; + ct.Add(x); + ct.Add(ln_scales); + ct.Add(proj_weight); + ct.Add(down_weight); + ct.Add(residual); + ct.Add(*out, false); + ct.Add(residual, false); + std::vector inputs_dims = ct.GetDims(); + + OpCacheOperator op_info; + op_info.prepareOpInfo( + "FusedRmsMlpResKernel", inputs_dims, ¶ms); + auto recipe = op_info.GetRecipe(); + + if (recipe == nullptr) { + FusedRmsMlpRes op(op_info.datatype_); + op.AddNode(ct, params); + op.Compile(); + op_info.setOp(op); + + recipe = op_info.GetRecipe(); + } + + std::map tensors = ct.GetDeviceAddr(); + RecipeRunner runner(recipe); + runner.Run(reinterpret_cast(dev_ctx.stream()), tensors); +} + +} // namespace custom_kernel + +template +void CallFusedRmsMlpResKernel(const Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& residual, + const phi::DenseTensor& ln_scales, + const phi::DenseTensor& proj_weight, + const phi::DenseTensor& down_weight, + const phi::Scalar& epsilon, + phi::DenseTensor* out) { + if (x.dtype() == phi::DataType::BFLOAT16) { + custom_kernel::FusedRmsMlpResKernel(dev_ctx, + x, + residual, + ln_scales, + proj_weight, + down_weight, + epsilon, + out); + } else { + throw std::runtime_error("Unsupported data type for FusedRmsMlpResKernel"); + } +} + +std::vector FusedRmsMlpResForward( + const paddle::Tensor& x, + const paddle::Tensor& ln_scales, + const paddle::Tensor& proj_weight, + const paddle::Tensor& down_weight, + const paddle::Tensor& residual, + const float epsilon) { + auto dev_ctx = static_cast( + paddle::experimental::DeviceContextPool::Instance().Get(x.place())); + + auto x_tensor = static_cast(x.impl().get()); + auto residual_tensor = + static_cast(residual.impl().get()); + + auto ln_scales_tensor = + static_cast(ln_scales.impl().get()); + auto down_tensor = + static_cast(down_weight.impl().get()); + auto proj_tensor = + static_cast(proj_weight.impl().get()); + + auto out_tensor = std::make_shared(); + out_tensor->Resize(x_tensor->dims()); + + CallFusedRmsMlpResKernel(*dev_ctx, + *x_tensor, + *residual_tensor, + *ln_scales_tensor, + *proj_tensor, + *down_tensor, + phi::Scalar(epsilon), + out_tensor.get()); + + paddle::Tensor out(out_tensor); + + return {out}; +} + +std::vector> FusedRmsMlpResInferShape( + const std::vector& x_shape, + const std::vector& ln_scales_shape, + const std::vector& proj_weight_shape, + const std::vector& down_weight_shape, + const std::vector& residual_shape) { + return {x_shape, residual_shape}; +} + +std::vector FusedRmsMlpResInferDtype( + const paddle::DataType& x_dtype, + const paddle::DataType& ln_scales_dtype, + const paddle::DataType& proj_weight_dtype, + const paddle::DataType& down_weight_dtype, + const paddle::DataType& residual_dtype) { + return {x_dtype, residual_dtype}; +} + +PD_BUILD_OP(fused_rms_mlp_res) + .Inputs({"x", "ln_scales", "proj_weight", "down_weight", "residual_in"}) + .Outputs({"out"}) + .Attrs({"epsilon: float"}) + .SetKernelFn(PD_KERNEL(FusedRmsMlpResForward)) + .SetInferShapeFn(PD_INFER_SHAPE(FusedRmsMlpResInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(FusedRmsMlpResInferDtype)); diff --git a/backends/intel_hpu/custom_ops/llama_infer/fused_rms_qkv_rope_t.cc b/backends/intel_hpu/custom_ops/llama_infer/fused_rms_qkv_rope_t.cc index d6873cb41a..e542b38051 100644 --- a/backends/intel_hpu/custom_ops/llama_infer/fused_rms_qkv_rope_t.cc +++ b/backends/intel_hpu/custom_ops/llama_infer/fused_rms_qkv_rope_t.cc @@ -14,6 +14,7 @@ #include "habanalabs/perf_lib_layer_params.h" #include "kernels/funcs.h" +#include "kernels/hpu_funcs.h" #include "kernels/hpu_operator.h" #include "paddle/extension.h" #include "utils/utils.h" @@ -28,176 +29,100 @@ struct FusedRmsQkvRopeParams { int kv_num_head; }; -class FusedRmsQkvRopeT : public HpuOperator { +class FusedRmsQkvRopeT : public HpuFusedOperator { public: explicit FusedRmsQkvRopeT(synDataType dtype) - : HpuOperator("fused_rms_qkv_rope_t_fwd_"), dtype_(dtype) {} - - void AddNode(const std::vector& ins, - const std::vector& outs, - FusedRmsQkvRopeParams& params) { - synStatus status = synFail; - - std::string name_reshape = guid_ + "reshape"; - std::string name_concat = guid_ + "concat"; - std::string name_rmsnorm = guid_ + "rmsnorm"; - std::string name_rope = guid_ + "rope"; - - std::string guid_reshape = "reshape"; - std::string guid_concat = "concat"; - std::string guid_rmsnorm = "rms_norm_ex_fwd_"; - std::string guid_rope = "rotary_pos_embedding_fwd_"; - - // std::string guid_rope = "rope_st2_fwd_"; - - if (dtype_ == syn_type_fp16) { - guid_rmsnorm = guid_rmsnorm + "f16"; - guid_rope = guid_rope + "f16"; - } else if (dtype_ == syn_type_bf16) { - guid_rmsnorm = guid_rmsnorm + "bf16"; - guid_rope = guid_rope + "bf16"; - } + : HpuFusedOperator("fused_rms_qkv_rope_t_fwd_"), dtype_(dtype) {} + template + void AddNode(ConvertTensors& ct, FusedRmsQkvRopeParams& params) { + auto ins = ct.GetTensors(); + auto outs = ct.GetTensors(false); + + synSectionHandle section = createSection(); + auto src = createTensorFromCT(&ct, 0); + auto residual = createTensorFromCT(&ct, 4, true, section); + auto residual_out = createTensorFromCT(&ct, 2, false, section); + + std::vector add_residual_in; + add_residual_in.push_back(src); + add_residual_in.push_back(residual); + + std::vector add_residual_out; + add_residual_out.push_back(residual_out); - auto src = createTensor(ins[0].size(), dtype_, ins[0], true, "src"); - auto ln_scales = - createTensor(ins[1].size(), dtype_, ins[1], true, "ln_scales"); + AddNodeAdd(add_residual_in, add_residual_out, guid_ + "add_residual"); + + auto ln_scales = createTensorFromCT(&ct, 1); std::vector rmsnorm_inputs; - rmsnorm_inputs.push_back(src); + rmsnorm_inputs.push_back(residual_out); rmsnorm_inputs.push_back(ln_scales); - auto tmp_dims = ins[0]; + auto tmp_dims = ins[0].dims; tmp_dims[2] = 1; - auto norm_out = - createTensor(ins[0].size(), dtype_, ins[0], false, "norm_out"); - auto norm_var = - createTensor(tmp_dims.size(), dtype_, tmp_dims, false, "norm_var"); + auto norm_out = createTensorNoPresist("norm_out", dtype_, ins[0].dims); + auto norm_var = createTensorNoPresist("norm_var", dtype_, tmp_dims); std::vector rmsnorm_outputs; rmsnorm_outputs.push_back(norm_out); rmsnorm_outputs.push_back(norm_var); - status = synNodeCreate(graphHandle_, - rmsnorm_inputs.data(), - rmsnorm_outputs.data(), - 2, - 2, - ¶ms.rmsnorm_params, - sizeof(params.rmsnorm_params), - guid_rmsnorm.c_str(), - name_rmsnorm.c_str(), - nullptr, - nullptr); - PD_CHECK(status == synSuccess, - "[RUNTIME] FusedRmsQkvRopeKernel synNodeCreate (norm) failed = ", - status); - - auto qkv_weights = - createTensor(ins[2].size(), dtype_, ins[2], true, "qkv_weights"); + AddNodeRmsNorm(rmsnorm_inputs, + rmsnorm_outputs, + params.rmsnorm_params, + guid_ + "rmsnorm"); + + auto qkv_weights = createTensorFromCT(&ct, 2); std::vector mul_inputs; mul_inputs.push_back(norm_out); mul_inputs.push_back(qkv_weights); - auto wt_dims = ins[2]; + auto wt_dims = ins[2].dims; tmp_dims[2] = wt_dims[0]; - auto qkv_out = - createTensor(tmp_dims.size(), dtype_, tmp_dims, false, "qkv_out"); + auto qkv_out = createTensorNoPresist("qkv_out", dtype_, tmp_dims); std::vector mul_outputs; mul_outputs.push_back(qkv_out); synGEMMParams gemm_params; gemm_params.transpose_a = false; gemm_params.transpose_b = true; - std::string guid_gemm = "batch_gemm"; - std::string gemm_name = guid_ + "gemm"; - status = synNodeCreate(graphHandle_, - mul_inputs.data(), - mul_outputs.data(), - 2, - 1, - &gemm_params, - sizeof(gemm_params), - guid_gemm.c_str(), - gemm_name.c_str(), - nullptr, - nullptr); - PD_CHECK(status == synSuccess, - "[RUNTIME] FusedRmsQkvRopeKernel synNodeCreate (matmul) failed = ", - status); - - auto reshape_dims = ins[0]; + AddNodeBatchGemm(mul_inputs, mul_outputs, gemm_params, guid_ + "batchgemm"); + + auto reshape_dims = ins[0].dims; reshape_dims[2] = params.num_head + 2 * params.kv_num_head; reshape_dims.push_back(params.head_dim); std::vector reshape_outputs; - auto reshape_out = createTensor( - reshape_dims.size(), dtype_, reshape_dims, false, "reshape_out"); + auto reshape_out = + createTensorNoPresist("reshape_out", dtype_, reshape_dims); reshape_outputs.push_back(reshape_out); - status = synNodeCreate(graphHandle_, - mul_outputs.data(), - reshape_outputs.data(), - 1, - 1, - nullptr, - 0, - guid_reshape.c_str(), - name_reshape.c_str(), - nullptr, - nullptr); - PD_CHECK( - status == synSuccess, - "[RUNTIME] FusedRmsQkvRopeKernel synNodeCreate (reshape) failed = ", - status); - - auto kv_dims = outs[1]; - kv_dims.erase(kv_dims.begin()); - - synSplitParams splitParams; - splitParams.axis = 1; + AddNodeReshape(mul_outputs, reshape_outputs, guid_ + "reshape_qkv"); + auto kv_dims = outs[1].dims; + kv_dims.erase(kv_dims.begin()); + auto q_split = createTensorNoPresist("q_split", dtype_, outs[0].dims); + auto k_split = createTensorNoPresist("k_split", dtype_, kv_dims); + auto v_split = createTensorNoPresist("v_split", dtype_, kv_dims); std::vector split_outpus; - auto q_split = - createTensor(outs[0].size(), dtype_, outs[0], false, "q_split"); split_outpus.push_back(q_split); - - auto k_split = - createTensor(kv_dims.size(), dtype_, kv_dims, false, "k_split"); split_outpus.push_back(k_split); - - auto v_split = - createTensor(kv_dims.size(), dtype_, kv_dims, false, "v_split"); split_outpus.push_back(v_split); - std::string split_guid = "split"; - std::string split_name = guid_ + "split"; - status = synNodeCreate(graphHandle_, - reshape_outputs.data(), - split_outpus.data(), - 1, - split_outpus.size(), - &splitParams, - sizeof(splitParams), - split_guid.c_str(), - split_name.c_str(), - nullptr, - nullptr); - PD_CHECK(status == synSuccess, - "[RUNTIME] FusedRmsQkvRopeKernel synNodeCreate (split) failed = ", - status); + synSplitParams splitParams; + splitParams.axis = 1; + AddNodeSplit(reshape_outputs, split_outpus, splitParams, guid_ + "split"); std::vector rotary_embs_inputs; - auto rotary_embs_c = - createTensor(ins[3].size(), dtype_, ins[3], true, "rotary_embs"); + auto rotary_embs_c = createTensorFromCT(&ct, 3); rotary_embs_inputs.push_back(rotary_embs_c); - auto rotary_embs_dims = ins[3]; + auto rotary_embs_dims = ins[3].dims; rotary_embs_dims[0] = 1; std::vector cos_inputs; - auto cos_in = createTensor( - rotary_embs_dims.size(), dtype_, rotary_embs_dims, false, "cos_in"); + auto cos_in = createTensorNoPresist("cos_in", dtype_, rotary_embs_dims); cos_inputs.push_back(cos_in); synSliceParamsV2 sliceParams; @@ -207,103 +132,34 @@ class FusedRmsQkvRopeT : public HpuOperator { sliceParams.starts[i] = 0; sliceParams.ends[i] = rotary_embs_dims[rotary_embs_dims.size() - 1 - i]; } - - std::string slice_guid = "slice"; - std::string slice_name = guid_ + "slice"; - std::string slice_name_cos = slice_name + "_cos"; - status = synNodeCreate(graphHandle_, - rotary_embs_inputs.data(), - cos_inputs.data(), - rotary_embs_inputs.size(), - cos_inputs.size(), - &sliceParams, - sizeof(sliceParams), - slice_guid.c_str(), - slice_name_cos.c_str(), - nullptr, - nullptr); - PD_CHECK( - status == synSuccess, - "[RUNTIME] FusedRmsQkvRopeKernel synNodeCreate (slice/cos) failed = ", - status); + AddNodeSlice( + rotary_embs_inputs, cos_inputs, sliceParams, guid_ + "slice_cos"); std::vector sin_inputs; - auto sin_in = createTensor( - rotary_embs_dims.size(), dtype_, rotary_embs_dims, false, "sin_in"); + auto sin_in = createTensorNoPresist("sin_in", dtype_, rotary_embs_dims); sin_inputs.push_back(sin_in); sliceParams.starts[rotary_embs_dims.size() - 1] = 1; sliceParams.ends[rotary_embs_dims.size() - 1] = 2; - std::string slice_name_sin = slice_name + "_sin"; - status = synNodeCreate(graphHandle_, - rotary_embs_inputs.data(), - sin_inputs.data(), - rotary_embs_inputs.size(), - sin_inputs.size(), - &sliceParams, - sizeof(sliceParams), - slice_guid.c_str(), - slice_name_sin.c_str(), - nullptr, - nullptr); - PD_CHECK( - status == synSuccess, - "[RUNTIME] FusedRmsQkvRopeKernel synNodeCreate (slice/sin) failed = ", - status); - - synSqueezeParams squeezeParams; - squeezeParams.axis = 4; - std::string squeeze_guid = "squeeze"; - std::string squeeze_name = guid_ + "squeeze"; + AddNodeSlice( + rotary_embs_inputs, sin_inputs, sliceParams, guid_ + "slice_sin"); rotary_embs_dims.erase(rotary_embs_dims.begin()); - + auto sin_sq = + createTensorNoPresist("sin_squeezed", dtype_, rotary_embs_dims); std::vector sin_squeezed; - auto sin_sq = createTensor(rotary_embs_dims.size(), - dtype_, - rotary_embs_dims, - false, - "sin_squeezed"); sin_squeezed.push_back(sin_sq); - std::string squeeze_name_sin = squeeze_name + "_sin"; - status = synNodeCreate(graphHandle_, - sin_inputs.data(), - sin_squeezed.data(), - 1, - 1, - &squeezeParams, - sizeof(squeezeParams), - squeeze_guid.c_str(), - squeeze_name_sin.c_str(), - nullptr, - nullptr); - PD_CHECK( - status == synSuccess, - "[RUNTIME] FusedRmsQkvRopeKernel synNodeCreate (squeeze/sin) failed = ", - status); + synSqueezeParams squeezeParams; + squeezeParams.axis = 4; + AddNodeSqueeze( + sin_inputs, sin_squeezed, squeezeParams, guid_ + "squeeze_sin"); + + auto cos_sq = + createTensorNoPresist("cos_squeezed", dtype_, rotary_embs_dims); std::vector cos_squeezed; - auto cos_sq = createTensor(rotary_embs_dims.size(), - dtype_, - rotary_embs_dims, - false, - "cos_squeezed"); cos_squeezed.push_back(cos_sq); - std::string squeeze_name_cos = squeeze_name + "_cos"; - status = synNodeCreate(graphHandle_, - cos_inputs.data(), - cos_squeezed.data(), - 1, - 1, - &squeezeParams, - sizeof(squeezeParams), - squeeze_guid.c_str(), - squeeze_name_cos.c_str(), - nullptr, - nullptr); - PD_CHECK( - status == synSuccess, - "[RUNTIME] FusedRmsQkvRopeKernel synNodeCreate (squeeze/cos) failed = ", - status); + AddNodeSqueeze( + cos_inputs, cos_squeezed, squeezeParams, guid_ + "squeeze_cos"); std::vector inputs_q; std::vector outputs_q; @@ -311,28 +167,13 @@ class FusedRmsQkvRopeT : public HpuOperator { inputs_q.push_back(sin_sq); inputs_q.push_back(cos_sq); - auto q_states = - createTensor(outs[0].size(), dtype_, outs[0], true, "query_states"); + auto q_states = createTensorFromCT(&ct, 0, false); outputs_q.push_back(q_states); ns_RoPESt2::ParamsV2 ropeParams; ropeParams.offset = 0; ropeParams.mode = ROTARY_POS_EMBEDDING_MODE_BLOCKWISE; - - status = synNodeCreate(graphHandle_, - inputs_q.data(), - outputs_q.data(), - inputs_q.size(), - outputs_q.size(), - &ropeParams, - sizeof(ropeParams), - guid_rope.c_str(), - name_rope.c_str(), - nullptr, - nullptr); - PD_CHECK(status == synSuccess, - "[RUNTIME] FusedRmsQkvRopeKernel synNodeCreate (rope/q) failed = ", - status); + AddNodeRope(inputs_q, outputs_q, ropeParams, guid_ + "rope_q"); std::vector inputs_k; std::vector outputs_k; @@ -340,25 +181,9 @@ class FusedRmsQkvRopeT : public HpuOperator { inputs_k.push_back(sin_sq); inputs_k.push_back(cos_sq); - auto k_rope = - createTensor(kv_dims.size(), dtype_, kv_dims, false, "k_rope"); + auto k_rope = createTensorNoPresist("k_rope", dtype_, kv_dims); outputs_k.push_back(k_rope); - - status = synNodeCreate(graphHandle_, - inputs_k.data(), - outputs_k.data(), - inputs_k.size(), - outputs_k.size(), - &ropeParams, - sizeof(ropeParams), - guid_rope.c_str(), - name_rope.c_str(), - nullptr, - nullptr); - - PD_CHECK(status == synSuccess, - "[RUNTIME] FusedRmsQkvRopeKernel synNodeCreate (rope/k) failed = ", - status); + AddNodeRope(inputs_k, outputs_k, ropeParams, guid_ + "rope_k"); std::vector inputs_concat; std::vector outputs_concat; @@ -366,48 +191,20 @@ class FusedRmsQkvRopeT : public HpuOperator { inputs_concat.push_back(v_split); kv_dims[0] *= 2; - auto kv_concat = - createTensor(kv_dims.size(), dtype_, kv_dims, false, "kv_concat"); + auto kv_concat = createTensorNoPresist("kv_concat", dtype_, kv_dims); outputs_concat.push_back(kv_concat); - unsigned concatParams = 3; - status = synNodeCreate(graphHandle_, - inputs_concat.data(), - outputs_concat.data(), - inputs_concat.size(), - outputs_concat.size(), - &concatParams, - sizeof(concatParams), - guid_concat.c_str(), - name_concat.c_str(), - nullptr, - nullptr); - PD_CHECK(status == synSuccess, - "[RUNTIME] FusedRmsQkvRopeKernel synNodeCreate (stack/concat) " - "failed = ", - status); + synConcatenateParams concatParams; + concatParams.axis = 3; + AddNodeConcat( + inputs_concat, outputs_concat, concatParams, guid_ + "concat"); std::vector outputs_stack; - auto kv_state = - createTensor(outs[1].size(), dtype_, outs[1], true, "key_value_states"); + auto kv_state = createTensorFromCT(&ct, 1, false); outputs_stack.push_back(kv_state); - status = synNodeCreate(graphHandle_, - outputs_concat.data(), - outputs_stack.data(), - outputs_concat.size(), - outputs_stack.size(), - nullptr, - 0, - guid_reshape.c_str(), - name_reshape.c_str(), - nullptr, - nullptr); - PD_CHECK(status == synSuccess, - "[RUNTIME] FusedRmsQkvRopeKernel synNodeCreate (stack/reshape) " - "failed = ", - status); + AddNodeReshape(outputs_concat, outputs_stack, guid_ + "reshaped_kv"); } protected: @@ -415,8 +212,9 @@ class FusedRmsQkvRopeT : public HpuOperator { }; template -void FusedRmsQkvRopeKernelT(const Context& dev_ctx, +void FusedRmsQkvRopeTKernel(const Context& dev_ctx, const phi::DenseTensor& src, + const phi::DenseTensor& residual, const phi::DenseTensor& ln_scales, const phi::DenseTensor& qkv_weights, const phi::DenseTensor& rotary_embs, @@ -426,31 +224,24 @@ void FusedRmsQkvRopeKernelT(const Context& dev_ctx, const phi::Scalar& head_dim, const phi::Scalar& num_head) { std::vector src_dims = phi::vectorize(src.dims()); - std::vector ln_scales_dims = - phi::vectorize(ln_scales.dims()); std::vector qkv_weights_dims = phi::vectorize(qkv_weights.dims()); - std::vector rotary_embs_dims = - phi::vectorize(rotary_embs.dims()); - - std::vector out_q_dim = - phi::vectorize(query_states->dims()); - std::vector out_kv_dim = - phi::vectorize(key_value_states->dims()); - - std::vector inputs = { - src_dims, ln_scales_dims, qkv_weights_dims, rotary_embs_dims}; - std::vector outputs = {out_q_dim, out_kv_dim}; int head_dim_ = head_dim.to(); int num_head_ = num_head.to(); - // const int64_t bsz = src_dims[0]; - // const int64_t seq_len = src_dims[1]; const int64_t fused_hidden_size = qkv_weights_dims[0]; - // const int64_t hidden_size = qkv_weights_dims[1]; const int kv_num_head = (fused_hidden_size - num_head_ * head_dim_) / head_dim_ / 2; - // const int num_groups = num_head_ / kv_num_head; + + ConvertTensors ct; + ct.Add(src); + ct.Add(ln_scales); + ct.Add(qkv_weights); + ct.Add(rotary_embs); + ct.Add(residual); + ct.Add(query_states, false); + ct.Add(key_value_states, false); + ct.Add(residual, false); OpCacheOperator op_info; op_info.prepareOpInfo( @@ -468,31 +259,24 @@ void FusedRmsQkvRopeKernelT(const Context& dev_ctx, params.kv_num_head = kv_num_head; FusedRmsQkvRopeT op(op_info.datatype_); - op.AddNode(inputs, outputs, params); + op.AddNode(ct, params); op.Compile(); op_info.setOp(op); recipe = op_info.GetRecipe(); } - std::map tensors; - tensors["src"] = reinterpret_cast(src.data()); - tensors["ln_scales"] = reinterpret_cast(ln_scales.data()); - tensors["qkv_weights"] = reinterpret_cast(qkv_weights.data()); - tensors["rotary_embs"] = reinterpret_cast(rotary_embs.data()); - - tensors["query_states"] = reinterpret_cast(query_states->data()); - tensors["key_value_states"] = - reinterpret_cast(key_value_states->data()); - + std::map tensors = ct.GetDeviceAddr(); RecipeRunner runner(recipe); runner.Run(reinterpret_cast(dev_ctx.stream()), tensors); } + } // namespace custom_kernel template -void CallFusedRmsQkvRopeKernelT(const Context& dev_ctx, +void CallFusedRmsQkvRopeTKernel(const Context& dev_ctx, const phi::DenseTensor& src, + const phi::DenseTensor& residual, const phi::DenseTensor& ln_scales, const phi::DenseTensor& qkv_weights, const phi::DenseTensor& rotary_embs, @@ -502,8 +286,9 @@ void CallFusedRmsQkvRopeKernelT(const Context& dev_ctx, const phi::Scalar& head_dim, const phi::Scalar& num_head) { if (src.dtype() == phi::DataType::FLOAT16) { - custom_kernel::FusedRmsQkvRopeKernelT(dev_ctx, + custom_kernel::FusedRmsQkvRopeTKernel(dev_ctx, src, + residual, ln_scales, qkv_weights, rotary_embs, @@ -513,9 +298,10 @@ void CallFusedRmsQkvRopeKernelT(const Context& dev_ctx, head_dim, num_head); } else if (src.dtype() == phi::DataType::BFLOAT16) { - custom_kernel::FusedRmsQkvRopeKernelT( + custom_kernel::FusedRmsQkvRopeTKernel( dev_ctx, src, + residual, ln_scales, qkv_weights, rotary_embs, @@ -525,7 +311,8 @@ void CallFusedRmsQkvRopeKernelT(const Context& dev_ctx, head_dim, num_head); } else { - throw std::runtime_error("Unsupported data type for FusedRmsQkvRopeKernel"); + throw std::runtime_error( + "Unsupported data type for FusedRmsQkvRopeTKernel"); } } @@ -533,6 +320,7 @@ std::vector FusedRmsQkvRopeT(const paddle::Tensor& src, const paddle::Tensor& ln_scales, const paddle::Tensor& qkv_weights, const paddle::Tensor& rotary_embs, + const paddle::Tensor& residual, float epsilon, int head_dim, int num_head) { @@ -545,6 +333,8 @@ std::vector FusedRmsQkvRopeT(const paddle::Tensor& src, static_cast(qkv_weights.impl().get()); auto rotary_embs_tensor = static_cast(rotary_embs.impl().get()); + auto residual_tensor = + static_cast(residual.impl().get()); // allocate memory on device. int64_t bsz = src.dims()[0]; @@ -563,8 +353,9 @@ std::vector FusedRmsQkvRopeT(const paddle::Tensor& src, phi::make_ddim({2, bsz, seq_len, kv_num_head, head_dim})); dev_ctx->Alloc(key_value_states.get(), src_tensor->dtype()); - CallFusedRmsQkvRopeKernelT(*dev_ctx, + CallFusedRmsQkvRopeTKernel(*dev_ctx, *src_tensor, + *residual_tensor, *ln_scales_tensor, *qkv_weights_tensor, *rotary_embs_tensor, @@ -581,6 +372,7 @@ std::vector> FusedRmsQkvRopeTShape( const std::vector& ln_scales_shape, const std::vector& qkv_weights_shape, const std::vector& rotary_embs_shape, + const std::vector& residual_shape, float epsilon, int head_dim, int num_head) { @@ -596,12 +388,13 @@ std::vector FusedRmsQkvRopeTDtype( const paddle::DataType& src_dtype, const paddle::DataType& ln_scales_dtype, const paddle::DataType& qkv_weights_dtype, - const paddle::DataType& rotary_embs_dtype) { + const paddle::DataType& rotary_embs_dtype, + const paddle::DataType& residual_dtype) { return {src_dtype, src_dtype}; } PD_BUILD_OP(fused_rms_qkv_rope_t) - .Inputs({"src", "ln_scales", "qkv_weights", "rotary_embs"}) + .Inputs({"src", "ln_scales", "qkv_weights", "rotary_embs", "residual"}) .Outputs({"query_states", "key_value_states"}) .Attrs({"epsilon: float", "head_dim: int", "num_head: int"}) .SetKernelFn(PD_KERNEL(FusedRmsQkvRopeT)) diff --git a/backends/intel_hpu/custom_ops/llama_infer/prepare_block_metadata.cc b/backends/intel_hpu/custom_ops/llama_infer/prepare_block_metadata.cc index 90cdbc0e87..d80f485cf5 100644 --- a/backends/intel_hpu/custom_ops/llama_infer/prepare_block_metadata.cc +++ b/backends/intel_hpu/custom_ops/llama_infer/prepare_block_metadata.cc @@ -89,6 +89,14 @@ void pad_fill(const T* input_p, } } +template +void pad_fill(const T* input_p, T* padded, std::vector valid_batches) { +#pragma omp parallel for num_threads(OMP_THREAD_NUM) + for (int i = 0; i < static_cast(valid_batches.size()); ++i) { + padded[i] = input_p[valid_batches[i]]; + } +} + // in: seq_lens_decoder, block_tables // out: block_indices, block_offset // return last_block_pos, seq_lens @@ -151,7 +159,6 @@ std::vector PrepareBlockMetadata( auto hpu_place = rope_emb.place(); auto dev_ctx = static_cast( paddle::experimental::DeviceContextPool::Instance().Get(hpu_place)); - auto input_ids_cpu = input_ids.copy_to(paddle::CPUPlace(), true); auto block_tables_cpu = block_tables.copy_to(paddle::CPUPlace(), true); auto seq_lens_encoder_cpu = seq_lens_encoder.copy_to(paddle::CPUPlace(), true); @@ -178,6 +185,7 @@ std::vector PrepareBlockMetadata( if (enc_count > 0) { int total_batch = find_bucket(enc_count, batch_step, max_batches); + auto input_ids_cpu = input_ids.copy_to(paddle::CPUPlace(), true); int max_buckets = (max_enc_len + block_size - 1) / block_size; int max_prompt_len = max_buckets * block_size; @@ -238,22 +246,22 @@ std::vector PrepareBlockMetadata( } else if (dec_count > 0) { int total_batch = find_bucket(dec_count, batch_step, max_batches); + auto input_ids_column_0 = + paddle::experimental::slice(input_ids, {1}, {0}, {1}, {}, {}); + auto input_ids_cpu = input_ids_column_0.copy_to(paddle::CPUPlace(), true); + auto src_padded = paddle::full( {total_batch}, 0, paddle::DataType::INT64, paddle::CPUPlace()); pad_fill(const_cast(input_ids_cpu.data()), reinterpret_cast(src_padded.data()), - valid_batches_dec, - max_seq_len, - 1); + valid_batches_dec); auto seq_lens_padded = paddle::full( {total_batch}, 0, paddle::DataType::INT32, paddle::CPUPlace()); pad_fill( const_cast(seq_lens_decoder_cpu.data()), reinterpret_cast(seq_lens_padded.data()), - valid_batches_dec, - 1, - 1); + valid_batches_dec); std::shared_ptr seq_lens_padded_hpu = std::make_shared(); diff --git a/backends/intel_hpu/custom_ops/python/paddlenlp_ops/layers.py b/backends/intel_hpu/custom_ops/python/paddlenlp_ops/layers.py index 6d74433359..62d34f9152 100644 --- a/backends/intel_hpu/custom_ops/python/paddlenlp_ops/layers.py +++ b/backends/intel_hpu/custom_ops/python/paddlenlp_ops/layers.py @@ -92,12 +92,13 @@ def __init__(self, ln_scales, qkv_weights, epsilon, head_dim, num_head): self.head_dim = head_dim self.num_head = num_head - def forward(self, i, src, rotary_embs): + def forward(self, i, src, rotary_embs, residual): query_states, kv_states = fused_rms_qkv_rope_t( src, self.ln_scales[i], self.qkv_weights[i], rotary_embs, + residual, self.epsilon, self.head_dim, self.num_head, @@ -208,6 +209,64 @@ def forward( return out_linear_out +class Fused_Block_Attention(paddle.nn.Layer): + def __init__( + self, + ln_scales, + qkv_weights, + epsilon, + head_dim, + num_head, + scaling_factor, + linear_weights, + ): + super().__init__() + self.ln_scales = ln_scales + self.qkv_weights = qkv_weights + self.epsilon = epsilon + self.head_dim = head_dim + self.num_head = num_head + self.scaling_factor = scaling_factor + self.linear_weights = linear_weights + + def forward( + self, + i, + src, + residual, + rotary_embs, + k_caches, + v_caches, + block_groups, + block_list, + block_mapping, + block_bias, + block_indices, + block_offsets, + ): + out_linear_out = fused_block_attention( + src, + residual, + rotary_embs, + k_caches, + v_caches, + block_groups, + block_list, + block_mapping, + block_bias, + block_indices, + block_offsets, + self.ln_scales[i], + self.qkv_weights[i], + self.linear_weights[i], + self.epsilon, + self.head_dim, + self.num_head, + self.scaling_factor, + ) + return out_linear_out + + class Fused_Mlp(paddle.nn.Layer): def __init__(self, proj_weight, up_weight, down_weight): super().__init__() @@ -244,6 +303,26 @@ def forward(self, i, x): return fused_rms_mlp_out +class Fused_Rms_Mlp_Res(paddle.nn.Layer): + def __init__(self, ln_scales, epsilon, proj_weight, down_weight): + super().__init__() + self.ln_scales = ln_scales + self.epsilon = epsilon + self.proj_weight = proj_weight + self.down_weight = down_weight + + def forward(self, i, x, residual): + fused_rms_mlp_out = fused_rms_mlp_res( + x, + self.ln_scales[i], + self.proj_weight[i], + self.down_weight[i], + residual, + self.epsilon, + ) + return fused_rms_mlp_out + + class Prepare_Block_Metadata(paddle.nn.Layer): def __init__(self, block_size): super().__init__() diff --git a/backends/intel_hpu/custom_ops/tests/test_fused_block_attention.py b/backends/intel_hpu/custom_ops/tests/test_fused_block_attention.py new file mode 100644 index 0000000000..c57493817a --- /dev/null +++ b/backends/intel_hpu/custom_ops/tests/test_fused_block_attention.py @@ -0,0 +1,225 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddlenlp_ops + +paddle.device.set_device("intel_hpu:1") + +# paddle.seed(102) + + +class TestFusedBlockAttention: + def __init__(self): + self.head_dim = 128 + self.num_head = 32 + self.kv_num_heads = 32 + self.hidden_size = self.num_head * self.head_dim + + self.epsilon = 1e-06 + + self.use_neox = True + self.position_offset = 0 + self.rope_theta = 10000 + + def init_decode_params(self): + self.test_name = "TestFusedBlockAttentionDecode" + self.batch_size = 16 + self.seq_len = 1 + self.block_size = 128 + self.num_of_block = 32 + self.total_block_num = 20 + position_id = paddle.to_tensor([80]) + self.position_ids = paddle.expand( + position_id, shape=[self.batch_size, self.seq_len] + ) + + def create_tensors(self): + self.k_cache = ( + paddle.rand( + [ + self.total_block_num, + self.block_size, + self.kv_num_heads, + self.head_dim, + ], + dtype=paddle.float32, + ) + * 1000 + ) + self.k_cache = self.k_cache.to(paddle.bfloat16) + self.k_cache_test = self.k_cache.clone() + self.v_cache = ( + paddle.rand( + [ + self.total_block_num, + self.block_size, + self.kv_num_heads, + self.head_dim, + ], + dtype=paddle.float32, + ) + * 1000 + ) + self.v_cache = self.v_cache.to(paddle.bfloat16) + self.v_cache_test = self.v_cache.clone() + + self.input_ids = paddle.zeros( + [self.batch_size, self.seq_len], dtype=paddle.bfloat16 + ) + self.src = paddle.rand( + [self.batch_size, self.seq_len, self.hidden_size], dtype=paddle.float32 + ).to(paddle.bfloat16) + self.residual = paddle.rand( + [self.batch_size, self.seq_len, self.hidden_size], dtype=paddle.float32 + ).to(paddle.bfloat16) + self.residual_test = self.residual.clone() + + self.ln_scales = paddle.rand([self.hidden_size], dtype=paddle.bfloat16) + self.qkv_weights = paddle.rand( + [self.hidden_size * 3, self.hidden_size], dtype=paddle.float32 + ) + self.qkv_weights = self.qkv_weights.to(paddle.bfloat16) + + self.linear_weights = paddle.rand( + [self.hidden_size, self.hidden_size], dtype=paddle.float32 + ).to(paddle.bfloat16) + + self.head_dim_shape_tensor = paddle.ones(self.head_dim, dtype="int8") + self.new_rope = paddlenlp_ops.fused_get_rotary_embedding( + self.input_ids, + self.position_ids, + self.head_dim_shape_tensor, + self.position_offset, + self.rope_theta, + self.use_neox, + ).to(paddle.bfloat16) + + self.block_indices = paddle.randint( + 0, + self.total_block_num, + [ + self.batch_size, + ], + dtype=paddle.int32, + ) + self.block_offsets = paddle.randint( + 0, + self.block_size, + [ + self.batch_size, + ], + dtype=paddle.int32, + ) + + self.block_groups = paddle.randint( + 0, + self.batch_size, + [ + self.num_of_block, + ], + dtype=paddle.int32, + ) + self.block_list = paddle.randint( + 0, + self.num_of_block, + [ + self.num_of_block, + ], + dtype=paddle.int32, + ) + self.block_mapping = paddle.randint( + 0, 2, [self.num_of_block, self.batch_size], dtype=paddle.int32 + ).to(paddle.bfloat16) + self.block_bias = paddle.rand( + [self.num_of_block, self.block_size], dtype=paddle.bfloat16 + ) + + def run_test(self): + query_states, key_value_states = paddlenlp_ops.fused_rms_qkv_rope_t( + self.src, + self.ln_scales, + self.qkv_weights, + self.new_rope.transpose([0, 1, 3, 2, 4]), + self.residual, + self.epsilon, + self.head_dim, + self.num_head, + ) + key_states = key_value_states[0].squeeze(1) + value_states = key_value_states[1].squeeze(1) + + self.k_cache.index_put_((self.block_indices, self.block_offsets), key_states) + self.v_cache.index_put_((self.block_indices, self.block_offsets), value_states) + + out_linear_out_ref = paddlenlp_ops.fused_flatpa_proj( + query_states, + self.k_cache, + self.v_cache, + self.block_groups, + self.block_list, + self.block_mapping, + self.block_bias, + self.linear_weights, + scaling_factor=self.head_dim**-0.5, + ) + + out_linear_out = paddlenlp_ops.fused_block_attention( + self.src, + self.residual_test, + self.new_rope.transpose([0, 1, 3, 2, 4]), + self.k_cache_test, + self.v_cache_test, + self.block_groups, + self.block_list, + self.block_mapping, + self.block_bias, + self.block_indices, + self.block_offsets, + self.ln_scales, + self.qkv_weights, + self.linear_weights, + self.epsilon, + self.head_dim, + self.num_head, + scaling_factor=self.head_dim**-0.5, + ) + + assert ( + (out_linear_out_ref == out_linear_out).all().item() + ), f"Test failed for {self.test_name} fused_block_attention out_linear_out" + assert ( + (self.k_cache == self.k_cache_test).all().item() + ), f"Test failed for {self.test_name} fused_block_attention k_cache" + assert ( + (self.v_cache == self.v_cache_test).all().item() + ), f"Test failed for {self.test_name} fused_block_attention v_cache" + assert ( + (self.residual == self.residual_test).all().item() + ), f"Test failed for {self.test_name} fused_block_attention residual" + + # ===============summary============== + print(f"Test Pass for {self.test_name} testcase") + + +class test_case_decode(TestFusedBlockAttention): + def __init__(self): + super().__init__() + self.init_decode_params() + self.create_tensors() + + +if __name__ == "__main__": + test_1 = test_case_decode() + test_1.run_test() diff --git a/backends/intel_hpu/custom_ops/tests/test_fused_rms_mlp.py b/backends/intel_hpu/custom_ops/tests/test_fused_rms_mlp.py index 3c211b58e7..710c82c34d 100644 --- a/backends/intel_hpu/custom_ops/tests/test_fused_rms_mlp.py +++ b/backends/intel_hpu/custom_ops/tests/test_fused_rms_mlp.py @@ -34,6 +34,9 @@ def init_data( x = paddle.rand( [batch_size, seqence_len, hidden_size], dtype=paddle.float32 ).to(paddle.bfloat16) + residual = paddle.rand( + [batch_size, seqence_len, hidden_size], dtype=paddle.float32 + ).to(paddle.bfloat16) ln_scales = paddle.rand([hidden_size], dtype=paddle.bfloat16) gate_weight = paddle.normal( @@ -49,7 +52,16 @@ def init_data( epsilon = 1e-06 - return x, ln_scales, proj_weight, gate_weight, up_weight, down_weight, epsilon + return ( + x, + ln_scales, + proj_weight, + gate_weight, + up_weight, + down_weight, + residual, + epsilon, + ) def ref_rms_mlp( @@ -90,8 +102,11 @@ def __init__(self): self.gate_weight, self.up_weight, self.down_weight, + self.residual, self.epsilon, ) = init_data() + self.x = self.x + self.residual + self.residual = self.x def forward(self): mlp_out_ref = ref_rms_mlp( @@ -115,8 +130,11 @@ def __init__(self): _, _, self.down_weight, + self.residual, self.epsilon, ) = init_data() + self.x = self.x + self.residual + self.residual = self.x def forward(self): fused_rms_mlp_out = paddlenlp_ops.fused_rms_mlp( @@ -129,6 +147,32 @@ def forward(self): return fused_rms_mlp_out +class fusedRmsMlpResOP(paddle.nn.Layer): + def __init__(self): + super().__init__() + ( + self.x, + self.ln_scales, + self.proj_weight, + _, + _, + self.down_weight, + self.residual, + self.epsilon, + ) = init_data() + + def forward(self): + fused_rms_mlp_out = paddlenlp_ops.fused_rms_mlp_res( + self.x, + self.ln_scales, + self.proj_weight, + self.down_weight, + self.residual, + self.epsilon, + ) + return fused_rms_mlp_out + + def run_profile(my_profile_func): prof = profiler.Profiler( targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.CUSTOM_DEVICE], @@ -144,11 +188,15 @@ def run_profile(my_profile_func): def run_accuracy_check(): ref_rms_mlp = refRmsMlpOP() fused_rms_mlp = fusedRmsMlpOP() + fused_rms_mlp_residual = fusedRmsMlpResOP() golden_res = ref_rms_mlp() fused_rms_res = fused_rms_mlp() + fused_rms_mlp_residual_res = fused_rms_mlp_residual() print((fused_rms_res == golden_res).all()) + print((fused_rms_res == fused_rms_mlp_residual_res).all()) + print((ref_rms_mlp.residual == fused_rms_mlp_residual.residual).all()) def main(): diff --git a/backends/intel_hpu/custom_ops/tests/test_rms_qkv_rope.py b/backends/intel_hpu/custom_ops/tests/test_rms_qkv_rope.py index 8e24283cc6..20b168e9f1 100644 --- a/backends/intel_hpu/custom_ops/tests/test_rms_qkv_rope.py +++ b/backends/intel_hpu/custom_ops/tests/test_rms_qkv_rope.py @@ -90,6 +90,8 @@ def create_tensors(self): self.src = paddle.rand( [self.batch_size, self.seq_len, self.hidden_size], dtype=paddle.bfloat16 ) + self.residual = paddle.zeros_like(self.src, dtype=paddle.bfloat16) + self.ln_scales = paddle.rand([self.hidden_size], dtype=paddle.bfloat16) self.qkv_weights = paddle.rand( [self.hidden_size * 3, self.hidden_size], dtype=paddle.float32 @@ -201,6 +203,7 @@ def run_test(self): self.ln_scales, self.qkv_weights, self.new_rope.transpose([0, 1, 3, 2, 4]), + self.residual, self.epsilon, self.head_dim, self.num_head, @@ -217,6 +220,8 @@ def run_test(self): # ===============summary============== print(f"Test Pass for {self.test_name} testcase") + print((self.src == self.residual).all().item()) + # print(self.residual.data_ptr() == residual.data_ptr()) class test_case_padding(TestFusedRmsQkvRope): diff --git a/backends/intel_hpu/kernels/hpu_funcs.h b/backends/intel_hpu/kernels/hpu_funcs.h index dbf9f6f4ce..24aabe67e6 100644 --- a/backends/intel_hpu/kernels/hpu_funcs.h +++ b/backends/intel_hpu/kernels/hpu_funcs.h @@ -294,6 +294,14 @@ class HpuFusedOperator : public HpuOperator { inputs, outputs, params, guid, node_name); } + template + inline void AddNodeScatter(std::vector inputs, + std::vector outputs, + std::string node_name) { + std::string guid = "scatter_nd_onnx_fwd_" + guid_dtype(); + AddNode_IO(inputs, outputs, guid, node_name); + } + template inline void AddNodeSilu(std::vector inputs, std::vector outputs, @@ -302,13 +310,55 @@ class HpuFusedOperator : public HpuOperator { AddNode_IO(inputs, outputs, guid, node_name); } - template + inline void AddNodeConcat(std::vector inputs, + std::vector outputs, + synConcatenateParams params, + std::string node_name) { + std::string guid = "concat"; + AddNode_IOP(inputs, outputs, params, guid, node_name); + } + inline void AddNodeSplit(std::vector inputs, std::vector outputs, synSplitParams params, std::string node_name) { std::string guid = "split"; - AddNode_IOP(inputs, outputs, params, guid, node_name); + AddNode_IOP(inputs, outputs, params, guid, node_name); + } + + inline void AddNodeSlice(std::vector inputs, + std::vector outputs, + synSliceParamsV2 params, + std::string node_name) { + std::string guid = "slice"; + AddNode_IOP(inputs, outputs, params, guid, node_name); + } + + inline void AddNodeSqueeze(std::vector inputs, + std::vector outputs, + synSqueezeParams params, + std::string node_name) { + std::string guid = "squeeze"; + AddNode_IOP(inputs, outputs, params, guid, node_name); + } + + template + inline void AddNodeRmsNorm(std::vector inputs, + std::vector outputs, + ns_LayerNormKernel::Params params, + std::string node_name) { + std::string guid = "rms_norm_ex_fwd_" + guid_dtype(); + AddNode_IOP( + inputs, outputs, params, guid, node_name); + } + + template + inline void AddNodeRope(std::vector inputs, + std::vector outputs, + ns_RoPESt2::ParamsV2 params, + std::string node_name) { + std::string guid = "rotary_pos_embedding_fwd_" + guid_dtype(); + AddNode_IOP(inputs, outputs, params, guid, node_name); } }; diff --git a/backends/intel_hpu/kernels/swiglu_kernel.cc b/backends/intel_hpu/kernels/swiglu_kernel.cc index 92742cc19a..f6952e4c31 100644 --- a/backends/intel_hpu/kernels/swiglu_kernel.cc +++ b/backends/intel_hpu/kernels/swiglu_kernel.cc @@ -80,7 +80,7 @@ class SwiGlu : public HpuFusedOperator { std::vector split_in = {cast_x}; std::vector split_out = {split_x, split_y}; std::string node_name = guid_ + "split"; - AddNodeSplit(split_in, split_out, params, node_name); + AddNodeSplit(split_in, split_out, params, node_name); } else { split_x = cast_x; split_y = cast_y;