@@ -577,6 +577,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
577
577
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy (lib, op->src [0 ]->type , op->type );
578
578
579
579
ggml_metal_kargs_cpy args = {
580
+ /* .nk0 =*/ ne00,
580
581
/* .ne00 =*/ ne00,
581
582
/* .ne01 =*/ ne01,
582
583
/* .ne02 =*/ ne02,
@@ -906,23 +907,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
906
907
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows (lib, op->src [0 ]->type );
907
908
908
909
ggml_metal_kargs_get_rows args = {
909
- /* .ne00 =*/ ne00,
910
- /* .nb01 =*/ nb01,
911
- /* .nb02 =*/ nb02,
912
- /* .ne10 =*/ ne10,
913
- /* .nb10 =*/ nb10,
914
- /* .nb11 =*/ nb11,
915
- /* .nb1 =*/ nb1,
916
- /* .nb2 =*/ nb2,
910
+ /* .ne00t =*/ ggml_is_quantized (op->src [0 ]->type ) ? ne00/16 : ne00,
911
+ /* .ne00 =*/ ne00,
912
+ /* .nb01 =*/ nb01,
913
+ /* .nb02 =*/ nb02,
914
+ /* .nb03 =*/ nb03,
915
+ /* .ne10 =*/ ne10,
916
+ /* .nb10 =*/ nb10,
917
+ /* .nb11 =*/ nb11,
918
+ /* .nb12 =*/ nb12,
919
+ /* .nb1 =*/ nb1,
920
+ /* .nb2 =*/ nb2,
921
+ /* .nb3 =*/ nb3,
917
922
};
918
923
924
+ const int nth = std::min (args.ne00t , ggml_metal_pipeline_max_theads_per_threadgroup (pipeline));
925
+
926
+ const int nw0 = (args.ne00t + nth - 1 )/nth;
927
+
919
928
ggml_metal_encoder_set_pipeline (enc, pipeline);
920
929
ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
921
930
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 1 );
922
931
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 2 );
923
932
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op), 3 );
924
933
925
- ggml_metal_encoder_dispatch_threadgroups (enc, ne10, ne11, ne12, 32 , 1 , 1 );
934
+ ggml_metal_encoder_dispatch_threadgroups (enc, nw0* ne10, ne11, ne12, nth , 1 , 1 );
926
935
927
936
return 1 ;
928
937
}
@@ -1117,7 +1126,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
1117
1126
ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
1118
1127
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 1 );
1119
1128
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 2 );
1120
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op), 3 );
1129
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op), 3 );
1121
1130
1122
1131
ggml_metal_encoder_dispatch_threadgroups (enc, ne01, ne1, ne02, 1 , 1 , 1 );
1123
1132
@@ -1172,25 +1181,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
1172
1181
/* .n_seq_tokens =*/ n_seq_tokens,
1173
1182
/* .n_seqs =*/ n_seqs,
1174
1183
/* .s_off =*/ ggml_nelements (op->src [1 ]) * sizeof (float ),
1184
+ /* .nb00 =*/ nb00,
1175
1185
/* .nb01 =*/ nb01,
1176
1186
/* .nb02 =*/ nb02,
1177
1187
/* .nb03 =*/ nb03,
1188
+ /* .nb10 =*/ nb10,
1178
1189
/* .nb11 =*/ nb11,
1179
1190
/* .nb12 =*/ nb12,
1191
+ /* .ns12 =*/ nb12/nb10,
1180
1192
/* .nb13 =*/ nb13,
1193
+ /* .nb20 =*/ nb20,
1181
1194
/* .nb21 =*/ nb21,
1195
+ /* .ns21 =*/ nb21/nb20,
1182
1196
/* .nb22 =*/ nb22,
1197
+ /* .ne30 =*/ ne30,
1183
1198
/* .nb31 =*/ nb31,
1184
1199
/* .nb41 =*/ nb41,
1185
1200
/* .nb42 =*/ nb42,
1201
+ /* .ns42 =*/ nb42/nb40,
1186
1202
/* .nb43 =*/ nb43,
1187
1203
/* .nb51 =*/ nb51,
1188
1204
/* .nb52 =*/ nb52,
1205
+ /* .ns52 =*/ nb52/nb50,
1189
1206
/* .nb53 =*/ nb53,
1207
+ /* .nb0 =*/ nb0,
1190
1208
};
1191
1209
1192
1210
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan (lib, op);
1193
1211
1212
+ GGML_ASSERT (d_state <= ggml_metal_pipeline_max_theads_per_threadgroup (pipeline));
1213
+
1194
1214
const size_t sms = ggml_metal_pipeline_get_smem (pipeline);
1195
1215
1196
1216
ggml_metal_encoder_set_pipeline (enc, pipeline);
@@ -1206,13 +1226,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
1206
1226
1207
1227
ggml_metal_encoder_set_threadgroup_memory_size (enc, sms, 0 );
1208
1228
1209
- if (ne30 == 1 ) {
1210
- // Mamba-2
1211
- ggml_metal_encoder_dispatch_threadgroups (enc, d_inner, n_head, n_seqs, d_state, 1 , 1 );
1212
- } else {
1213
- GGML_ASSERT (d_inner == 1 );
1214
- ggml_metal_encoder_dispatch_threadgroups (enc, n_head, n_seqs, 1 , d_state, 1 , 1 );
1215
- }
1229
+ ggml_metal_encoder_dispatch_threadgroups (enc, d_inner, n_head, n_seqs, d_state, 1 , 1 );
1216
1230
1217
1231
return 1 ;
1218
1232
}
@@ -1273,37 +1287,35 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1273
1287
1274
1288
GGML_ASSERT (ne00 % ggml_blck_size (op->src [0 ]->type ) == 0 );
1275
1289
1276
- // TODO: support
1277
- // const int32_t nk00 = ne00/ggml_blck_size(op->type);
1278
- const int32_t nk00 = ne00;
1279
-
1280
- int nth = 32 ; // SIMD width
1281
-
1282
- while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup (pipeline)) {
1283
- nth *= 2 ;
1290
+ int64_t nk0 = ne00;
1291
+ if (ggml_is_quantized (op->src [0 ]->type )) {
1292
+ nk0 = ne00/16 ;
1293
+ } else if (ggml_is_quantized (op->type )) {
1294
+ nk0 = ne00/ggml_blck_size (op->type );
1284
1295
}
1285
1296
1286
- nth = std::min (nth , ggml_metal_pipeline_max_theads_per_threadgroup (pipeline));
1297
+ int nth = std::min< int >(nk0 , ggml_metal_pipeline_max_theads_per_threadgroup (pipeline));
1287
1298
1288
1299
// when rows are small, we can batch them together in a single threadgroup
1289
1300
int nrptg = 1 ;
1290
1301
1291
1302
// TODO: relax this constraint in the future
1292
1303
if (ggml_blck_size (op->src [0 ]->type ) == 1 && ggml_blck_size (op->type ) == 1 ) {
1293
- if (nth > nk00 ) {
1294
- nrptg = (nth + nk00 - 1 )/nk00 ;
1295
- nth = nk00 ;
1304
+ if (nth > nk0 ) {
1305
+ nrptg = (nth + nk0 - 1 )/nk0 ;
1306
+ nth = nk0 ;
1296
1307
1297
1308
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup (pipeline)) {
1298
1309
nrptg--;
1299
1310
}
1300
1311
}
1301
1312
}
1302
1313
1303
- nth = std::min (nth, nk00 );
1314
+ nth = std::min< int > (nth, nk0 );
1304
1315
1305
1316
ggml_metal_kargs_cpy args = {
1306
- /* .ne00 =*/ nk00,
1317
+ /* .nk0 =*/ nk0,
1318
+ /* .ne00 =*/ ne00,
1307
1319
/* .ne01 =*/ ne01,
1308
1320
/* .ne02 =*/ ne02,
1309
1321
/* .ne03 =*/ ne03,
@@ -1321,12 +1333,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1321
1333
/* .nb3 =*/ nb3,
1322
1334
};
1323
1335
1336
+ const int nw0 = nrptg == 1 ? (nk0 + nth - 1 )/nth : 1 ;
1337
+
1324
1338
ggml_metal_encoder_set_pipeline (enc, pipeline);
1325
1339
ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
1326
1340
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 1 );
1327
1341
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op), 2 );
1328
1342
1329
- ggml_metal_encoder_dispatch_threadgroups (enc, ne01, ne02, ne03, nth, nrptg, 1 );
1343
+ ggml_metal_encoder_dispatch_threadgroups (enc, nw0*( ne01 + nrptg - 1 )/nrptg , ne02, ne03, nth, nrptg, 1 );
1330
1344
1331
1345
return 1 ;
1332
1346
}
0 commit comments