-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Adding SME1 Convolution Kernel to convole_kleidiai.cpp #26402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| #include <functional> | ||
| #include <unordered_map> | ||
|
|
||
| #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h" | ||
| #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" | ||
| #include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h" | ||
| #include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h" | ||
|
|
@@ -160,24 +161,7 @@ static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) { | |
| return false; | ||
| } | ||
|
|
||
| //optimization checks - is the implementation optimal for the conv request | ||
|
|
||
| const auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); | ||
| const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); | ||
|
|
||
| auto M = ComputeConvOutSize(Parameters->InputShape[0], ComputeKernelSize(Parameters->DilationShape[0], | ||
| Parameters->KernelShape[0]), Parameters->Padding[0], Parameters->StrideShape[0]) * | ||
| ComputeConvOutSize(Parameters->InputShape[1], ComputeKernelSize(Parameters->DilationShape[1], | ||
| Parameters->KernelShape[1]), Parameters->Padding[1], Parameters->StrideShape[1]); | ||
| auto N = Parameters->FilterCount; | ||
| auto K = Parameters->InputChannels * Parameters->KernelShape[0] * Parameters->KernelShape[1]; | ||
|
|
||
| //Can use these variables to add other conditions as required | ||
| MLAS_UNREFERENCED_PARAMETER(M); | ||
| MLAS_UNREFERENCED_PARAMETER(K); | ||
| MLAS_UNREFERENCED_PARAMETER(m_step); | ||
| MLAS_UNREFERENCED_PARAMETER(n_step); | ||
|
|
||
| if (N == 1 || Parameters->KernelShape[0] < 3 || Parameters->KernelShape[1] < 3) { | ||
| return false; | ||
| } | ||
|
|
@@ -312,8 +296,8 @@ static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci | |
| const size_t kw, const void * const* lhs_ptrs, std::byte* lhs_data, | ||
| const float* in_data, | ||
| const float* pad_ptr) { | ||
|
|
||
| auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); | ||
| size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() | ||
| : kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); | ||
|
|
||
| // Minimize the kernel call count for the number of available threads | ||
| auto RequiredTiles = MlasDivRoundup(m, m_step); | ||
|
|
@@ -391,7 +375,9 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i | |
|
|
||
| const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw); | ||
|
|
||
| const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); | ||
| const auto m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() | ||
| : kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); | ||
|
|
||
| const auto lhs_ptrs_k = kh * kw; | ||
| const auto lhs_ptrs_m = m_step * MlasDivRoundup(m, m_step); | ||
| auto lhs_ptrs = std::shared_ptr<const void*[]>(new const void*[lhs_ptrs_k * lhs_ptrs_m], | ||
|
|
@@ -497,13 +483,13 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s | |
| } | ||
|
|
||
| static void ConvolveSme(const size_t co, //channels out | ||
| const size_t ci, //channels in | ||
| const size_t ih, //image height | ||
| const size_t iw, //image width | ||
| const size_t kh, //kernel height | ||
| const size_t kw, //kernel width | ||
| const size_t sh, //kernel stride height | ||
| const size_t sw, //kernel stride width | ||
| const size_t ci, //channels in | ||
| const size_t ih, //image height | ||
| const size_t iw, //image width | ||
| const size_t kh, //kernel height | ||
| const size_t kw, //kernel width | ||
| const size_t sh, //kernel stride height | ||
| const size_t sw, //kernel stride width | ||
| const size_t dilationh, //kernel dilation stride | ||
| const size_t dilationw, //kernel dilation stride | ||
| const size_t padding, //padding size | ||
|
|
@@ -524,10 +510,12 @@ static void ConvolveSme(const size_t co, //channels out | |
| const auto m = ComputeConvOutSize(ih, d_kh, padding, sh) * | ||
| ComputeConvOutSize(iw, d_kw, padding, sw); | ||
|
|
||
| auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); | ||
| auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); | ||
| size_t n_step = ArmKleidiAI::UseSME2 ? kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() | ||
| : kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); | ||
| size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa() | ||
| : kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); | ||
|
|
||
| //tile iteration dimensions | ||
| // tile iteration dimensions | ||
| std::array<size_t,3> dim; | ||
| dim[0] = 1; // B | ||
| dim[1] = MlasDivRoundup(m, m_step); // M | ||
|
|
@@ -563,29 +551,23 @@ static void ConvolveSme(const size_t co, //channels out | |
| auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, in, ThreadPool); | ||
| auto rhs = RhsPackWeightsBiasSme(co, ci, kh, kw, dilationh, dilationw, weights, bias, ThreadPool); | ||
|
|
||
|
|
||
| MlasTrySimpleParallel(ThreadPool, | ||
| static_cast<ptrdiff_t>(dim[0]*dim[1]*dim[2]), | ||
| [&](ptrdiff_t tid) | ||
| { | ||
| MlasTrySimpleParallel(ThreadPool, static_cast<ptrdiff_t>(dim[0] * dim[1] * dim[2]), [&](ptrdiff_t tid) { | ||
| //compute B,M,N index from iteration index | ||
| //ptrdiff_t BIdx = tid / (dim[1] * dim[2]); | ||
| ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; | ||
| ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; | ||
|
|
||
| // Get rhs tile, B | ||
| const size_t rhs_packed_offset = | ||
| kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx*n_step, | ||
| d_kh*d_kw,ci); | ||
| const size_t rhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx * n_step, d_kh * d_kw, ci) | ||
| : kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, d_kh * d_kw, ci); | ||
|
|
||
| auto BTile = reinterpret_cast<const void*>( | ||
| reinterpret_cast<const std::byte*>(rhs.get()) + rhs_packed_offset | ||
| ); | ||
|
|
||
| // Get lhs tile, A | ||
| const size_t lhs_packed_offset = | ||
| kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx*m_step, | ||
| d_kh*d_kw,ci); | ||
| const size_t lhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx * m_step, d_kh * d_kw, ci) | ||
| : kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, d_kh * d_kw, ci); | ||
|
|
||
| auto ATile = reinterpret_cast<const float*>( | ||
| reinterpret_cast<const std::byte*>(lhs.get()) + lhs_packed_offset | ||
|
|
@@ -599,10 +581,19 @@ static void ConvolveSme(const size_t co, //channels out | |
| MIdx * m_step * co * sizeof(float) + | ||
| NIdx * n_step * sizeof(float)]; | ||
|
|
||
| kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( | ||
| TileSizeM, TileSizeN, d_kh*d_kw, ci, ATile, BTile, CTile, co * sizeof(float), | ||
| -std::numeric_limits<float>::max(), std::numeric_limits<float>::max() | ||
| ); | ||
| if (ArmKleidiAI::UseSME2) { | ||
| KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the usage of the logging macros here mean that we need to wait for the logging PR to be merged ? |
||
| kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( | ||
| TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float), | ||
| -std::numeric_limits<float>::max(), std::numeric_limits<float>::max() | ||
| ); | ||
| } else { | ||
| KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci); | ||
| kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa( | ||
| TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float), | ||
| -std::numeric_limits<float>::max(), std::numeric_limits<float>::max() | ||
| ); | ||
| } | ||
| }); | ||
|
|
||
| if (result == tmp_mlas_aligned) { | ||
|
|
@@ -702,11 +693,11 @@ ArmKleidiAI::MlasConv( | |
| ) | ||
| { | ||
| if(!CheckCapabilitiesSme(Parameters)){ | ||
| //Fallback to Default Mlas | ||
| // Fallback to Default Mlas | ||
| return false; | ||
| }; | ||
| ConvolveSme(Parameters->FilterCount, Parameters->InputChannels, // channel out, in | ||
| Parameters->InputShape[0], Parameters->InputShape[1], // image dimensions | ||
| Parameters->InputShape[0], Parameters->InputShape[1], // image dimensions | ||
| Parameters->KernelShape[0], Parameters->KernelShape[1], // kernel dimensions | ||
| Parameters->StrideShape[0], Parameters->StrideShape[1], // kernel stride dimensions | ||
| Parameters->DilationShape[0], Parameters->DilationShape[1], // kernel dilation | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -210,9 +210,9 @@ MlasDynamicQGemmBatch ( | |
| MLAS_THREADPOOL* ThreadPool | ||
| ) { | ||
| #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) | ||
| //No fallback and putting in guards | ||
| if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ | ||
| ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); | ||
| //No fallback and putting in guards. This implementation is SME2 specific. | ||
| if(ArmKleidiAI::UseSME2){ | ||
| ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this change is no longer needed after #26301 supports SME variants now ? |
||
| } | ||
| #endif | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this can be bumped up to 1.15 now ? Given that #26301 gets that update anyway ?