41
41
GGML_METAL_KERNEL_TYPE_TANH,
42
42
GGML_METAL_KERNEL_TYPE_RELU,
43
43
GGML_METAL_KERNEL_TYPE_GELU,
44
+ GGML_METAL_KERNEL_TYPE_GELU_4,
44
45
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
46
+ GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
45
47
GGML_METAL_KERNEL_TYPE_SILU,
48
+ GGML_METAL_KERNEL_TYPE_SILU_4,
46
49
GGML_METAL_KERNEL_TYPE_SOFT_MAX,
47
50
GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
48
51
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
@@ -473,8 +476,11 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
473
476
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_TANH, tanh , true );
474
477
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RELU, relu, true );
475
478
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GELU, gelu, true );
479
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true );
476
480
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true );
481
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true );
477
482
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU, silu, true );
483
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true );
478
484
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction );
479
485
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction );
480
486
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true );
@@ -1178,6 +1184,9 @@ static enum ggml_status ggml_metal_graph_compute(
1178
1184
} break ;
1179
1185
case GGML_OP_UNARY:
1180
1186
switch (ggml_get_unary_op (gf->nodes [i])) {
1187
+ // we are not taking into account the strides, so for now require contiguous tensors
1188
+ GGML_ASSERT (ggml_is_contiguous (src0));
1189
+
1181
1190
case GGML_UNARY_OP_TANH:
1182
1191
{
1183
1192
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_TANH].pipeline ;
@@ -1204,42 +1213,60 @@ static enum ggml_status ggml_metal_graph_compute(
1204
1213
} break ;
1205
1214
case GGML_UNARY_OP_GELU:
1206
1215
{
1207
- id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GELU].pipeline ;
1216
+ int64_t n = ggml_nelements (dst);
1217
+
1218
+ id <MTLComputePipelineState > pipeline = nil ;
1219
+
1220
+ if (n % 4 == 0 ) {
1221
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GELU_4].pipeline ;
1222
+ n /= 4 ;
1223
+ } else {
1224
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GELU].pipeline ;
1225
+ }
1208
1226
1209
1227
[encoder setComputePipelineState: pipeline];
1210
1228
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1211
1229
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1212
1230
1213
- const int64_t n = ggml_nelements (dst);
1214
- GGML_ASSERT (n % 4 == 0 );
1215
-
1216
- [encoder dispatchThreadgroups: MTLSizeMake (n/4 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1231
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1217
1232
} break ;
1218
1233
case GGML_UNARY_OP_GELU_QUICK:
1219
1234
{
1220
- id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline ;
1235
+ int64_t n = ggml_nelements (dst);
1236
+
1237
+ id <MTLComputePipelineState > pipeline = nil ;
1238
+
1239
+ if (n % 4 == 0 ) {
1240
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline ;
1241
+ n /= 4 ;
1242
+ } else {
1243
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline ;
1244
+ }
1221
1245
1222
1246
[encoder setComputePipelineState: pipeline];
1223
1247
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1224
1248
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1225
1249
1226
- const int64_t n = ggml_nelements (dst);
1227
- GGML_ASSERT (n % 4 == 0 );
1228
-
1229
- [encoder dispatchThreadgroups: MTLSizeMake (n/4 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1250
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1230
1251
} break ;
1231
1252
case GGML_UNARY_OP_SILU:
1232
1253
{
1233
- id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SILU].pipeline ;
1254
+ int64_t n = ggml_nelements (dst);
1255
+
1256
+ id <MTLComputePipelineState > pipeline = nil ;
1257
+
1258
+ if (n % 4 == 0 ) {
1259
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SILU_4].pipeline ;
1260
+ n /= 4 ;
1261
+ } else {
1262
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SILU].pipeline ;
1263
+ }
1234
1264
1235
1265
[encoder setComputePipelineState: pipeline];
1236
1266
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1237
1267
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1238
1268
1239
- const int64_t n = ggml_nelements (dst);
1240
- GGML_ASSERT (n % 4 == 0 );
1241
-
1242
- [encoder dispatchThreadgroups: MTLSizeMake (n/4 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1269
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1243
1270
} break ;
1244
1271
default :
1245
1272
{
0 commit comments