Skip to content

Commit c59334e

Browse files
4bit GEMM fix: per-device cudaFuncSetAttribute cache (#1952)
1 parent 5453368 commit c59334e

4 files changed

Lines changed: 28 additions & 16 deletions

File tree

csrc/gemm_4bit.cu

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,23 @@
1515
// 16-entry cache indexed by device ID. num_sms==0 means not yet populated.
1616
// Static storage is zero-initialized, so all entries start unpopulated (num_sms==0).
1717
GpuProps get_gpu_props() {
18-
static GpuProps cache[16];
18+
static GpuProps cache[16] = {};
1919
int dev = 0;
2020
cudaGetDevice(&dev);
21-
if (dev < 16 && cache[dev].num_sms == 0) {
22-
cudaDeviceGetAttribute(&cache[dev].num_sms, cudaDevAttrMultiProcessorCount, dev);
23-
cudaDeviceGetAttribute(&cache[dev].cc_major, cudaDevAttrComputeCapabilityMajor, dev);
24-
cudaDeviceGetAttribute(&cache[dev].cc_minor, cudaDevAttrComputeCapabilityMinor, dev);
25-
}
26-
return cache[dev];
21+
22+
if (dev < 16 && cache[dev].num_sms != 0)
23+
return cache[dev];
24+
25+
GpuProps props = {};
26+
props.device_index = dev;
27+
cudaDeviceGetAttribute(&props.num_sms, cudaDevAttrMultiProcessorCount, dev);
28+
cudaDeviceGetAttribute(&props.cc_major, cudaDevAttrComputeCapabilityMajor, dev);
29+
cudaDeviceGetAttribute(&props.cc_minor, cudaDevAttrComputeCapabilityMinor, dev);
30+
31+
if (dev < 16)
32+
cache[dev] = props;
33+
34+
return props;
2735
}
2836

2937
/// @brief Fused 4-bit dequantize + GEMM. Computes out[M,N] = A[M,K] @ B[N,K]^T + bias.

csrc/gemm_4bit_common.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// GPU properties queried once per device and cached in gemm_4bit.cu.
66
// Passed through dispatch into MMA launchers to avoid repeated cudaGetDevice calls.
77
struct GpuProps {
8-
int num_sms, cc_major, cc_minor;
8+
int device_index, num_sms, cc_major, cc_minor;
99
};
1010

1111
#include <cuda_bf16.h>

csrc/gemm_4bit_sm75.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,14 +317,16 @@ static void launch_tile(
317317
int M, int N, int K,
318318
int blocksize,
319319
int quant_type,
320+
GpuProps gpu,
320321
cudaStream_t stream
321322
// clang-format on
322323
) {
323324
constexpr int smem = smem_bytes_for<T, MT, NT>();
324-
static bool cfg = false;
325-
if (!cfg) {
325+
static bool cfg[16] = {};
326+
if (gpu.device_index >= 16 || !cfg[gpu.device_index]) {
326327
cudaFuncSetAttribute(gemm_4bit_sm75_m16n8k8<T, MT, NT>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
327-
cfg = true;
328+
if (gpu.device_index < 16)
329+
cfg[gpu.device_index] = true;
328330
}
329331
dim3 grid((M + MT - 1) / MT, (N + NT - 1) / NT);
330332
gemm_4bit_sm75_m16n8k8<T, MT, NT><<<grid, dim3(CTA_SIZE), smem, stream>>>(
@@ -396,7 +398,7 @@ void launch_gemm_4bit_sm75_m16n8k8(
396398

397399
// clang-format off
398400
#define LAUNCH_SM75(MT, NT) \
399-
launch_tile<T, MT, NT>(A, B, absmax, absmax_8bit, absmax_code, absmax_offset, C, bias, M, N, K, blocksize, quant_type, stream)
401+
launch_tile<T, MT, NT>(A, B, absmax, absmax_8bit, absmax_code, absmax_offset, C, bias, M, N, K, blocksize, quant_type, gpu, stream)
400402

401403
if (mt == 32 && nt == 64) LAUNCH_SM75(32, 64);
402404
else if (mt == 32 && nt == 128) LAUNCH_SM75(32, 128);

csrc/gemm_4bit_sm80.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -469,14 +469,16 @@ static void launch_tile(
469469
int M, int N, int K,
470470
int blocksize,
471471
int quant_type,
472+
GpuProps gpu,
472473
cudaStream_t stream
473474
// clang-format on
474475
) {
475476
constexpr int smem = smem_bytes_for<T, MT, NT, KC>();
476-
static bool cfg = false;
477-
if (!cfg) {
477+
static bool cfg[16] = {};
478+
if (gpu.device_index >= 16 || !cfg[gpu.device_index]) {
478479
cudaFuncSetAttribute(gemm_4bit_sm80_m16n8k16<T, MT, NT, KC>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
479-
cfg = true;
480+
if (gpu.device_index < 16)
481+
cfg[gpu.device_index] = true;
480482
}
481483
dim3 grid((M + MT - 1) / MT, (N + NT - 1) / NT);
482484
gemm_4bit_sm80_m16n8k16<T, MT, NT, KC><<<grid, dim3(CTA_SIZE), smem, stream>>>(
@@ -662,7 +664,7 @@ void launch_gemm_4bit_sm80_m16n8k16(
662664

663665
// clang-format off
664666
#define LAUNCH_SM80(MT, NT, KC) \
665-
launch_tile<T, MT, NT, KC>(A, B, absmax, absmax_8bit, absmax_code, absmax_offset, C, bias, M, N, K, blocksize, quant_type, stream)
667+
launch_tile<T, MT, NT, KC>(A, B, absmax, absmax_8bit, absmax_code, absmax_offset, C, bias, M, N, K, blocksize, quant_type, gpu, stream)
666668

667669
if (kc == 64) {
668670
if (mt == 32 && nt == 64) LAUNCH_SM80( 32, 64, 64);

0 commit comments

Comments
 (0)