Skip to content

Commit

Permalink
cuda/vulkan: specify fp32-only support for some operations in support…
Browse files Browse the repository at this point in the history
…s_op (#1129)

* cuda: restrict SILU_BACK to fp32, since fp16 exceeds the desired test threshold

* vulkan: specify fp32-only support for certain ops (that are now tested for fp16 as well)

* f32 sigmoid in vulkan supports op

* Revert "f32 sigmoid in vulkan supports op"

This reverts commit c6f04b3.
  • Loading branch information
cmdr2 authored Feb 28, 2025
1 parent 0d1ea2e commit ff90529
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3153,7 +3153,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return false;
} break;
case GGML_OP_SILU_BACK:
return ggml_is_contiguous(op->src[0]);
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
break;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
Expand Down
11 changes: 6 additions & 5 deletions src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8371,7 +8371,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_TANH:
return ggml_is_contiguous(op->src[0]);
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
default:
return false;
}
Expand Down Expand Up @@ -8571,17 +8571,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_RMS_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_ADD:
case GGML_OP_ACC:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
case GGML_OP_PAD:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
Expand Down
5 changes: 1 addition & 4 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3980,10 +3980,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {

test_cases.emplace_back(new test_add1());
test_cases.emplace_back(new test_scale());

for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
test_cases.emplace_back(new test_silu_back());
}
test_cases.emplace_back(new test_silu_back());

for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
for (bool v : {false, true}) {
Expand Down

0 comments on commit ff90529

Please sign in to comment.