@@ -1965,10 +1965,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
1965
1965
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1966
1966
}
1967
1967
1968
- if (device->coopmat_int_support) {
1969
- CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
1970
- CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
1971
- }
1968
+ // if (device->coopmat_int_support) {
1969
+ // CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
1970
+ // CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
1971
+ // }
1972
1972
1973
1973
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1974
1974
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2078,6 +2078,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
2078
2078
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2079
2079
2080
2080
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2081
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2082
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2083
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2081
2084
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2082
2085
2083
2086
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2123,6 +2126,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
2123
2126
if (device->mul_mat ## ID ## _s[TYPE]) \
2124
2127
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2125
2128
2129
+ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2130
+ if (device->mul_mat ## ID ## _l[TYPE]) \
2131
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2132
+ if (device->mul_mat ## ID ## _m[TYPE]) \
2133
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2134
+ if (device->mul_mat ## ID ## _s[TYPE]) \
2135
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2136
+
2126
2137
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2127
2138
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2128
2139
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
@@ -2149,6 +2160,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
2149
2160
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2150
2161
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2151
2162
2163
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2164
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2165
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2166
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2167
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2168
+
2152
2169
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2153
2170
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2154
2171
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -3386,6 +3403,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
3386
3403
if (src1_type == GGML_TYPE_Q8_1) {
3387
3404
switch (src0_type) {
3388
3405
case GGML_TYPE_Q4_0:
3406
+ case GGML_TYPE_Q4_1:
3407
+ case GGML_TYPE_Q5_0:
3408
+ case GGML_TYPE_Q5_1:
3389
3409
case GGML_TYPE_Q8_0:
3390
3410
break;
3391
3411
default:
@@ -3687,8 +3707,6 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
3687
3707
return s;
3688
3708
}
3689
3709
3690
-
3691
-
3692
3710
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
3693
3711
const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
3694
3712
const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
@@ -7500,16 +7518,18 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7500
7518
pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
7501
7519
}
7502
7520
7521
+ const bool fp16acc = ctx->device->fp16;
7522
+
7503
7523
vk_pipeline p;
7504
7524
std::string shname;
7505
7525
if (shader_size == 0) {
7506
- p = ctx->device->fp16 ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
7526
+ p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
7507
7527
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
7508
7528
} else if (shader_size == 1) {
7509
- p = ctx->device->fp16 ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
7529
+ p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
7510
7530
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
7511
7531
} else if (shader_size == 2) {
7512
- p = ctx->device->fp16 ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
7532
+ p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
7513
7533
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
7514
7534
} else {
7515
7535
GGML_ASSERT(0);
@@ -7519,13 +7539,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7519
7539
7520
7540
if (mmq || k != kpad) {
7521
7541
if (shader_size == 0) {
7522
- p = ctx->device->fp16 ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
7542
+ p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
7523
7543
shname = std::string(ggml_type_name(quant)) + "_S";
7524
7544
} else if (shader_size == 1) {
7525
- p = ctx->device->fp16 ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
7545
+ p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
7526
7546
shname = std::string(ggml_type_name(quant)) + "_M";
7527
7547
} else if (shader_size == 2) {
7528
- p = ctx->device->fp16 ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
7548
+ p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
7529
7549
shname = std::string(ggml_type_name(quant)) + "_L";
7530
7550
} else {
7531
7551
GGML_ASSERT(0);
@@ -7553,16 +7573,16 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7553
7573
float * d_chk = (float *) malloc(d_sz);
7554
7574
7555
7575
for (size_t i = 0; i < x_ne; i++) {
7556
- // x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7557
- x[i] = (i % k == i / k) ? 1.0f : 0.0f;
7576
+ x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7577
+ // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
7558
7578
// x[i] = i % k;
7559
7579
}
7560
7580
7561
7581
ggml_vk_quantize_data(x, qx, x_ne, quant);
7562
7582
7563
7583
for (size_t i = 0; i < y_ne; i++) {
7564
- // y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7565
- y[i] = (i % k == i / k) ? 1.0f : 0.0f;
7584
+ y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7585
+ // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
7566
7586
// y[i] = i % k;
7567
7587
}
7568
7588
@@ -7593,14 +7613,25 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7593
7613
7594
7614
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
7595
7615
ggml_vk_ctx_begin(ctx->device, subctx);
7596
- for (size_t i = 0; i < num_it; i++) {
7597
- ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
7598
- ggml_vk_matmul(
7599
- ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
7600
- m, n, k,
7601
- k, k, m, k*m, k*n, m*n,
7602
- split_k, batch, batch, batch, 1, 1, n
7603
- );
7616
+ if (mmq) {
7617
+ for (size_t i = 0; i < num_it; i++) {
7618
+ ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
7619
+ ggml_vk_matmul(
7620
+ ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
7621
+ m, n, k,
7622
+ k, k, m, k*m, k*n, m*n,
7623
+ split_k, batch, batch, batch, 1, 1, n
7624
+ );
7625
+ }
7626
+ } else {
7627
+ for (size_t i = 0; i < num_it; i++) {
7628
+ ggml_vk_matmul(
7629
+ ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
7630
+ m, n, k,
7631
+ k, k, m, k*m, k*n, m*n,
7632
+ split_k, batch, batch, batch, 1, 1, n
7633
+ );
7634
+ }
7604
7635
}
7605
7636
ggml_vk_ctx_end(subctx);
7606
7637
@@ -7735,11 +7766,23 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
7735
7766
128, 49, 49,
7736
7767
4096, 49, 4096,
7737
7768
};
7738
- const size_t num_it = 100;
7769
+ const size_t num_it = 1;
7770
+
7771
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
7772
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
7773
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
7774
+
7775
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
7776
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
7777
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
7778
+
7779
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
7780
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
7781
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
7739
7782
7740
- ggml_vk_test_dequant_matmul(ctx, 16, 16, 32 , 2, 1 , 1, 0, GGML_TYPE_Q8_0, true);
7741
- ggml_vk_test_dequant_matmul(ctx, 16, 16, 32 , 2, 1 , 1, 1, GGML_TYPE_Q8_0, true);
7742
- ggml_vk_test_dequant_matmul(ctx, 16, 16, 32 , 2, 1 , 1, 2, GGML_TYPE_Q8_0, true);
7783
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096 , 2, num_it , 1, 0, GGML_TYPE_Q8_0, true);
7784
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096 , 2, num_it , 1, 1, GGML_TYPE_Q8_0, true);
7785
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096 , 2, num_it , 1, 2, GGML_TYPE_Q8_0, true);
7743
7786
7744
7787
abort();
7745
7788
0 commit comments