Skip to content

Commit d78a82c

Browse files
committed
instead of assert, check in should_fuse
1 parent 0ec69fb commit d78a82c

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,6 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
254254
if (with_norm) {
255255
if (clamp) {
256256
clamp_val = ggml_get_op_params_f32(clamp, 0);
257-
float max_val = ggml_get_op_params_f32(clamp, 1);
258-
GGML_ASSERT(max_val == INFINITY);
259257
}
260258
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
261259
} else {
@@ -269,7 +267,7 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
269267
}
270268
}
271269

272-
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
270+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
273271
float scale = 1.0f;
274272
float max_bias = 0.0f;
275273

@@ -295,6 +293,18 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
295293
return false;
296294
}
297295

296+
if (clamp) {
297+
if (clamp->op != GGML_OP_CLAMP) {
298+
return false;
299+
}
300+
float max_val = ggml_get_op_params_f32(clamp, 1);
301+
302+
if (max_val != INFINITY) {
303+
return false;
304+
}
305+
}
306+
307+
298308
return true;
299309
}
300310

ggml/src/ggml-cuda/topk-moe.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
1111
const bool delayed_softmax = false,
1212
ggml_tensor * weight_clamp = nullptr);
1313

14-
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
14+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
1515

1616
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);

0 commit comments

Comments
 (0)