Skip to content

Commit 34ff5e1

Browse files
committed
Vulkan: Add q4_1, q5_0 and q5_1 quants, improve integer dot code
1 parent 249595d commit 34ff5e1

File tree

6 files changed

+250
-155
lines changed

6 files changed

+250
-155
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 71 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1965,10 +1965,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
19651965
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
19661966
}
19671967

1968-
if (device->coopmat_int_support) {
1969-
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
1970-
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
1971-
}
1968+
// if (device->coopmat_int_support) {
1969+
// CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
1970+
// CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
1971+
// }
19721972

19731973
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
19741974
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2078,6 +2078,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
20782078
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
20792079

20802080
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2081+
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2082+
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2083+
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
20812084
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
20822085

20832086
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2123,6 +2126,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
21232126
if (device->mul_mat ## ID ## _s[TYPE]) \
21242127
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
21252128

2129+
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2130+
if (device->mul_mat ## ID ## _l[TYPE]) \
2131+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2132+
if (device->mul_mat ## ID ## _m[TYPE]) \
2133+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2134+
if (device->mul_mat ## ID ## _s[TYPE]) \
2135+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2136+
21262137
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
21272138
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
21282139
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
@@ -2149,6 +2160,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
21492160
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
21502161
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
21512162

2163+
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2164+
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2165+
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2166+
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2167+
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2168+
21522169
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
21532170
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
21542171
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -3386,6 +3403,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
33863403
if (src1_type == GGML_TYPE_Q8_1) {
33873404
switch (src0_type) {
33883405
case GGML_TYPE_Q4_0:
3406+
case GGML_TYPE_Q4_1:
3407+
case GGML_TYPE_Q5_0:
3408+
case GGML_TYPE_Q5_1:
33893409
case GGML_TYPE_Q8_0:
33903410
break;
33913411
default:
@@ -3687,8 +3707,6 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
36873707
return s;
36883708
}
36893709

3690-
3691-
36923710
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
36933711
const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
36943712
const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
@@ -7500,16 +7518,18 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
75007518
pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
75017519
}
75027520

7521+
const bool fp16acc = ctx->device->fp16;
7522+
75037523
vk_pipeline p;
75047524
std::string shname;
75057525
if (shader_size == 0) {
7506-
p = ctx->device->fp16 ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
7526+
p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
75077527
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
75087528
} else if (shader_size == 1) {
7509-
p = ctx->device->fp16 ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
7529+
p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
75107530
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
75117531
} else if (shader_size == 2) {
7512-
p = ctx->device->fp16 ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
7532+
p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
75137533
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
75147534
} else {
75157535
GGML_ASSERT(0);
@@ -7519,13 +7539,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
75197539

75207540
if (mmq || k != kpad) {
75217541
if (shader_size == 0) {
7522-
p = ctx->device->fp16 ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
7542+
p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
75237543
shname = std::string(ggml_type_name(quant)) + "_S";
75247544
} else if (shader_size == 1) {
7525-
p = ctx->device->fp16 ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
7545+
p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
75267546
shname = std::string(ggml_type_name(quant)) + "_M";
75277547
} else if (shader_size == 2) {
7528-
p = ctx->device->fp16 ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
7548+
p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
75297549
shname = std::string(ggml_type_name(quant)) + "_L";
75307550
} else {
75317551
GGML_ASSERT(0);
@@ -7553,16 +7573,16 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
75537573
float * d_chk = (float *) malloc(d_sz);
75547574

75557575
for (size_t i = 0; i < x_ne; i++) {
7556-
// x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7557-
x[i] = (i % k == i / k) ? 1.0f : 0.0f;
7576+
x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7577+
// x[i] = (i % k == i / k) ? 1.0f : 0.0f;
75587578
// x[i] = i % k;
75597579
}
75607580

75617581
ggml_vk_quantize_data(x, qx, x_ne, quant);
75627582

75637583
for (size_t i = 0; i < y_ne; i++) {
7564-
// y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7565-
y[i] = (i % k == i / k) ? 1.0f : 0.0f;
7584+
y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7585+
// y[i] = (i % k == i / k) ? 1.0f : 0.0f;
75667586
// y[i] = i % k;
75677587
}
75687588

@@ -7593,14 +7613,25 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
75937613

75947614
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
75957615
ggml_vk_ctx_begin(ctx->device, subctx);
7596-
for (size_t i = 0; i < num_it; i++) {
7597-
ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
7598-
ggml_vk_matmul(
7599-
ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
7600-
m, n, k,
7601-
k, k, m, k*m, k*n, m*n,
7602-
split_k, batch, batch, batch, 1, 1, n
7603-
);
7616+
if (mmq) {
7617+
for (size_t i = 0; i < num_it; i++) {
7618+
ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
7619+
ggml_vk_matmul(
7620+
ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
7621+
m, n, k,
7622+
k, k, m, k*m, k*n, m*n,
7623+
split_k, batch, batch, batch, 1, 1, n
7624+
);
7625+
}
7626+
} else {
7627+
for (size_t i = 0; i < num_it; i++) {
7628+
ggml_vk_matmul(
7629+
ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
7630+
m, n, k,
7631+
k, k, m, k*m, k*n, m*n,
7632+
split_k, batch, batch, batch, 1, 1, n
7633+
);
7634+
}
76047635
}
76057636
ggml_vk_ctx_end(subctx);
76067637

@@ -7735,11 +7766,23 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
77357766
128, 49, 49,
77367767
4096, 49, 4096,
77377768
};
7738-
const size_t num_it = 100;
7769+
const size_t num_it = 1;
7770+
7771+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
7772+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
7773+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
7774+
7775+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
7776+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
7777+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
7778+
7779+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
7780+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
7781+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
77397782

7740-
ggml_vk_test_dequant_matmul(ctx, 16, 16, 32, 2, 1, 1, 0, GGML_TYPE_Q8_0, true);
7741-
ggml_vk_test_dequant_matmul(ctx, 16, 16, 32, 2, 1, 1, 1, GGML_TYPE_Q8_0, true);
7742-
ggml_vk_test_dequant_matmul(ctx, 16, 16, 32, 2, 1, 1, 2, GGML_TYPE_Q8_0, true);
7783+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
7784+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
7785+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
77437786

77447787
abort();
77457788

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_mmq.comp

Lines changed: 0 additions & 37 deletions
This file was deleted.

0 commit comments

Comments
 (0)