diff --git a/ast/node.go b/ast/node.go index 198efa59..fbb9ae82 100644 --- a/ast/node.go +++ b/ast/node.go @@ -187,6 +187,7 @@ type BuiltinNode struct { Arguments []Node // Arguments of the builtin function. Throws bool // If true then accessing a field or array index can throw an error. Used by optimizer. Map Node // Used by optimizer to fold filter() and map() builtins. + Threshold *int // Used by optimizer for count() early termination. } // PredicateNode represents a predicate. diff --git a/compiler/compiler.go b/compiler/compiler.go index 951385cd..ed8942c9 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -937,6 +937,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { c.compile(node.Arguments[0]) c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) + var loopBreak int c.emitLoop(func() { if len(node.Arguments) == 2 { c.compile(node.Arguments[1]) @@ -945,9 +946,25 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { } c.emitCond(func() { c.emit(OpIncrementCount) + // Early termination if threshold is set + if node.Threshold != nil { + c.emit(OpGetCount) + c.emit(OpInt, *node.Threshold) + c.emit(OpMoreOrEqual) + loopBreak = c.emit(OpJumpIfTrue, placeholder) + c.emit(OpPop) + } }) }) c.emit(OpGetCount) + if node.Threshold != nil { + end := c.emit(OpJump, placeholder) + c.patchJump(loopBreak) + // Early exit path: pop the bool comparison result, push count + c.emit(OpPop) + c.emit(OpGetCount) + c.patchJump(end) + } c.emit(OpEnd) return diff --git a/optimizer/count_threshold.go b/optimizer/count_threshold.go new file mode 100644 index 00000000..d045760b --- /dev/null +++ b/optimizer/count_threshold.go @@ -0,0 +1,54 @@ +package optimizer + +import ( + . "github.com/expr-lang/expr/ast" +) + +// countThreshold optimizes count comparisons by setting a threshold for early termination. +// The threshold allows the count loop to exit early once enough matches are found. +// Patterns: +// - count(arr, pred) > N → threshold = N + 1 (exit proves > N is true) +// - count(arr, pred) >= N → threshold = N (exit proves >= N is true) +// - count(arr, pred) < N → threshold = N (exit proves < N is false) +// - count(arr, pred) <= N → threshold = N + 1 (exit proves <= N is false) +type countThreshold struct{} + +func (*countThreshold) Visit(node *Node) { + binary, ok := (*node).(*BinaryNode) + if !ok { + return + } + + count, ok := binary.Left.(*BuiltinNode) + if !ok || count.Name != "count" || len(count.Arguments) != 2 { + return + } + + integer, ok := binary.Right.(*IntegerNode) + if !ok || integer.Value < 0 { + return + } + + var threshold int + switch binary.Operator { + case ">": + threshold = integer.Value + 1 + case ">=": + threshold = integer.Value + case "<": + threshold = integer.Value + case "<=": + threshold = integer.Value + 1 + default: + return + } + + // Skip if threshold is 0 or 1 (handled by count_any optimizer) + if threshold <= 1 { + return + } + + // Set threshold on the count node for early termination + // The original comparison remains unchanged + count.Threshold = &threshold +} diff --git a/optimizer/count_threshold_test.go b/optimizer/count_threshold_test.go new file mode 100644 index 00000000..3bac6fc3 --- /dev/null +++ b/optimizer/count_threshold_test.go @@ -0,0 +1,278 @@ +package optimizer_test + +import ( + "testing" + + "github.com/expr-lang/expr" + . "github.com/expr-lang/expr/ast" + "github.com/expr-lang/expr/internal/testify/assert" + "github.com/expr-lang/expr/internal/testify/require" + "github.com/expr-lang/expr/optimizer" + "github.com/expr-lang/expr/parser" + "github.com/expr-lang/expr/vm" +) + +func TestOptimize_count_threshold_gt(t *testing.T) { + tree, err := parser.Parse(`count(items, .active) > 100`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + // Operator should remain >, but count should have threshold set + binary, ok := tree.Node.(*BinaryNode) + require.True(t, ok, "expected BinaryNode, got %T", tree.Node) + assert.Equal(t, ">", binary.Operator) + + count, ok := binary.Left.(*BuiltinNode) + require.True(t, ok, "expected BuiltinNode, got %T", binary.Left) + assert.Equal(t, "count", count.Name) + require.NotNil(t, count.Threshold) + assert.Equal(t, 101, *count.Threshold) // threshold = N + 1 for > operator +} + +func TestOptimize_count_threshold_gte(t *testing.T) { + tree, err := parser.Parse(`count(items, .active) >= 50`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + // Operator should remain >=, but count should have threshold set + binary, ok := tree.Node.(*BinaryNode) + require.True(t, ok, "expected BinaryNode, got %T", tree.Node) + assert.Equal(t, ">=", binary.Operator) + + count, ok := binary.Left.(*BuiltinNode) + require.True(t, ok, "expected BuiltinNode, got %T", binary.Left) + assert.Equal(t, "count", count.Name) + require.NotNil(t, count.Threshold) + assert.Equal(t, 50, *count.Threshold) // threshold = N for >= operator +} + +func TestOptimize_count_threshold_lt(t *testing.T) { + tree, err := parser.Parse(`count(items, .active) < 100`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + // Operator should remain <, but count should have threshold set + binary, ok := tree.Node.(*BinaryNode) + require.True(t, ok, "expected BinaryNode, got %T", tree.Node) + assert.Equal(t, "<", binary.Operator) + + count, ok := binary.Left.(*BuiltinNode) + require.True(t, ok, "expected BuiltinNode, got %T", binary.Left) + assert.Equal(t, "count", count.Name) + require.NotNil(t, count.Threshold) + assert.Equal(t, 100, *count.Threshold) // threshold = N for < operator +} + +func TestOptimize_count_threshold_lte(t *testing.T) { + tree, err := parser.Parse(`count(items, .active) <= 50`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + // Operator should remain <=, but count should have threshold set + binary, ok := tree.Node.(*BinaryNode) + require.True(t, ok, "expected BinaryNode, got %T", tree.Node) + assert.Equal(t, "<=", binary.Operator) + + count, ok := binary.Left.(*BuiltinNode) + require.True(t, ok, "expected BuiltinNode, got %T", binary.Left) + assert.Equal(t, "count", count.Name) + require.NotNil(t, count.Threshold) + assert.Equal(t, 51, *count.Threshold) // threshold = N + 1 for <= operator +} + +func TestOptimize_count_threshold_correctness(t *testing.T) { + tests := []struct { + expr string + want bool + }{ + // count > N (threshold = N + 1) + {`count(1..1000, # <= 100) > 50`, true}, // 100 matches > 50 + {`count(1..1000, # <= 100) > 100`, false}, // 100 matches not > 100 + {`count(1..1000, # <= 100) > 99`, true}, // 100 matches > 99 + {`count(1..100, # > 0) > 50`, true}, // 100 matches > 50 + {`count(1..100, # > 0) > 100`, false}, // 100 matches not > 100 + + // count >= N (threshold = N) + {`count(1..1000, # <= 100) >= 100`, true}, // 100 matches >= 100 + {`count(1..1000, # <= 100) >= 101`, false}, // 100 matches not >= 101 + {`count(1..100, # > 0) >= 50`, true}, // 100 matches >= 50 + {`count(1..100, # > 0) >= 100`, true}, // 100 matches >= 100 + + // count < N (threshold = N) + {`count(1..1000, # <= 100) < 101`, true}, // 100 matches < 101 + {`count(1..1000, # <= 100) < 100`, false}, // 100 matches not < 100 + {`count(1..1000, # <= 100) < 50`, false}, // 100 matches not < 50 + {`count(1..100, # > 0) < 101`, true}, // 100 matches < 101 + {`count(1..100, # > 0) < 100`, false}, // 100 matches not < 100 + + // count <= N (threshold = N + 1) + {`count(1..1000, # <= 100) <= 100`, true}, // 100 matches <= 100 + {`count(1..1000, # <= 100) <= 99`, false}, // 100 matches not <= 99 + {`count(1..1000, # <= 100) <= 50`, false}, // 100 matches not <= 50 + {`count(1..100, # > 0) <= 100`, true}, // 100 matches <= 100 + {`count(1..100, # > 0) <= 99`, false}, // 100 matches not <= 99 + } + + for _, tt := range tests { + t.Run(tt.expr, func(t *testing.T) { + program, err := expr.Compile(tt.expr) + require.NoError(t, err) + + output, err := expr.Run(program, nil) + require.NoError(t, err) + assert.Equal(t, tt.want, output) + }) + } +} + +func TestOptimize_count_threshold_no_optimization(t *testing.T) { + // These should NOT get a threshold (handled by count_any or not optimizable) + tests := []struct { + code string + threshold bool + }{ + {`count(items, .active) > 0`, false}, // handled by count_any + {`count(items, .active) >= 1`, false}, // handled by count_any + {`count(items, .active) < 1`, false}, // threshold = 1, skipped + {`count(items, .active) <= 0`, false}, // threshold = 1, skipped + {`count(items, .active) == 10`, false}, // not supported + } + + for _, tt := range tests { + t.Run(tt.code, func(t *testing.T) { + tree, err := parser.Parse(tt.code) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + // Check if count has threshold set + var count *BuiltinNode + if binary, ok := tree.Node.(*BinaryNode); ok { + count, _ = binary.Left.(*BuiltinNode) + } else if builtin, ok := tree.Node.(*BuiltinNode); ok { + count = builtin + } + + if count != nil && count.Name == "count" { + if tt.threshold { + assert.NotNil(t, count.Threshold, "expected threshold to be set") + } else { + assert.Nil(t, count.Threshold, "expected threshold to be nil") + } + } + }) + } +} + +// Benchmark: count > 100 with early match (element 101 matches early) +func BenchmarkCountThresholdEarlyMatch(b *testing.B) { + // Array of 10000 elements, all match predicate, threshold is 101 + // Should exit after ~101 iterations + program, _ := expr.Compile(`count(1..10000, # > 0) > 100`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} + +// Benchmark: count >= 50 with early match +func BenchmarkCountThresholdGteEarlyMatch(b *testing.B) { + // All elements match, threshold is 50 + // Should exit after ~50 iterations + program, _ := expr.Compile(`count(1..10000, # > 0) >= 50`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} + +// Benchmark: count > 100 with no early exit (not enough matches) +func BenchmarkCountThresholdNoEarlyExit(b *testing.B) { + // Only 100 elements match (# <= 100), threshold is 101 + // Must scan entire array + program, _ := expr.Compile(`count(1..10000, # <= 100) > 100`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} + +// Benchmark: Large threshold with early match +func BenchmarkCountThresholdLargeEarlyMatch(b *testing.B) { + // All 10000 match, threshold is 1000 + // Should exit after ~1000 iterations + program, _ := expr.Compile(`count(1..10000, # > 0) > 999`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} + +// Benchmark: count < N with early exit (result is false) +func BenchmarkCountThresholdLtEarlyExit(b *testing.B) { + // All 10000 match, threshold is 100 + // Should exit after ~100 iterations with result = false + program, _ := expr.Compile(`count(1..10000, # > 0) < 100`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} + +// Benchmark: count <= N with early exit (result is false) +func BenchmarkCountThresholdLteEarlyExit(b *testing.B) { + // All 10000 match, threshold is 51 + // Should exit after ~51 iterations with result = false + program, _ := expr.Compile(`count(1..10000, # > 0) <= 50`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} + +// Benchmark: count < N without early exit (result is true) +func BenchmarkCountThresholdLtNoEarlyExit(b *testing.B) { + // Only 100 elements match (# <= 100), threshold is 200 + // Must scan entire array, result = true + program, _ := expr.Compile(`count(1..10000, # <= 100) < 200`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} + +// Benchmark: count <= N without early exit (result is true) +func BenchmarkCountThresholdLteNoEarlyExit(b *testing.B) { + // Only 100 elements match (# <= 100), threshold is 101 + // Must scan entire array, result = true + program, _ := expr.Compile(`count(1..10000, # <= 100) <= 100`) + var out any + b.ResetTimer() + for n := 0; n < b.N; n++ { + out, _ = vm.Run(program, nil) + } + _ = out +} diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index fedf0208..9e4c75d3 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -44,6 +44,7 @@ func Optimize(node *Node, config *conf.Config) error { Walk(node, &sumArray{}) Walk(node, &sumMap{}) Walk(node, &countAny{}) + Walk(node, &countThreshold{}) return nil }