@@ -469,14 +469,16 @@ static void launch_tile(
469469 int M, int N, int K,
470470 int blocksize,
471471 int quant_type,
472+ GpuProps gpu,
472473 cudaStream_t stream
473474 // clang-format on
474475) {
475476 constexpr int smem = smem_bytes_for<T, MT, NT, KC>();
476- static bool cfg = false ;
477- if (!cfg) {
477+ static bool cfg[ 16 ] = {} ;
478+ if (gpu. device_index >= 16 || !cfg[gpu. device_index ] ) {
478479 cudaFuncSetAttribute (gemm_4bit_sm80_m16n8k16<T, MT, NT, KC>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
479- cfg = true ;
480+ if (gpu.device_index < 16 )
481+ cfg[gpu.device_index ] = true ;
480482 }
481483 dim3 grid ((M + MT - 1 ) / MT, (N + NT - 1 ) / NT);
482484 gemm_4bit_sm80_m16n8k16<T, MT, NT, KC><<<grid, dim3 (CTA_SIZE), smem, stream>>> (
@@ -662,7 +664,7 @@ void launch_gemm_4bit_sm80_m16n8k16(
662664
663665 // clang-format off
664666#define LAUNCH_SM80 (MT, NT, KC ) \
665- launch_tile<T, MT, NT, KC>(A, B, absmax, absmax_8bit, absmax_code, absmax_offset, C, bias, M, N, K, blocksize, quant_type, stream)
667+ launch_tile<T, MT, NT, KC>(A, B, absmax, absmax_8bit, absmax_code, absmax_offset, C, bias, M, N, K, blocksize, quant_type, gpu, stream)
666668
667669 if (kc == 64 ) {
668670 if (mt == 32 && nt == 64 ) LAUNCH_SM80 ( 32 , 64 , 64 );
0 commit comments