|
| 1 | +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "habanalabs/perf_lib_layer_params.h" |
| 16 | +#include "kernels/funcs.h" |
| 17 | +#include "kernels/hpu_funcs.h" |
| 18 | +#include "kernels/hpu_operator.h" |
| 19 | +#include "paddle/extension.h" |
| 20 | +#include "utils/utils.h" |
| 21 | + |
| 22 | +namespace custom_kernel { |
| 23 | + |
| 24 | +struct FusedRmsMlpResParams { |
| 25 | + ns_LayerNormKernel::Params rmsnorm_params; |
| 26 | + synSplitParams split_params; |
| 27 | +}; |
| 28 | + |
| 29 | +class FusedRmsMlpRes : public HpuFusedOperator { |
| 30 | + public: |
| 31 | + explicit FusedRmsMlpRes(synDataType dtype) |
| 32 | + : HpuFusedOperator("fused_rms_mlp_res_fwd_", false), dtype_(dtype) {} |
| 33 | + template <typename T> |
| 34 | + void AddNode(ConvertTensors& ct, FusedRmsMlpResParams params) { |
| 35 | + auto ins = ct.GetTensors(); |
| 36 | + auto outs = ct.GetTensors(false); |
| 37 | + |
| 38 | + synGEMMParams gemm_params; |
| 39 | + gemm_params.transpose_a = false; |
| 40 | + gemm_params.transpose_b = false; |
| 41 | + |
| 42 | + synSectionHandle section = createSection(); |
| 43 | + auto hidden_states = createTensorFromCT(&ct, 0); |
| 44 | + auto residual_input = createTensorFromCT(&ct, 4, true, section); |
| 45 | + auto residual_out = createTensorFromCT(&ct, 1, false, section); |
| 46 | + |
| 47 | + std::vector<synTensor> add_residual_in; |
| 48 | + add_residual_in.push_back(hidden_states); |
| 49 | + add_residual_in.push_back(residual_input); |
| 50 | + |
| 51 | + std::vector<synTensor> add_residual_out; |
| 52 | + add_residual_out.push_back(residual_out); |
| 53 | + |
| 54 | + AddNodeAdd<T>(add_residual_in, add_residual_out, guid_ + "add_residual"); |
| 55 | + |
| 56 | + auto ln_scales = createTensorFromCT(&ct, 1); |
| 57 | + std::vector<synTensor> rmsnorm_inputs; |
| 58 | + rmsnorm_inputs.push_back(residual_out); |
| 59 | + rmsnorm_inputs.push_back(ln_scales); |
| 60 | + |
| 61 | + auto tmp_dims = ins[0].dims; |
| 62 | + tmp_dims[2] = 1; |
| 63 | + auto norm_out = createTensorNoPresist("norm_out", ins[0].type, ins[0].dims); |
| 64 | + auto norm_var = createTensorNoPresist("norm_var", ins[0].type, tmp_dims); |
| 65 | + std::vector<synTensor> rmsnorm_outputs; |
| 66 | + rmsnorm_outputs.push_back(norm_out); |
| 67 | + rmsnorm_outputs.push_back(norm_var); |
| 68 | + |
| 69 | + AddNodeRmsNorm<T>(rmsnorm_inputs, |
| 70 | + rmsnorm_outputs, |
| 71 | + params.rmsnorm_params, |
| 72 | + guid_ + "rmsnorm"); |
| 73 | + |
| 74 | + auto proj_weight = createTensorFromCT(&ct, 2); |
| 75 | + std::vector<int64_t> proj_dims = { |
| 76 | + ins[0].dims[0], ins[0].dims[1], ins[2].dims[1]}; |
| 77 | + auto proj_out = createTensorNoPresist("proj_out", ins[0].type, proj_dims); |
| 78 | + |
| 79 | + std::vector<synTensor> proj_inputs; |
| 80 | + proj_inputs.push_back(norm_out); |
| 81 | + proj_inputs.push_back(proj_weight); |
| 82 | + std::vector<synTensor> proj_outputs; |
| 83 | + proj_outputs.push_back(proj_out); |
| 84 | + |
| 85 | + AddNodeGemm(proj_inputs, proj_outputs, gemm_params, guid_ + "gemm_up_proj"); |
| 86 | + |
| 87 | + std::vector<int64_t> split_out_dims = { |
| 88 | + proj_dims[0], proj_dims[1], proj_dims[2] / 2}; |
| 89 | + auto gate_out = |
| 90 | + createTensorNoPresist("gate_out", ins[0].type, split_out_dims); |
| 91 | + auto up_out = createTensorNoPresist("up_out", ins[0].type, split_out_dims); |
| 92 | + auto down_weight = createTensorFromCT(&ct, 3); |
| 93 | + |
| 94 | + std::vector<synTensor> split_inputs; |
| 95 | + split_inputs.push_back(proj_out); |
| 96 | + std::vector<synTensor> split_outputs; |
| 97 | + split_outputs.push_back(gate_out); |
| 98 | + split_outputs.push_back(up_out); |
| 99 | + |
| 100 | + AddNodeSplit( |
| 101 | + split_inputs, split_outputs, params.split_params, guid_ + "split"); |
| 102 | + |
| 103 | + auto silu_out = |
| 104 | + createTensorNoPresist("silu_out", ins[0].type, split_out_dims); |
| 105 | + std::vector<synTensor> silu_inputs; |
| 106 | + silu_inputs.push_back(gate_out); |
| 107 | + std::vector<synTensor> silu_outputs; |
| 108 | + silu_outputs.push_back(silu_out); |
| 109 | + |
| 110 | + AddNodeSilu<T>(silu_inputs, silu_outputs, guid_ + "silu"); |
| 111 | + |
| 112 | + auto multi_out = |
| 113 | + createTensorNoPresist("multi_out", ins[0].type, split_out_dims); |
| 114 | + std::vector<synTensor> multi_inputs; |
| 115 | + multi_inputs.push_back(silu_out); |
| 116 | + multi_inputs.push_back(up_out); |
| 117 | + std::vector<synTensor> multi_outputs; |
| 118 | + multi_outputs.push_back(multi_out); |
| 119 | + |
| 120 | + AddNodeMultiply<T>(multi_inputs, multi_outputs, guid_ + "_multi"); |
| 121 | + |
| 122 | + auto mlp_out = createTensorFromCT(&ct, 0, false); |
| 123 | + std::vector<synTensor> down_inputs; |
| 124 | + down_inputs.push_back(multi_out); |
| 125 | + down_inputs.push_back(down_weight); |
| 126 | + std::vector<synTensor> down_outputs; |
| 127 | + down_outputs.push_back(mlp_out); |
| 128 | + |
| 129 | + AddNodeGemm( |
| 130 | + down_inputs, down_outputs, gemm_params, guid_ + "gemm_down_proj"); |
| 131 | + } |
| 132 | + |
| 133 | + protected: |
| 134 | + synDataType dtype_; |
| 135 | +}; |
| 136 | + |
| 137 | +template <typename T, typename Context> |
| 138 | +void FusedRmsMlpResKernel(const Context& dev_ctx, |
| 139 | + const phi::DenseTensor& x, |
| 140 | + const phi::DenseTensor& residual, |
| 141 | + const phi::DenseTensor& ln_scales, |
| 142 | + const phi::DenseTensor& proj_weight, |
| 143 | + const phi::DenseTensor& down_weight, |
| 144 | + const phi::Scalar& epsilon, |
| 145 | + phi::DenseTensor* out) { |
| 146 | + // allocate memory on device. |
| 147 | + dev_ctx.template Alloc<T>(out); |
| 148 | + if (out->numel() == 0) { |
| 149 | + return; |
| 150 | + } |
| 151 | + |
| 152 | + std::vector<int64_t> ln_scales_dims = |
| 153 | + phi::vectorize<int64_t>(ln_scales.dims()); |
| 154 | + |
| 155 | + const phi::Scalar axis_scalar = proj_weight.dims().size() - 1; |
| 156 | + int64_t axis = axis_scalar.to<int64_t>(); |
| 157 | + if (axis < 0) { |
| 158 | + axis = proj_weight.dims().size() + axis; |
| 159 | + } |
| 160 | + FusedRmsMlpResParams params; |
| 161 | + memset(reinterpret_cast<void*>(¶ms), 0x00, sizeof(FusedRmsMlpResParams)); |
| 162 | + params.rmsnorm_params.epsValid = true; |
| 163 | + params.rmsnorm_params.eps = epsilon.to<float>(); |
| 164 | + |
| 165 | + params.split_params = {{0}}; |
| 166 | + params.split_params.axis = proj_weight.dims().size() - 1 - axis; |
| 167 | + |
| 168 | + ConvertTensors ct; |
| 169 | + ct.Add(x); |
| 170 | + ct.Add(ln_scales); |
| 171 | + ct.Add(proj_weight); |
| 172 | + ct.Add(down_weight); |
| 173 | + ct.Add(residual); |
| 174 | + ct.Add(*out, false); |
| 175 | + ct.Add(residual, false); |
| 176 | + std::vector<DIMS> inputs_dims = ct.GetDims(); |
| 177 | + |
| 178 | + OpCacheOperator op_info; |
| 179 | + op_info.prepareOpInfo<T, FusedRmsMlpResParams>( |
| 180 | + "FusedRmsMlpResKernel", inputs_dims, ¶ms); |
| 181 | + auto recipe = op_info.GetRecipe(); |
| 182 | + |
| 183 | + if (recipe == nullptr) { |
| 184 | + FusedRmsMlpRes op(op_info.datatype_); |
| 185 | + op.AddNode<T>(ct, params); |
| 186 | + op.Compile(); |
| 187 | + op_info.setOp(op); |
| 188 | + |
| 189 | + recipe = op_info.GetRecipe(); |
| 190 | + } |
| 191 | + |
| 192 | + std::map<std::string, uint64_t> tensors = ct.GetDeviceAddr(); |
| 193 | + RecipeRunner runner(recipe); |
| 194 | + runner.Run(reinterpret_cast<C_Stream>(dev_ctx.stream()), tensors); |
| 195 | +} |
| 196 | + |
| 197 | +} // namespace custom_kernel |
| 198 | + |
| 199 | +template <typename Context> |
| 200 | +void CallFusedRmsMlpResKernel(const Context& dev_ctx, |
| 201 | + const phi::DenseTensor& x, |
| 202 | + const phi::DenseTensor& residual, |
| 203 | + const phi::DenseTensor& ln_scales, |
| 204 | + const phi::DenseTensor& proj_weight, |
| 205 | + const phi::DenseTensor& down_weight, |
| 206 | + const phi::Scalar& epsilon, |
| 207 | + phi::DenseTensor* out) { |
| 208 | + if (x.dtype() == phi::DataType::BFLOAT16) { |
| 209 | + custom_kernel::FusedRmsMlpResKernel<phi::dtype::bfloat16>(dev_ctx, |
| 210 | + x, |
| 211 | + residual, |
| 212 | + ln_scales, |
| 213 | + proj_weight, |
| 214 | + down_weight, |
| 215 | + epsilon, |
| 216 | + out); |
| 217 | + } else { |
| 218 | + throw std::runtime_error("Unsupported data type for FusedRmsMlpResKernel"); |
| 219 | + } |
| 220 | +} |
| 221 | + |
| 222 | +std::vector<paddle::Tensor> FusedRmsMlpResForward( |
| 223 | + const paddle::Tensor& x, |
| 224 | + const paddle::Tensor& ln_scales, |
| 225 | + const paddle::Tensor& proj_weight, |
| 226 | + const paddle::Tensor& down_weight, |
| 227 | + const paddle::Tensor& residual, |
| 228 | + const float epsilon) { |
| 229 | + auto dev_ctx = static_cast<const phi::CustomContext*>( |
| 230 | + paddle::experimental::DeviceContextPool::Instance().Get(x.place())); |
| 231 | + |
| 232 | + auto x_tensor = static_cast<const phi::DenseTensor*>(x.impl().get()); |
| 233 | + auto residual_tensor = |
| 234 | + static_cast<const phi::DenseTensor*>(residual.impl().get()); |
| 235 | + |
| 236 | + auto ln_scales_tensor = |
| 237 | + static_cast<const phi::DenseTensor*>(ln_scales.impl().get()); |
| 238 | + auto down_tensor = |
| 239 | + static_cast<const phi::DenseTensor*>(down_weight.impl().get()); |
| 240 | + auto proj_tensor = |
| 241 | + static_cast<const phi::DenseTensor*>(proj_weight.impl().get()); |
| 242 | + |
| 243 | + auto out_tensor = std::make_shared<phi::DenseTensor>(); |
| 244 | + out_tensor->Resize(x_tensor->dims()); |
| 245 | + |
| 246 | + CallFusedRmsMlpResKernel(*dev_ctx, |
| 247 | + *x_tensor, |
| 248 | + *residual_tensor, |
| 249 | + *ln_scales_tensor, |
| 250 | + *proj_tensor, |
| 251 | + *down_tensor, |
| 252 | + phi::Scalar(epsilon), |
| 253 | + out_tensor.get()); |
| 254 | + |
| 255 | + paddle::Tensor out(out_tensor); |
| 256 | + |
| 257 | + return {out}; |
| 258 | +} |
| 259 | + |
| 260 | +std::vector<std::vector<int64_t>> FusedRmsMlpResInferShape( |
| 261 | + const std::vector<int64_t>& x_shape, |
| 262 | + const std::vector<int64_t>& ln_scales_shape, |
| 263 | + const std::vector<int64_t>& proj_weight_shape, |
| 264 | + const std::vector<int64_t>& down_weight_shape, |
| 265 | + const std::vector<int64_t>& residual_shape) { |
| 266 | + return {x_shape, residual_shape}; |
| 267 | +} |
| 268 | + |
| 269 | +std::vector<paddle::DataType> FusedRmsMlpResInferDtype( |
| 270 | + const paddle::DataType& x_dtype, |
| 271 | + const paddle::DataType& ln_scales_dtype, |
| 272 | + const paddle::DataType& proj_weight_dtype, |
| 273 | + const paddle::DataType& down_weight_dtype, |
| 274 | + const paddle::DataType& residual_dtype) { |
| 275 | + return {x_dtype, residual_dtype}; |
| 276 | +} |
| 277 | + |
| 278 | +PD_BUILD_OP(fused_rms_mlp_res) |
| 279 | + .Inputs({"x", "ln_scales", "proj_weight", "down_weight", "residual_in"}) |
| 280 | + .Outputs({"out"}) |
| 281 | + .Attrs({"epsilon: float"}) |
| 282 | + .SetKernelFn(PD_KERNEL(FusedRmsMlpResForward)) |
| 283 | + .SetInferShapeFn(PD_INFER_SHAPE(FusedRmsMlpResInferShape)) |
| 284 | + .SetInferDtypeFn(PD_INFER_DTYPE(FusedRmsMlpResInferDtype)); |
0 commit comments