Skip to content

Commit 8ae32dc

Browse files
authored
metal : various optimizations + refactoring (#16446)
* metal : ssm_scan minor opts * metal : get_rows optimize * metal : cpy optimize * metal : ssm_conv opt * metal : ssm_scan simplify * metal : ssm_Scan opt
1 parent 3df2244 commit 8ae32dc

File tree

5 files changed

+251
-445
lines changed

5 files changed

+251
-445
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
338338
char base[256];
339339
char name[256];
340340

341-
snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
341+
const char * suffix = "";
342+
343+
if (op->src[1]->ne[0] % 4 == 0) {
344+
suffix = "_4";
345+
}
346+
347+
snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
342348
snprintf(name, 256, "%s", base);
343349

344350
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
@@ -352,15 +358,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
352358
}
353359

354360
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
361+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
362+
355363
char base[256];
356364
char name[256];
357365

358-
if (op->src[3]->ne[0] == 1) {
359-
snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type));
360-
} else {
361-
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
362-
}
363-
snprintf(name, 256, "%s", base);
366+
const int nsg = (ne00 + 31)/32;
367+
368+
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
369+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
364370

365371
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
366372
if (res) {
@@ -369,7 +375,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar
369375

370376
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
371377

372-
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
378+
ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
373379

374380
return res;
375381
}

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -776,9 +776,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
776776
};
777777
}
778778
case GGML_OP_GET_ROWS:
779-
{
780-
return op->ne[3] == 1;
781-
}
779+
return true;
782780
case GGML_OP_SET_ROWS:
783781
{
784782
if (op->src[0]->type != GGML_TYPE_F32) {

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ typedef struct {
178178
} ggml_metal_kargs_clamp;
179179

180180
typedef struct {
181+
int64_t nk0;
181182
int64_t ne00;
182183
int64_t ne01;
183184
int64_t ne02;
@@ -572,32 +573,45 @@ typedef struct {
572573
int64_t n_seq_tokens;
573574
int64_t n_seqs;
574575
uint64_t s_off;
576+
uint64_t nb00;
575577
uint64_t nb01;
576578
uint64_t nb02;
577579
uint64_t nb03;
580+
uint64_t nb10;
578581
uint64_t nb11;
579582
uint64_t nb12;
583+
uint64_t ns12;
580584
uint64_t nb13;
585+
uint64_t nb20;
581586
uint64_t nb21;
587+
uint64_t ns21;
582588
uint64_t nb22;
589+
int64_t ne30;
583590
uint64_t nb31;
584591
uint64_t nb41;
585592
uint64_t nb42;
593+
uint64_t ns42;
586594
uint64_t nb43;
587595
uint64_t nb51;
588596
uint64_t nb52;
597+
uint64_t ns52;
589598
uint64_t nb53;
599+
uint64_t nb0;
590600
} ggml_metal_kargs_ssm_scan;
591601

592602
typedef struct {
593-
int64_t ne00;
603+
int32_t ne00t;
604+
int32_t ne00;
594605
uint64_t nb01;
595606
uint64_t nb02;
596-
int64_t ne10;
607+
uint64_t nb03;
608+
int32_t ne10;
597609
uint64_t nb10;
598610
uint64_t nb11;
611+
uint64_t nb12;
599612
uint64_t nb1;
600613
uint64_t nb2;
614+
uint64_t nb3;
601615
} ggml_metal_kargs_get_rows;
602616

603617
typedef struct {

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
577577
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
578578

579579
ggml_metal_kargs_cpy args = {
580+
/*.nk0 =*/ ne00,
580581
/*.ne00 =*/ ne00,
581582
/*.ne01 =*/ ne01,
582583
/*.ne02 =*/ ne02,
@@ -906,23 +907,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
906907
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
907908

908909
ggml_metal_kargs_get_rows args = {
909-
/*.ne00 =*/ ne00,
910-
/*.nb01 =*/ nb01,
911-
/*.nb02 =*/ nb02,
912-
/*.ne10 =*/ ne10,
913-
/*.nb10 =*/ nb10,
914-
/*.nb11 =*/ nb11,
915-
/*.nb1 =*/ nb1,
916-
/*.nb2 =*/ nb2,
910+
/*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
911+
/*.ne00 =*/ ne00,
912+
/*.nb01 =*/ nb01,
913+
/*.nb02 =*/ nb02,
914+
/*.nb03 =*/ nb03,
915+
/*.ne10 =*/ ne10,
916+
/*.nb10 =*/ nb10,
917+
/*.nb11 =*/ nb11,
918+
/*.nb12 =*/ nb12,
919+
/*.nb1 =*/ nb1,
920+
/*.nb2 =*/ nb2,
921+
/*.nb3 =*/ nb3,
917922
};
918923

924+
const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
925+
926+
const int nw0 = (args.ne00t + nth - 1)/nth;
927+
919928
ggml_metal_encoder_set_pipeline(enc, pipeline);
920929
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
921930
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
922931
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
923932
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
924933

925-
ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1);
934+
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
926935

927936
return 1;
928937
}
@@ -1117,7 +1126,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
11171126
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
11181127
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
11191128
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1120-
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
1129+
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
11211130

11221131
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
11231132

@@ -1172,25 +1181,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
11721181
/*.n_seq_tokens =*/ n_seq_tokens,
11731182
/*.n_seqs =*/ n_seqs,
11741183
/*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
1184+
/*.nb00 =*/ nb00,
11751185
/*.nb01 =*/ nb01,
11761186
/*.nb02 =*/ nb02,
11771187
/*.nb03 =*/ nb03,
1188+
/*.nb10 =*/ nb10,
11781189
/*.nb11 =*/ nb11,
11791190
/*.nb12 =*/ nb12,
1191+
/*.ns12 =*/ nb12/nb10,
11801192
/*.nb13 =*/ nb13,
1193+
/*.nb20 =*/ nb20,
11811194
/*.nb21 =*/ nb21,
1195+
/*.ns21 =*/ nb21/nb20,
11821196
/*.nb22 =*/ nb22,
1197+
/*.ne30 =*/ ne30,
11831198
/*.nb31 =*/ nb31,
11841199
/*.nb41 =*/ nb41,
11851200
/*.nb42 =*/ nb42,
1201+
/*.ns42 =*/ nb42/nb40,
11861202
/*.nb43 =*/ nb43,
11871203
/*.nb51 =*/ nb51,
11881204
/*.nb52 =*/ nb52,
1205+
/*.ns52 =*/ nb52/nb50,
11891206
/*.nb53 =*/ nb53,
1207+
/*.nb0 =*/ nb0,
11901208
};
11911209

11921210
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
11931211

1212+
GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1213+
11941214
const size_t sms = ggml_metal_pipeline_get_smem(pipeline);
11951215

11961216
ggml_metal_encoder_set_pipeline(enc, pipeline);
@@ -1206,13 +1226,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
12061226

12071227
ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
12081228

1209-
if (ne30 == 1) {
1210-
// Mamba-2
1211-
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1212-
} else {
1213-
GGML_ASSERT(d_inner == 1);
1214-
ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
1215-
}
1229+
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
12161230

12171231
return 1;
12181232
}
@@ -1273,37 +1287,35 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
12731287

12741288
GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
12751289

1276-
// TODO: support
1277-
//const int32_t nk00 = ne00/ggml_blck_size(op->type);
1278-
const int32_t nk00 = ne00;
1279-
1280-
int nth = 32; // SIMD width
1281-
1282-
while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1283-
nth *= 2;
1290+
int64_t nk0 = ne00;
1291+
if (ggml_is_quantized(op->src[0]->type)) {
1292+
nk0 = ne00/16;
1293+
} else if (ggml_is_quantized(op->type)) {
1294+
nk0 = ne00/ggml_blck_size(op->type);
12841295
}
12851296

1286-
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1297+
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
12871298

12881299
// when rows are small, we can batch them together in a single threadgroup
12891300
int nrptg = 1;
12901301

12911302
// TODO: relax this constraint in the future
12921303
if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
1293-
if (nth > nk00) {
1294-
nrptg = (nth + nk00 - 1)/nk00;
1295-
nth = nk00;
1304+
if (nth > nk0) {
1305+
nrptg = (nth + nk0 - 1)/nk0;
1306+
nth = nk0;
12961307

12971308
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
12981309
nrptg--;
12991310
}
13001311
}
13011312
}
13021313

1303-
nth = std::min(nth, nk00);
1314+
nth = std::min<int>(nth, nk0);
13041315

13051316
ggml_metal_kargs_cpy args = {
1306-
/*.ne00 =*/ nk00,
1317+
/*.nk0 =*/ nk0,
1318+
/*.ne00 =*/ ne00,
13071319
/*.ne01 =*/ ne01,
13081320
/*.ne02 =*/ ne02,
13091321
/*.ne03 =*/ ne03,
@@ -1321,12 +1333,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
13211333
/*.nb3 =*/ nb3,
13221334
};
13231335

1336+
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
1337+
13241338
ggml_metal_encoder_set_pipeline(enc, pipeline);
13251339
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
13261340
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
13271341
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
13281342

1329-
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
1343+
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
13301344

13311345
return 1;
13321346
}

0 commit comments

Comments
 (0)