From 2cc67d2215c60ada5fb7452c185c26d1a3027000 Mon Sep 17 00:00:00 2001 From: Piotr Sowa Date: Sun, 6 Oct 2019 09:28:42 +0200 Subject: [PATCH 1/3] Experimental Multi-GPU for GEMM --- .../blas/specialCases/GemmSpecialCases.cpp | 303 +++++++++--------- src/library/blas/xgemm.cc | 181 ++++++----- 2 files changed, 255 insertions(+), 229 deletions(-) diff --git a/src/library/blas/specialCases/GemmSpecialCases.cpp b/src/library/blas/specialCases/GemmSpecialCases.cpp index be11f41a..baa88344 100644 --- a/src/library/blas/specialCases/GemmSpecialCases.cpp +++ b/src/library/blas/specialCases/GemmSpecialCases.cpp @@ -33,7 +33,7 @@ /* template clblasStatus SGEMM_SPLIT_CALLS( - cl_kernel *ClKernel, clblasOrder order, + cl_kernel ClKernel, clblasOrder order, unsigned int tile_size, unsigned int WG_size, unsigned int M_split_factor, unsigned int N_split_factor, @@ -54,7 +54,7 @@ clblasStatus SGEMM_SPLIT_CALLS( */ template clblasStatus GEMM_SPLIT_CALLS( - cl_kernel *ClKernel, clblasOrder order, + cl_kernel ClKernel, clblasOrder order, unsigned int tile_size, unsigned int WG_size, unsigned int M_split_factor, unsigned int N_split_factor, @@ -104,7 +104,6 @@ clblasStatus GEMM_SPLIT_CALLS( if (transA == clblasNoTrans && transB == clblasTrans) { - unsigned int small_M = M / M_split_factor; unsigned int small_N = N / N_split_factor; unsigned int small_K = K / K_split_factor; @@ -117,11 +116,11 @@ clblasStatus GEMM_SPLIT_CALLS( precision betaone = 1; - error = clSetKernelArg(*ClKernel, 5, sizeof(cl_uint), &small_M); + error = clSetKernelArg(ClKernel, 5, sizeof(cl_uint), &small_M); assert(error == CL_SUCCESS); - error = clSetKernelArg(*ClKernel, 6, sizeof(cl_uint), &small_N); + error = clSetKernelArg(ClKernel, 6, sizeof(cl_uint), &small_N); assert(error == CL_SUCCESS); - error = clSetKernelArg(*ClKernel, 7, sizeof(cl_uint), &small_K); + error = clSetKernelArg(ClKernel, 7, sizeof(cl_uint), &small_K); assert(error == CL_SUCCESS); for (int M_split_index = 0; M_split_index < M_split_factor; M_split_index++) @@ -129,21 +128,21 @@ clblasStatus GEMM_SPLIT_CALLS( for (int N_split_index = 0; N_split_index < N_split_factor; N_split_index++) { unsigned int offc_C = ldc*N / N_split_factor * N_split_index + M / M_split_factor * M_split_index + offC; - error = clSetKernelArg(*ClKernel, 13, sizeof(cl_uint), &offc_C); + error = clSetKernelArg(ClKernel, 13, sizeof(cl_uint), &offc_C); assert(error == CL_SUCCESS); for (int K_split_index = 0; K_split_index < K_split_factor; K_split_index++) { unsigned int offa_A = (M / M_split_factor * M_split_index) + (lda * K / K_split_factor * K_split_index) + offA; unsigned int offb_B = (N / N_split_factor * N_split_index) + (ldb * K / K_split_factor * K_split_index) + offB; - error = clSetKernelArg(*ClKernel, 11, sizeof(cl_uint), &offa_A); + error = clSetKernelArg(ClKernel, 11, sizeof(cl_uint), &offa_A); assert(error == CL_SUCCESS); - error = clSetKernelArg(*ClKernel, 12, sizeof(cl_uint), &offb_B); + error = clSetKernelArg(ClKernel, 12, sizeof(cl_uint), &offb_B); assert(error == CL_SUCCESS); if (K_split_index == 0) { - error = clSetKernelArg(*ClKernel, 4, sizeof(precision), &(beta)); + error = clSetKernelArg(ClKernel, 4, sizeof(precision), &(beta)); assert(error == CL_SUCCESS); if (M_split_index == 0 && N_split_index == 0) @@ -152,39 +151,39 @@ clblasStatus GEMM_SPLIT_CALLS( if ((M_split_factor == 1) && (N_split_factor == 1) && (K_split_factor == 1)) { //also very last GEMM call - error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL, + error = clEnqueueNDRangeKernel(commandQueues[0], ClKernel, 2, NULL, gs, wgsize, numEventsInWaitList, eventWaitList, &events[0]); assert(error == CL_SUCCESS); } else { - error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL, + error = clEnqueueNDRangeKernel(commandQueues[0], ClKernel, 2, NULL, gs, wgsize, numEventsInWaitList, eventWaitList, NULL); assert(error == CL_SUCCESS); } } else { - error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL, + error = clEnqueueNDRangeKernel(commandQueues[0], ClKernel, 2, NULL, gs, wgsize, 0, NULL, NULL); assert(error == CL_SUCCESS); } } else { - error = clSetKernelArg(*ClKernel, 4, sizeof(precision), &betaone); + error = clSetKernelArg(ClKernel, 4, sizeof(precision), &betaone); assert(error == CL_SUCCESS); if ((M_split_index == (M_split_factor - 1)) && (N_split_index == (N_split_factor - 1)) && (K_split_index == (K_split_factor - 1))) { //very last GEMM call - error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL, + error = clEnqueueNDRangeKernel(commandQueues[0], ClKernel, 2, NULL, gs, wgsize, 0, NULL, events); assert(error == CL_SUCCESS); } else { - error = clEnqueueNDRangeKernel(commandQueues[0], *ClKernel, 2, NULL, + error = clEnqueueNDRangeKernel(commandQueues[0], ClKernel, 2, NULL, gs, wgsize, 0, NULL, NULL); assert(error == CL_SUCCESS); } @@ -215,7 +214,7 @@ clblasStatus SGEMM_mod1024( bool &specialCaseHandled) { const char *tileKernelSource = NULL; - cl_kernel *tileClKernel = NULL; + cl_kernel tileClKernel = NULL; size_t tileKernelBinarySize = 0; cl_int err; @@ -259,39 +258,39 @@ clblasStatus SGEMM_mod1024( } tileKernelSource = sgemm_Col_NT_B1_MX128_NX128_KX16_src; - tileClKernel = &sgemm_Col_NT_B1_MX128_NX128_KX16_clKernel; + tileClKernel = sgemm_Col_NT_B1_MX128_NX128_KX16_clKernel; tileKernelBinary = sgemm_Col_NT_B1_MX128_NX128_KX16_bin; tileKernelBinarySize = sgemm_Col_NT_B1_MX128_NX128_KX16_binSize; - makeGemmKernel(tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions); + makeGemmKernel(&tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions); - err = clSetKernelArg(*tileClKernel, 0, sizeof(cl_mem), &A); + err = clSetKernelArg(tileClKernel, 0, sizeof(cl_mem), &A); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 1, sizeof(cl_mem), &B); + err = clSetKernelArg(tileClKernel, 1, sizeof(cl_mem), &B); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 2, sizeof(cl_mem), &C); + err = clSetKernelArg(tileClKernel, 2, sizeof(cl_mem), &C); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 3, sizeof(cl_float), &alpha); + err = clSetKernelArg(tileClKernel, 3, sizeof(cl_float), &alpha); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 4, sizeof(cl_float), &beta); + err = clSetKernelArg(tileClKernel, 4, sizeof(cl_float), &beta); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 5, sizeof(cl_uint), &M); + err = clSetKernelArg(tileClKernel, 5, sizeof(cl_uint), &M); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 6, sizeof(cl_uint), &N); + err = clSetKernelArg(tileClKernel, 6, sizeof(cl_uint), &N); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 7, sizeof(cl_uint), &K); + err = clSetKernelArg(tileClKernel, 7, sizeof(cl_uint), &K); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 8, sizeof(cl_uint), &lda); + err = clSetKernelArg(tileClKernel, 8, sizeof(cl_uint), &lda); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 9, sizeof(cl_uint), &ldb); + err = clSetKernelArg(tileClKernel, 9, sizeof(cl_uint), &ldb); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 10, sizeof(cl_uint), &ldc); + err = clSetKernelArg(tileClKernel, 10, sizeof(cl_uint), &ldc); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 11, sizeof(cl_uint), &offA); + err = clSetKernelArg(tileClKernel, 11, sizeof(cl_uint), &offA); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 12, sizeof(cl_uint), &offB); + err = clSetKernelArg(tileClKernel, 12, sizeof(cl_uint), &offB); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 13, sizeof(cl_uint), &offC); + err = clSetKernelArg(tileClKernel, 13, sizeof(cl_uint), &offC); CL_CHECK(err); status = GEMM_SPLIT_CALLS( @@ -334,39 +333,39 @@ clblasStatus SGEMM_mod1024( tileKernelSource = sgemm_Col_NT_B1_MX096_NX096_KX16_src; - tileClKernel = &sgemm_Col_NT_B1_MX096_NX096_KX16_clKernel; + tileClKernel = sgemm_Col_NT_B1_MX096_NX096_KX16_clKernel; tileKernelBinary = sgemm_Col_NT_B1_MX096_NX096_KX16_bin; tileKernelBinarySize = sgemm_Col_NT_B1_MX096_NX096_KX16_binSize; - makeGemmKernel(tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions); + makeGemmKernel(&tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions); - err = clSetKernelArg(*tileClKernel, 0, sizeof(cl_mem), &A); + err = clSetKernelArg(tileClKernel, 0, sizeof(cl_mem), &A); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 1, sizeof(cl_mem), &B); + err = clSetKernelArg(tileClKernel, 1, sizeof(cl_mem), &B); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 2, sizeof(cl_mem), &C); + err = clSetKernelArg(tileClKernel, 2, sizeof(cl_mem), &C); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 3, sizeof(cl_float), &alpha); + err = clSetKernelArg(tileClKernel, 3, sizeof(cl_float), &alpha); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 4, sizeof(cl_float), &beta); + err = clSetKernelArg(tileClKernel, 4, sizeof(cl_float), &beta); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 5, sizeof(cl_uint), &M); + err = clSetKernelArg(tileClKernel, 5, sizeof(cl_uint), &M); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 6, sizeof(cl_uint), &N); + err = clSetKernelArg(tileClKernel, 6, sizeof(cl_uint), &N); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 7, sizeof(cl_uint), &K); + err = clSetKernelArg(tileClKernel, 7, sizeof(cl_uint), &K); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 8, sizeof(cl_uint), &lda); + err = clSetKernelArg(tileClKernel, 8, sizeof(cl_uint), &lda); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 9, sizeof(cl_uint), &ldb); + err = clSetKernelArg(tileClKernel, 9, sizeof(cl_uint), &ldb); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 10, sizeof(cl_uint), &ldc); + err = clSetKernelArg(tileClKernel, 10, sizeof(cl_uint), &ldc); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 11, sizeof(cl_uint), &offA); + err = clSetKernelArg(tileClKernel, 11, sizeof(cl_uint), &offA); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 12, sizeof(cl_uint), &offB); + err = clSetKernelArg(tileClKernel, 12, sizeof(cl_uint), &offB); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 13, sizeof(cl_uint), &offC); + err = clSetKernelArg(tileClKernel, 13, sizeof(cl_uint), &offC); CL_CHECK(err); @@ -416,17 +415,17 @@ clblasStatus SGEMM_SPLIT64_32( cl_event *events, bool &specialCaseHandled) { - //all the mod32 sizes that is not mod64 or mod96 ranging from 1184 to 3872 + //all the mod32 sizes that is not mod64 or mod96 ranging from 1184 to 3872 //non mod32 cases are not implemented in this approach and are of less interest const char *tileKernelSource = NULL; const char *rowKernelSource = NULL; const char *columnKernelSource = NULL; const char *singleKernelSource = NULL; - cl_kernel *tileClKernel = NULL; - cl_kernel *rowClKernel = NULL; - cl_kernel *columnClKernel = NULL; - cl_kernel *singleClKernel = NULL; + cl_kernel tileClKernel = NULL; + cl_kernel rowClKernel = NULL; + cl_kernel columnClKernel = NULL; + cl_kernel singleClKernel = NULL; const unsigned char *tileKernelBinary = NULL; const unsigned char *rowKernelBinary = NULL; @@ -439,7 +438,7 @@ clblasStatus SGEMM_SPLIT64_32( size_t singleKernelBinarySize = 0; cl_int err; - + if ((M >= 1184 && N >= 1184) && (M <= 3872 && N <= 3872) && (M % 64 != 0 && N % 64 != 0) && (M % 96 != 0 && N % 96 != 0) && (K % 16 == 0)) { if ((M % 32 == 0 && N % 32 == 0) && (transA == clblasNoTrans && transB == clblasTrans)) @@ -455,83 +454,83 @@ clblasStatus SGEMM_SPLIT64_32( size_t wgsize[2] = { 16, 16 }; tileKernelSource = sgemm_Col_NT_B1_MX064_NX064_KX16_src; - tileClKernel = &sgemm_Col_NT_B1_MX064_NX064_KX16_clKernel; + tileClKernel = sgemm_Col_NT_B1_MX064_NX064_KX16_clKernel; tileKernelBinary = sgemm_Col_NT_B1_MX064_NX064_KX16_bin; tileKernelBinarySize = sgemm_Col_NT_B1_MX064_NX064_KX16_binSize; rowKernelSource = sgemm_Col_NT_B1_MX032_NX064_KX16_ROW_src; - rowClKernel = &sgemm_Col_NT_B1_MX032_NX064_KX16_ROW_clKernel; + rowClKernel = sgemm_Col_NT_B1_MX032_NX064_KX16_ROW_clKernel; rowKernelBinary = sgemm_Col_NT_B1_MX032_NX064_KX16_ROW_bin; rowKernelBinarySize = sgemm_Col_NT_B1_MX032_NX064_KX16_ROW_binSize; columnKernelSource = sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN_src; - columnClKernel = &sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN_clKernel; + columnClKernel = sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN_clKernel; columnKernelBinary = sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN_bin; columnKernelBinarySize = sgemm_Col_NT_B1_MX064_NX032_KX16_COLUMN_binSize; singleKernelSource = sgemm_Col_NT_B1_MX032_NX032_KX16_SINGLE_src; - singleClKernel = &sgemm_Col_NT_B1_MX032_NX032_KX16_SINGLE_clKernel; + singleClKernel = sgemm_Col_NT_B1_MX032_NX032_KX16_SINGLE_clKernel; singleKernelBinary = sgemm_Col_NT_B1_MX032_NX032_KX16_SINGLE_bin; singleKernelBinarySize = sgemm_Col_NT_B1_MX032_NX032_KX16_SINGLE_binSize; - cl_kernel * Kernels[4] = { tileClKernel, rowClKernel, columnClKernel, singleClKernel }; + cl_kernel Kernels[4] = { tileClKernel, rowClKernel, columnClKernel, singleClKernel }; - makeGemmKernel(tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions); - makeGemmKernel(rowClKernel, commandQueues[0], rowKernelSource, User_srcBuildOptions, &rowKernelBinary, &rowKernelBinarySize, User_binBuildOptions); - makeGemmKernel(columnClKernel, commandQueues[0], columnKernelSource, User_srcBuildOptions, &columnKernelBinary, &columnKernelBinarySize, User_binBuildOptions); - makeGemmKernel(singleClKernel, commandQueues[0], singleKernelSource, User_srcBuildOptions, &singleKernelBinary, &singleKernelBinarySize, User_binBuildOptions); + makeGemmKernel(&tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions); + makeGemmKernel(&rowClKernel, commandQueues[0], rowKernelSource, User_srcBuildOptions, &rowKernelBinary, &rowKernelBinarySize, User_binBuildOptions); + makeGemmKernel(&columnClKernel, commandQueues[0], columnKernelSource, User_srcBuildOptions, &columnKernelBinary, &columnKernelBinarySize, User_binBuildOptions); + makeGemmKernel(&singleClKernel, commandQueues[0], singleKernelSource, User_srcBuildOptions, &singleKernelBinary, &singleKernelBinarySize, User_binBuildOptions); for (int i = 0; i < 4; i++) { - err = clSetKernelArg(*Kernels[i], 0, sizeof(cl_mem), &A); + err = clSetKernelArg(Kernels[i], 0, sizeof(cl_mem), &A); CL_CHECK(err); - err = clSetKernelArg(*Kernels[i], 1, sizeof(cl_mem), &B); + err = clSetKernelArg(Kernels[i], 1, sizeof(cl_mem), &B); CL_CHECK(err); - err = clSetKernelArg(*Kernels[i], 2, sizeof(cl_mem), &C); + err = clSetKernelArg(Kernels[i], 2, sizeof(cl_mem), &C); CL_CHECK(err); - err = clSetKernelArg(*Kernels[i], 3, sizeof(cl_float), &alpha); + err = clSetKernelArg(Kernels[i], 3, sizeof(cl_float), &alpha); CL_CHECK(err); - err = clSetKernelArg(*Kernels[i], 4, sizeof(cl_float), &beta); + err = clSetKernelArg(Kernels[i], 4, sizeof(cl_float), &beta); CL_CHECK(err); - err = clSetKernelArg(*Kernels[i], 5, sizeof(cl_uint), &M); + err = clSetKernelArg(Kernels[i], 5, sizeof(cl_uint), &M); CL_CHECK(err); - err = clSetKernelArg(*Kernels[i], 6, sizeof(cl_uint), &N); + err = clSetKernelArg(Kernels[i], 6, sizeof(cl_uint), &N); CL_CHECK(err); - err = clSetKernelArg(*Kernels[i], 7, sizeof(cl_uint), &K); + err = clSetKernelArg(Kernels[i], 7, sizeof(cl_uint), &K); CL_CHECK(err); - err = clSetKernelArg(*Kernels[i], 8, sizeof(cl_uint), &lda); + err = clSetKernelArg(Kernels[i], 8, sizeof(cl_uint), &lda); CL_CHECK(err); - err = clSetKernelArg(*Kernels[i], 9, sizeof(cl_uint), &ldb); + err = clSetKernelArg(Kernels[i], 9, sizeof(cl_uint), &ldb); CL_CHECK(err); - err = clSetKernelArg(*Kernels[i], 10, sizeof(cl_uint), &ldc); + err = clSetKernelArg(Kernels[i], 10, sizeof(cl_uint), &ldc); CL_CHECK(err); - err = clSetKernelArg(*Kernels[i], 11, sizeof(cl_uint), &offA); + err = clSetKernelArg(Kernels[i], 11, sizeof(cl_uint), &offA); CL_CHECK(err); - err = clSetKernelArg(*Kernels[i], 12, sizeof(cl_uint), &offB); + err = clSetKernelArg(Kernels[i], 12, sizeof(cl_uint), &offB); CL_CHECK(err); - err = clSetKernelArg(*Kernels[i], 13, sizeof(cl_uint), &offC); + err = clSetKernelArg(Kernels[i], 13, sizeof(cl_uint), &offC); CL_CHECK(err); } - err = clEnqueueNDRangeKernel(commandQueues[0], *Kernels[0], 2, NULL, gs, wgsize, numEventsInWaitList, eventWaitList, NULL); + err = clEnqueueNDRangeKernel(commandQueues[0], Kernels[0], 2, NULL, gs, wgsize, numEventsInWaitList, eventWaitList, NULL); gs[0] = 16; - err |= clEnqueueNDRangeKernel(commandQueues[0], *Kernels[1], 2, NULL, gs, wgsize, 0, NULL, NULL); + err |= clEnqueueNDRangeKernel(commandQueues[0], Kernels[1], 2, NULL, gs, wgsize, 0, NULL, NULL); gs[1] = 16; gs[0] = GlobalX; - err |= clEnqueueNDRangeKernel(commandQueues[0], *Kernels[2], 2, NULL, gs, wgsize, 0, NULL, NULL); + err |= clEnqueueNDRangeKernel(commandQueues[0], Kernels[2], 2, NULL, gs, wgsize, 0, NULL, NULL); gs[0] = 16; gs[1] = 16; - err |= clEnqueueNDRangeKernel(commandQueues[0], *Kernels[3], 2, NULL, gs, wgsize, 0, NULL, events); + err |= clEnqueueNDRangeKernel(commandQueues[0], Kernels[3], 2, NULL, gs, wgsize, 0, NULL, events); if (err == 0) return clblasSuccess; } } - + return clblasNotImplemented; } @@ -552,7 +551,7 @@ clblasStatus SGEMM_BRANCH_32( bool &specialCaseHandled) { const char *tileKernelSource = NULL; - cl_kernel *tileClKernel = NULL; + cl_kernel tileClKernel = NULL; size_t tileKernelBinarySize = 0; cl_int err; @@ -573,42 +572,42 @@ clblasStatus SGEMM_BRANCH_32( { specialCaseHandled = true; tileKernelSource = sgemm_Col_NN_B1_MX032_NX032_KX16_BRANCH_src; - tileClKernel = &sgemm_Col_NN_B1_MX032_NX032_KX16_BRANCH_clKernel; + tileClKernel = sgemm_Col_NN_B1_MX032_NX032_KX16_BRANCH_clKernel; tileKernelBinary = sgemm_Col_NN_B1_MX032_NX032_KX16_BRANCH_bin; tileKernelBinarySize = sgemm_Col_NN_B1_MX032_NX032_KX16_BRANCH_binSize; - makeGemmKernel(tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions); + makeGemmKernel(&tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions); - err = clSetKernelArg(*tileClKernel, 0, sizeof(cl_mem), &A); + err = clSetKernelArg(tileClKernel, 0, sizeof(cl_mem), &A); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 1, sizeof(cl_mem), &B); + err = clSetKernelArg(tileClKernel, 1, sizeof(cl_mem), &B); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 2, sizeof(cl_mem), &C); + err = clSetKernelArg(tileClKernel, 2, sizeof(cl_mem), &C); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 3, sizeof(cl_float), &alpha); + err = clSetKernelArg(tileClKernel, 3, sizeof(cl_float), &alpha); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 4, sizeof(cl_float), &beta); + err = clSetKernelArg(tileClKernel, 4, sizeof(cl_float), &beta); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 5, sizeof(cl_uint), &M); + err = clSetKernelArg(tileClKernel, 5, sizeof(cl_uint), &M); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 6, sizeof(cl_uint), &N); + err = clSetKernelArg(tileClKernel, 6, sizeof(cl_uint), &N); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 7, sizeof(cl_uint), &K); + err = clSetKernelArg(tileClKernel, 7, sizeof(cl_uint), &K); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 8, sizeof(cl_uint), &lda); + err = clSetKernelArg(tileClKernel, 8, sizeof(cl_uint), &lda); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 9, sizeof(cl_uint), &ldb); + err = clSetKernelArg(tileClKernel, 9, sizeof(cl_uint), &ldb); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 10, sizeof(cl_uint), &ldc); + err = clSetKernelArg(tileClKernel, 10, sizeof(cl_uint), &ldc); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 11, sizeof(cl_uint), &offA); + err = clSetKernelArg(tileClKernel, 11, sizeof(cl_uint), &offA); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 12, sizeof(cl_uint), &offB); + err = clSetKernelArg(tileClKernel, 12, sizeof(cl_uint), &offB); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 13, sizeof(cl_uint), &offC); + err = clSetKernelArg(tileClKernel, 13, sizeof(cl_uint), &offC); CL_CHECK(err); - err = clEnqueueNDRangeKernel(commandQueues[0], *tileClKernel, 2, NULL, + err = clEnqueueNDRangeKernel(commandQueues[0], tileClKernel, 2, NULL, gs, wgsize, numEventsInWaitList, eventWaitList, &events[0]); if (err == 0) @@ -618,42 +617,42 @@ clblasStatus SGEMM_BRANCH_32( { specialCaseHandled = true; tileKernelSource = sgemm_Col_NT_B1_MX032_NX032_KX16_BRANCH_src; - tileClKernel = &sgemm_Col_NT_B1_MX032_NX032_KX16_BRANCH_clKernel; + tileClKernel = sgemm_Col_NT_B1_MX032_NX032_KX16_BRANCH_clKernel; tileKernelBinary = sgemm_Col_NT_B1_MX032_NX032_KX16_BRANCH_bin; tileKernelBinarySize = sgemm_Col_NT_B1_MX032_NX032_KX16_BRANCH_binSize; - makeGemmKernel(tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions); + makeGemmKernel(&tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions); - err = clSetKernelArg(*tileClKernel, 0, sizeof(cl_mem), &A); + err = clSetKernelArg(tileClKernel, 0, sizeof(cl_mem), &A); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 1, sizeof(cl_mem), &B); + err = clSetKernelArg(tileClKernel, 1, sizeof(cl_mem), &B); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 2, sizeof(cl_mem), &C); + err = clSetKernelArg(tileClKernel, 2, sizeof(cl_mem), &C); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 3, sizeof(cl_float), &alpha); + err = clSetKernelArg(tileClKernel, 3, sizeof(cl_float), &alpha); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 4, sizeof(cl_float), &beta); + err = clSetKernelArg(tileClKernel, 4, sizeof(cl_float), &beta); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 5, sizeof(cl_uint), &M); + err = clSetKernelArg(tileClKernel, 5, sizeof(cl_uint), &M); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 6, sizeof(cl_uint), &N); + err = clSetKernelArg(tileClKernel, 6, sizeof(cl_uint), &N); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 7, sizeof(cl_uint), &K); + err = clSetKernelArg(tileClKernel, 7, sizeof(cl_uint), &K); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 8, sizeof(cl_uint), &lda); + err = clSetKernelArg(tileClKernel, 8, sizeof(cl_uint), &lda); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 9, sizeof(cl_uint), &ldb); + err = clSetKernelArg(tileClKernel, 9, sizeof(cl_uint), &ldb); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 10, sizeof(cl_uint), &ldc); + err = clSetKernelArg(tileClKernel, 10, sizeof(cl_uint), &ldc); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 11, sizeof(cl_uint), &offA); + err = clSetKernelArg(tileClKernel, 11, sizeof(cl_uint), &offA); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 12, sizeof(cl_uint), &offB); + err = clSetKernelArg(tileClKernel, 12, sizeof(cl_uint), &offB); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 13, sizeof(cl_uint), &offC); + err = clSetKernelArg(tileClKernel, 13, sizeof(cl_uint), &offC); CL_CHECK(err); - err = clEnqueueNDRangeKernel(commandQueues[0], *tileClKernel, 2, NULL, + err = clEnqueueNDRangeKernel(commandQueues[0], tileClKernel, 2, NULL, gs, wgsize, numEventsInWaitList, eventWaitList, &events[0]); if (err == 0) @@ -663,42 +662,42 @@ clblasStatus SGEMM_BRANCH_32( { specialCaseHandled = true; tileKernelSource = sgemm_Col_TN_B1_MX032_NX032_KX16_BRANCH_src; - tileClKernel = &sgemm_Col_TN_B1_MX032_NX032_KX16_BRANCH_clKernel; + tileClKernel = sgemm_Col_TN_B1_MX032_NX032_KX16_BRANCH_clKernel; tileKernelBinary = sgemm_Col_TN_B1_MX032_NX032_KX16_BRANCH_bin; tileKernelBinarySize = sgemm_Col_TN_B1_MX032_NX032_KX16_BRANCH_binSize; - makeGemmKernel(tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions); + makeGemmKernel(&tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions); - err = clSetKernelArg(*tileClKernel, 0, sizeof(cl_mem), &A); + err = clSetKernelArg(tileClKernel, 0, sizeof(cl_mem), &A); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 1, sizeof(cl_mem), &B); + err = clSetKernelArg(tileClKernel, 1, sizeof(cl_mem), &B); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 2, sizeof(cl_mem), &C); + err = clSetKernelArg(tileClKernel, 2, sizeof(cl_mem), &C); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 3, sizeof(cl_float), &alpha); + err = clSetKernelArg(tileClKernel, 3, sizeof(cl_float), &alpha); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 4, sizeof(cl_float), &beta); + err = clSetKernelArg(tileClKernel, 4, sizeof(cl_float), &beta); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 5, sizeof(cl_uint), &M); + err = clSetKernelArg(tileClKernel, 5, sizeof(cl_uint), &M); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 6, sizeof(cl_uint), &N); + err = clSetKernelArg(tileClKernel, 6, sizeof(cl_uint), &N); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 7, sizeof(cl_uint), &K); + err = clSetKernelArg(tileClKernel, 7, sizeof(cl_uint), &K); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 8, sizeof(cl_uint), &lda); + err = clSetKernelArg(tileClKernel, 8, sizeof(cl_uint), &lda); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 9, sizeof(cl_uint), &ldb); + err = clSetKernelArg(tileClKernel, 9, sizeof(cl_uint), &ldb); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 10, sizeof(cl_uint), &ldc); + err = clSetKernelArg(tileClKernel, 10, sizeof(cl_uint), &ldc); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 11, sizeof(cl_uint), &offA); + err = clSetKernelArg(tileClKernel, 11, sizeof(cl_uint), &offA); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 12, sizeof(cl_uint), &offB); + err = clSetKernelArg(tileClKernel, 12, sizeof(cl_uint), &offB); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 13, sizeof(cl_uint), &offC); + err = clSetKernelArg(tileClKernel, 13, sizeof(cl_uint), &offC); CL_CHECK(err); - err = clEnqueueNDRangeKernel(commandQueues[0], *tileClKernel, 2, NULL, + err = clEnqueueNDRangeKernel(commandQueues[0], tileClKernel, 2, NULL, gs, wgsize, numEventsInWaitList, eventWaitList, &events[0]); if (err == 0) @@ -726,7 +725,7 @@ clblasStatus DGEMM_BIG_MOD48( bool &specialCaseHandled) { const char *tileKernelSource = NULL; - cl_kernel *tileClKernel = NULL; + cl_kernel tileClKernel = NULL; size_t tileKernelBinarySize = 0; cl_int err; @@ -761,39 +760,39 @@ clblasStatus DGEMM_BIG_MOD48( } tileKernelSource = dgemm_Col_NT_B1_MX048_NX048_KX08_src; - tileClKernel = &dgemm_Col_NT_B1_MX048_NX048_KX08_clKernel; + tileClKernel = dgemm_Col_NT_B1_MX048_NX048_KX08_clKernel; tileKernelBinary = dgemm_Col_NT_B1_MX048_NX048_KX08_bin; tileKernelBinarySize = dgemm_Col_NT_B1_MX048_NX048_KX08_binSize; - makeGemmKernel(tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions); + makeGemmKernel(&tileClKernel, commandQueues[0], tileKernelSource, User_srcBuildOptions, &tileKernelBinary, &tileKernelBinarySize, User_binBuildOptions); - err = clSetKernelArg(*tileClKernel, 0, sizeof(cl_mem), &A); + err = clSetKernelArg(tileClKernel, 0, sizeof(cl_mem), &A); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 1, sizeof(cl_mem), &B); + err = clSetKernelArg(tileClKernel, 1, sizeof(cl_mem), &B); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 2, sizeof(cl_mem), &C); + err = clSetKernelArg(tileClKernel, 2, sizeof(cl_mem), &C); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 3, sizeof(cl_double), &alpha); + err = clSetKernelArg(tileClKernel, 3, sizeof(cl_double), &alpha); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 4, sizeof(cl_double), &beta); + err = clSetKernelArg(tileClKernel, 4, sizeof(cl_double), &beta); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 5, sizeof(cl_uint), &M); + err = clSetKernelArg(tileClKernel, 5, sizeof(cl_uint), &M); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 6, sizeof(cl_uint), &N); + err = clSetKernelArg(tileClKernel, 6, sizeof(cl_uint), &N); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 7, sizeof(cl_uint), &K); + err = clSetKernelArg(tileClKernel, 7, sizeof(cl_uint), &K); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 8, sizeof(cl_uint), &lda); + err = clSetKernelArg(tileClKernel, 8, sizeof(cl_uint), &lda); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 9, sizeof(cl_uint), &ldb); + err = clSetKernelArg(tileClKernel, 9, sizeof(cl_uint), &ldb); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 10, sizeof(cl_uint), &ldc); + err = clSetKernelArg(tileClKernel, 10, sizeof(cl_uint), &ldc); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 11, sizeof(cl_uint), &offA); + err = clSetKernelArg(tileClKernel, 11, sizeof(cl_uint), &offA); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 12, sizeof(cl_uint), &offB); + err = clSetKernelArg(tileClKernel, 12, sizeof(cl_uint), &offB); CL_CHECK(err); - err = clSetKernelArg(*tileClKernel, 13, sizeof(cl_uint), &offC); + err = clSetKernelArg(tileClKernel, 13, sizeof(cl_uint), &offC); CL_CHECK(err); status = GEMM_SPLIT_CALLS( diff --git a/src/library/blas/xgemm.cc b/src/library/blas/xgemm.cc index a2c6cb00..c677726b 100644 --- a/src/library/blas/xgemm.cc +++ b/src/library/blas/xgemm.cc @@ -73,9 +73,13 @@ static void force_gemm_column_major( return static_cast(err); const static unsigned int numGemmKernelArgs = 14; -void *gemmKernelArgs[numGemmKernelArgs]; -size_t gemmKernelArgSizes[numGemmKernelArgs]; - +#if defined( _WIN32 ) +/*__declspec( thread )*/ static void *gemmKernelArgs[numGemmKernelArgs]; +/*__declspec( thread )*/ static size_t gemmKernelArgSizes[numGemmKernelArgs]; +#else +static /*__thread*/ void *gemmKernelArgs[numGemmKernelArgs]; +static /*__thread*/ size_t gemmKernelArgSizes[numGemmKernelArgs]; +#endif /****************************************************************************** * Is beta zero for optimization @@ -115,7 +119,8 @@ static char *getKernelName(cl_kernel clKernel) // The kernelNameLength turns out to be of proper length. // CL_CHECK(err) - char *kernelName = new char[kernelNameLength]; + char *kernelName; + kernelName = new char[kernelNameLength]; err = clGetKernelInfo( clKernel, CL_KERNEL_FUNCTION_NAME, @@ -128,20 +133,14 @@ static char *getKernelName(cl_kernel clKernel) } typedef struct kernel_map_key_ { - cl_context context; // address of context - cl_device_id device; // address of device + cl_command_queue queue; const char *kernelSource; // address of kernel source } kernel_map_key; bool operator<(const kernel_map_key & l, const kernel_map_key & r) { - if (l.context < r.context) { + if (l.queue < r.queue) { return true; - } else if (r.context < l.context) { - return false; - } - if (l.device < r.device) { - return true; - } else if (r.device < l.device) { + } else if (r.queue < l.queue) { return false; } if (l.kernelSource < r.kernelSource) { @@ -166,30 +165,21 @@ void makeGemmKernel( size_t *kernelBinarySize, const char *binaryBuildOptions) { + typedef std::map kernel_map_t; - #if defined( _WIN32 ) - __declspec( thread ) static kernel_map_t *kernel_map = 0; +#if defined( _WIN32 ) + /*__declspec( thread )*/ static kernel_map_t *kernel_map = 0; #else - static __thread kernel_map_t *kernel_map = 0; + static /*__thread*/ kernel_map_t *kernel_map = 0; #endif if (!kernel_map) { kernel_map = new kernel_map_t(); } - cl_context clContext; - cl_device_id clDevice; - cl_int err; - - err = clGetCommandQueueInfo( clQueue, CL_QUEUE_CONTEXT, sizeof(clContext), &clContext, NULL); - CL_CHECK(err) - err = clGetCommandQueueInfo( clQueue, CL_QUEUE_DEVICE, sizeof(clDevice), &clDevice, NULL); - CL_CHECK(err) - // is kernel already compiled? kernel_map_key key; + key.queue = clQueue; key.kernelSource = kernelSource; - key.context = clContext; - key.device = clDevice; kernel_map_t::iterator idx = kernel_map->find(key); if (idx == kernel_map->end()) { *clKernel = NULL; @@ -198,7 +188,16 @@ void makeGemmKernel( return; } - if (true /*!*clKernel*/) { // since kernel wasn't found in map + if (true/*!*clKernel*/) { // since kernel wasn't found in map + cl_context clContext; + cl_device_id clDevice; + cl_int err; + + err = clGetCommandQueueInfo( clQueue, CL_QUEUE_CONTEXT, sizeof(clContext), &clContext, NULL); + CL_CHECK(err) + err = clGetCommandQueueInfo( clQueue, CL_QUEUE_DEVICE, sizeof(clDevice), &clDevice, NULL); + CL_CHECK(err) + // kernel has not been built, so build it (from binary, preferably) cl_program clProgram; cl_int clBinaryStatus; @@ -258,15 +257,24 @@ void makeGemmKernel( printf("\nBuild Log:\n\n"); printf("%s\n", buildLog); //printf("\n\nKernel String:\n\n"); - //printf("%s\n", kernelSource); + //printf("%s\n", *kernelSource); //FIXME: The function should be exiting at this point } + cl_uint numKernels = 0; + err = clCreateKernelsInProgram( clProgram, 1, clKernel, - NULL ); + &numKernels ); + + if (err != 0) + { + printf("KERNEL ERROR:\n%s\n", kernelSource); + + } CL_CHECK(err) + err = clReleaseProgram(clProgram); CL_CHECK(err) @@ -356,30 +364,41 @@ clblasGemm( const cl_event *eventWaitList, cl_event *events) { - - // cast types to opencl types - cl_mem A = iA; - cl_mem B = iB; - cl_uint M = static_cast( iM ); - cl_uint N = static_cast( iN ); - cl_uint K = static_cast( iK ); - cl_uint offA = static_cast( iOffA ); - cl_uint offB = static_cast( iOffB ); - cl_uint offC = static_cast( iOffC ); - cl_uint lda = static_cast( iLda ); - cl_uint ldb = static_cast( iLdb ); - cl_uint ldc = static_cast( iLdc ); - - transA = correctTranspose(transA); - transB = correctTranspose(transB); + cl_mem A; + A = iA; + cl_mem B; + B = iB; + cl_uint M; + M = static_cast( iM ); + cl_uint N; + N = static_cast( iN ); + cl_uint K; + K = static_cast( iK ); + cl_uint offA; + offA = static_cast( iOffA ); + cl_uint offB; + offB = static_cast( iOffB ); + cl_uint offC; + offC = static_cast( iOffC ); + cl_uint lda; + lda = static_cast( iLda ); + cl_uint ldb; + ldb = static_cast( iLdb ); + cl_uint ldc; + ldc = static_cast( iLdc ); + + clblasTranspose tA; + tA = correctTranspose(transA); + clblasTranspose tB; + tB = correctTranspose(transB); // if debug build, validate input // CHECK_QUEUES(numCommandQueues, commandQueues); // CHECK_EVENTS(numEventsInWaitList, eventWaitList); - // CHECK_MATRIX_A(Precision, order, transA, A, M, K, offA, lda); - // CHECK_MATRIX_B(Precision, order, transB, B, K, N, offB, ldb); + // CHECK_MATRIX_A(Precision, order, tA, A, M, K, offA, lda); + // CHECK_MATRIX_B(Precision, order, tB, B, K, N, offB, ldb); // CHECK_MATRIX_C(Precision, order, clblasNoTrans, C, M, N, offC, ldc); - force_gemm_column_major( order, transA, transB, + force_gemm_column_major( order, tA, tB, M, N, offA, offB, lda, ldb, A, B ); @@ -397,8 +416,8 @@ clblasGemm( bool specialCaseHandled = false; clblasStatus SpecialCaseStatus = GemmSpecialCases(order, - transA, - transB, + tA, + tB, M, N, K, alpha, A, offA, lda, @@ -429,17 +448,20 @@ clblasGemm( err = clGetDeviceInfo( clDevice, CL_DEVICE_MAX_COMPUTE_UNITS, sizeof(clDeviceNumCUs), &clDeviceNumCUs, NULL); //CL_CHECK(err) returnIfErr(err); - unsigned int deviceIdealNumThreads = (8 /*waves per CU*/)*(64 /*threads per wave*/)*clDeviceNumCUs; - float optimalNumElementsPerThread = ((float)M*N) / deviceIdealNumThreads; + unsigned int deviceIdealNumThreads; + deviceIdealNumThreads = (8 /*waves per CU*/)*(64 /*threads per wave*/)*clDeviceNumCUs; + float optimalNumElementsPerThread; + optimalNumElementsPerThread = ((float)M*N) / deviceIdealNumThreads; //optimalNumElementsPerThread = 32; - bool betaNonZero = !isZero(beta); + bool betaNonZero; + betaNonZero = !isZero(beta); #ifdef AUTOGEMM_PRINT_DEBUG printf("%sgemm_%3s_%s%s_B%u_%llux%llux%llu\n", getPrecision(), order==clblasColumnMajor ? "Col" : "Row", - transA==clblasNoTrans ? "N" : transA==clblasTrans ? "T" : "C", - transB==clblasNoTrans ? "N" : transB==clblasTrans ? "T" : "C", + tA==clblasNoTrans ? "N" : tA==clblasTrans ? "T" : "C", + tB==clblasNoTrans ? "N" : tB==clblasTrans ? "T" : "C", betaNonZero ? 1 : 0, iM, iN, iK ); #endif @@ -470,8 +492,9 @@ clblasGemm( unsigned int microTileNumRows; unsigned int microTileNumCols; unsigned int unroll; + gemmSelectKernel( - order, transA, transB, + order, tA, tB, iM, iN, iK, betaNonZero, optimalNumElementsPerThread, @@ -502,15 +525,15 @@ clblasGemm( if (!tileKernelSource) { printf("ERROR: gemmSelectKernel() couldn't find kernel(s) for { order=%s, transA=%s, transB=%s, M=%u, N=%u, K=%u, beta=%u, onept=%f }\n", order==clblasColumnMajor ? "ColMajor" : "RowMajor", - transA==clblasNoTrans ? "N" : transA==clblasTrans ? "T" : "C", - transB==clblasNoTrans ? "N" : transB==clblasTrans ? "T" : "C", + tA==clblasNoTrans ? "N" : tA==clblasTrans ? "T" : "C", + tB==clblasNoTrans ? "N" : tB==clblasTrans ? "T" : "C", M, N, K, betaNonZero ? 1 : 0, optimalNumElementsPerThread ); gemmSelectKernel( order, - transA, - transB, + tA, + tB, M, N, K, @@ -542,14 +565,18 @@ clblasGemm( return clblasNotImplemented; } - - unsigned int macroTileNumRows = workGroupNumRows*microTileNumRows; - unsigned int macroTileNumCols = workGroupNumCols*microTileNumCols; - bool needTileKernel = M/macroTileNumRows > 0 - && N/macroTileNumCols > 0; - bool needRowKernel = M%macroTileNumRows > 0 && N/macroTileNumCols > 0; - bool needColKernel = N%macroTileNumCols > 0 && M/macroTileNumRows > 0; - bool needCornerKernel = M%macroTileNumRows > 0 && N%macroTileNumCols > 0; + unsigned int macroTileNumRows; + macroTileNumRows = workGroupNumRows*microTileNumRows; + unsigned int macroTileNumCols; + macroTileNumCols = workGroupNumCols*microTileNumCols; + bool needTileKernel; + needTileKernel = M/macroTileNumRows > 0 && N/macroTileNumCols > 0; + bool needRowKernel; + needRowKernel = M%macroTileNumRows > 0 && N/macroTileNumCols > 0; + bool needColKernel; + needColKernel = N%macroTileNumCols > 0 && M/macroTileNumRows > 0; + bool needCornerKernel; + needCornerKernel = M%macroTileNumRows > 0 && N%macroTileNumCols > 0; #if 0 printf("For M,N,K = %u,%u,%u and %u CUs selected tile is wg=%ux%u, microTile=%ux%u, macroTile=%ux%u kernelsNeeded=%u,%u,%u,%u\n", M, N, K, clDeviceNumCUs, @@ -603,11 +630,11 @@ clblasGemm( if (needTileKernel) { //printf("enqueueing tile kernel\n"); size_t globalWorkSize[2] = {(M/macroTileNumRows)*workGroupNumRows, (N/macroTileNumCols)*workGroupNumCols }; - err = enqueueGemmKernel( commandQueues[numKernelsEnqueued%numCommandQueues], tileClKernel, + err = enqueueGemmKernel( commandQueues[0], tileClKernel, gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs, globalWorkSize, localWorkSize, numEventsInWaitList, eventWaitList, - &events[numKernelsEnqueued%numCommandQueues] ); + events ); returnIfErr(err); numKernelsEnqueued++; } @@ -618,11 +645,11 @@ clblasGemm( if (needRowKernel) { //printf("enqueueing row kernel\n"); size_t globalWorkSize[2] = {1*workGroupNumRows, (N/macroTileNumCols)*workGroupNumCols }; - err = enqueueGemmKernel( commandQueues[numKernelsEnqueued%numCommandQueues], rowClKernel, + err = enqueueGemmKernel( commandQueues[0], rowClKernel, gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs, globalWorkSize, localWorkSize, numEventsInWaitList, eventWaitList, - &events[numKernelsEnqueued%numCommandQueues] ); + events ); returnIfErr(err); numKernelsEnqueued++; } @@ -633,11 +660,11 @@ clblasGemm( if (needColKernel) { //printf("enqueueing col kernel\n"); size_t globalWorkSize[2] = { (M/macroTileNumRows)*workGroupNumRows, 1*workGroupNumCols }; - err = enqueueGemmKernel( commandQueues[numKernelsEnqueued%numCommandQueues], colClKernel, + err = enqueueGemmKernel( commandQueues[0], colClKernel, gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs, globalWorkSize, localWorkSize, numEventsInWaitList, eventWaitList, - &events[numKernelsEnqueued%numCommandQueues] ); + events ); returnIfErr(err); numKernelsEnqueued++; } @@ -648,11 +675,11 @@ clblasGemm( if (needCornerKernel) { //printf("enqueueing corner kernel\n"); size_t globalWorkSize[2] = { 1*workGroupNumRows, 1*workGroupNumCols }; - err = enqueueGemmKernel( commandQueues[numKernelsEnqueued%numCommandQueues], cornerClKernel, + err = enqueueGemmKernel( commandQueues[0], cornerClKernel, gemmKernelArgs, gemmKernelArgSizes, numGemmKernelArgs, globalWorkSize, localWorkSize, numEventsInWaitList, eventWaitList, - &events[numKernelsEnqueued%numCommandQueues] ); + events ); returnIfErr(err); numKernelsEnqueued++; } From 8bc42d899b3920cdc668cdeb9a050c39cc733af5 Mon Sep 17 00:00:00 2001 From: Piotr Sowa Date: Sun, 14 Mar 2021 13:27:31 +0100 Subject: [PATCH 2/3] Improving Stability to Check --- src/library/blas/xgemm.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/library/blas/xgemm.cc b/src/library/blas/xgemm.cc index c677726b..8c66a71d 100644 --- a/src/library/blas/xgemm.cc +++ b/src/library/blas/xgemm.cc @@ -74,11 +74,11 @@ static void force_gemm_column_major( const static unsigned int numGemmKernelArgs = 14; #if defined( _WIN32 ) -/*__declspec( thread )*/ static void *gemmKernelArgs[numGemmKernelArgs]; -/*__declspec( thread )*/ static size_t gemmKernelArgSizes[numGemmKernelArgs]; +__declspec( thread ) void *gemmKernelArgs[numGemmKernelArgs]; +__declspec( thread ) size_t gemmKernelArgSizes[numGemmKernelArgs]; #else -static /*__thread*/ void *gemmKernelArgs[numGemmKernelArgs]; -static /*__thread*/ size_t gemmKernelArgSizes[numGemmKernelArgs]; +__thread void *gemmKernelArgs[numGemmKernelArgs]; +__thread size_t gemmKernelArgSizes[numGemmKernelArgs]; #endif /****************************************************************************** @@ -317,6 +317,8 @@ void makeGemmKernel( if (err != CL_SUCCESS) return err; + clFinish(clQueue); + return CL_SUCCESS; } From 2852281f1e3b5bfe94369e2e4c1fd473aad53b40 Mon Sep 17 00:00:00 2001 From: Piotr Sowa Date: Sun, 14 Mar 2021 17:44:12 +0100 Subject: [PATCH 3/3] Improving Stability to Check --- src/library/blas/xgemm.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/library/blas/xgemm.cc b/src/library/blas/xgemm.cc index 8c66a71d..6e7e9588 100644 --- a/src/library/blas/xgemm.cc +++ b/src/library/blas/xgemm.cc @@ -267,12 +267,6 @@ void makeGemmKernel( clProgram, 1, clKernel, &numKernels ); - - if (err != 0) - { - printf("KERNEL ERROR:\n%s\n", kernelSource); - - } CL_CHECK(err) err = clReleaseProgram(clProgram);