Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ast/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ type BuiltinNode struct {
Arguments []Node // Arguments of the builtin function.
Throws bool // If true then accessing a field or array index can throw an error. Used by optimizer.
Map Node // Used by optimizer to fold filter() and map() builtins.
Threshold *int // Used by optimizer for count() early termination.
}

// PredicateNode represents a predicate.
Expand Down
17 changes: 17 additions & 0 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
c.compile(node.Arguments[0])
c.derefInNeeded(node.Arguments[0])
c.emit(OpBegin)
var loopBreak int
c.emitLoop(func() {
if len(node.Arguments) == 2 {
c.compile(node.Arguments[1])
Expand All @@ -939,9 +940,25 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
}
c.emitCond(func() {
c.emit(OpIncrementCount)
// Early termination if threshold is set
if node.Threshold != nil {
c.emit(OpGetCount)
c.emit(OpInt, *node.Threshold)
c.emit(OpMoreOrEqual)
loopBreak = c.emit(OpJumpIfTrue, placeholder)
c.emit(OpPop)
}
})
})
c.emit(OpGetCount)
if node.Threshold != nil {
end := c.emit(OpJump, placeholder)
c.patchJump(loopBreak)
// Early exit path: pop the bool comparison result, push count
c.emit(OpPop)
c.emit(OpGetCount)
c.patchJump(end)
}
c.emit(OpEnd)
return

Expand Down
48 changes: 48 additions & 0 deletions optimizer/count_threshold.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package optimizer

import (
. "github.com/expr-lang/expr/ast"
)

// countThreshold optimizes count comparisons by setting a threshold for early termination.
// The threshold allows the count loop to exit early once enough matches are found.
// Patterns:
// - count(arr, pred) > N → threshold = N + 1 (need more than N matches)
// - count(arr, pred) >= N → threshold = N (need at least N matches)
type countThreshold struct{}

func (*countThreshold) Visit(node *Node) {
binary, ok := (*node).(*BinaryNode)
if !ok {
return
}

count, ok := binary.Left.(*BuiltinNode)
if !ok || count.Name != "count" || len(count.Arguments) != 2 {
return
}

integer, ok := binary.Right.(*IntegerNode)
if !ok || integer.Value < 0 {
return
}

var threshold int
switch binary.Operator {
case ">":
threshold = integer.Value + 1
case ">=":
threshold = integer.Value
default:
return
}

// Skip if threshold is 0 or 1 (handled by count_any optimizer)
if threshold <= 1 {
return
}

// Set threshold on the count node for early termination
// The original comparison remains unchanged
count.Threshold = &threshold
}
175 changes: 175 additions & 0 deletions optimizer/count_threshold_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
package optimizer_test

import (
"testing"

"github.com/expr-lang/expr"
. "github.com/expr-lang/expr/ast"
"github.com/expr-lang/expr/internal/testify/assert"
"github.com/expr-lang/expr/internal/testify/require"
"github.com/expr-lang/expr/optimizer"
"github.com/expr-lang/expr/parser"
"github.com/expr-lang/expr/vm"
)

func TestOptimize_count_threshold_gt(t *testing.T) {
tree, err := parser.Parse(`count(items, .active) > 100`)
require.NoError(t, err)

err = optimizer.Optimize(&tree.Node, nil)
require.NoError(t, err)

// Operator should remain >, but count should have threshold set
binary, ok := tree.Node.(*BinaryNode)
require.True(t, ok, "expected BinaryNode, got %T", tree.Node)
assert.Equal(t, ">", binary.Operator)

count, ok := binary.Left.(*BuiltinNode)
require.True(t, ok, "expected BuiltinNode, got %T", binary.Left)
assert.Equal(t, "count", count.Name)
require.NotNil(t, count.Threshold)
assert.Equal(t, 101, *count.Threshold) // threshold = N + 1 for > operator
}

func TestOptimize_count_threshold_gte(t *testing.T) {
tree, err := parser.Parse(`count(items, .active) >= 50`)
require.NoError(t, err)

err = optimizer.Optimize(&tree.Node, nil)
require.NoError(t, err)

// Operator should remain >=, but count should have threshold set
binary, ok := tree.Node.(*BinaryNode)
require.True(t, ok, "expected BinaryNode, got %T", tree.Node)
assert.Equal(t, ">=", binary.Operator)

count, ok := binary.Left.(*BuiltinNode)
require.True(t, ok, "expected BuiltinNode, got %T", binary.Left)
assert.Equal(t, "count", count.Name)
require.NotNil(t, count.Threshold)
assert.Equal(t, 50, *count.Threshold) // threshold = N for >= operator
}

func TestOptimize_count_threshold_correctness(t *testing.T) {
tests := []struct {
expr string
want bool
}{
// count > N (threshold = N + 1)
{`count(1..1000, # <= 100) > 50`, true}, // 100 matches > 50
{`count(1..1000, # <= 100) > 100`, false}, // 100 matches not > 100
{`count(1..1000, # <= 100) > 99`, true}, // 100 matches > 99
{`count(1..100, # > 0) > 50`, true}, // 100 matches > 50
{`count(1..100, # > 0) > 100`, false}, // 100 matches not > 100

// count >= N (threshold = N)
{`count(1..1000, # <= 100) >= 100`, true}, // 100 matches >= 100
{`count(1..1000, # <= 100) >= 101`, false}, // 100 matches not >= 101
{`count(1..100, # > 0) >= 50`, true}, // 100 matches >= 50
{`count(1..100, # > 0) >= 100`, true}, // 100 matches >= 100
}

for _, tt := range tests {
t.Run(tt.expr, func(t *testing.T) {
program, err := expr.Compile(tt.expr)
require.NoError(t, err)

output, err := expr.Run(program, nil)
require.NoError(t, err)
assert.Equal(t, tt.want, output)
})
}
}

func TestOptimize_count_threshold_no_optimization(t *testing.T) {
// These should NOT get a threshold (handled by count_any or not optimizable)
tests := []struct {
code string
threshold bool
}{
{`count(items, .active) > 0`, false}, // handled by count_any
{`count(items, .active) >= 1`, false}, // handled by count_any
{`count(items, .active) < 10`, false}, // not supported yet
{`count(items, .active) <= 10`, false}, // not supported yet
{`count(items, .active) == 10`, false}, // not supported
}

for _, tt := range tests {
t.Run(tt.code, func(t *testing.T) {
tree, err := parser.Parse(tt.code)
require.NoError(t, err)

err = optimizer.Optimize(&tree.Node, nil)
require.NoError(t, err)

// Check if count has threshold set
var count *BuiltinNode
if binary, ok := tree.Node.(*BinaryNode); ok {
count, _ = binary.Left.(*BuiltinNode)
} else if builtin, ok := tree.Node.(*BuiltinNode); ok {
count = builtin
}

if count != nil && count.Name == "count" {
if tt.threshold {
assert.NotNil(t, count.Threshold, "expected threshold to be set")
} else {
assert.Nil(t, count.Threshold, "expected threshold to be nil")
}
}
})
}
}

// Benchmark: count > 100 with early match (element 101 matches early)
func BenchmarkCountThresholdEarlyMatch(b *testing.B) {
// Array of 10000 elements, all match predicate, threshold is 101
// Should exit after ~101 iterations
program, _ := expr.Compile(`count(1..10000, # > 0) > 100`)
var out any
b.ResetTimer()
for n := 0; n < b.N; n++ {
out, _ = vm.Run(program, nil)
}
_ = out
}

// Benchmark: count >= 50 with early match
func BenchmarkCountThresholdGteEarlyMatch(b *testing.B) {
// All elements match, threshold is 50
// Should exit after ~50 iterations
program, _ := expr.Compile(`count(1..10000, # > 0) >= 50`)
var out any
b.ResetTimer()
for n := 0; n < b.N; n++ {
out, _ = vm.Run(program, nil)
}
_ = out
}

// Benchmark: count > 100 with no early exit (not enough matches)
func BenchmarkCountThresholdNoEarlyExit(b *testing.B) {
// Only 100 elements match (# <= 100), threshold is 101
// Must scan entire array
program, _ := expr.Compile(`count(1..10000, # <= 100) > 100`)
var out any
b.ResetTimer()
for n := 0; n < b.N; n++ {
out, _ = vm.Run(program, nil)
}
_ = out
}

// Benchmark: Large threshold with early match
func BenchmarkCountThresholdLargeEarlyMatch(b *testing.B) {
// All 10000 match, threshold is 1000
// Should exit after ~1000 iterations
program, _ := expr.Compile(`count(1..10000, # > 0) > 999`)
var out any
b.ResetTimer()
for n := 0; n < b.N; n++ {
out, _ = vm.Run(program, nil)
}
_ = out
}

1 change: 1 addition & 0 deletions optimizer/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func Optimize(node *Node, config *conf.Config) error {
Walk(node, &sumArray{})
Walk(node, &sumMap{})
Walk(node, &countAny{})
Walk(node, &countThreshold{})
return nil
}

Expand Down
Loading