diff --git a/conf/config.go b/conf/config.go index d629958e5..2c14d9882 100644 --- a/conf/config.go +++ b/conf/config.go @@ -10,43 +10,41 @@ import ( "github.com/expr-lang/expr/vm/runtime" ) -const ( - // DefaultMemoryBudget represents an upper limit of memory usage +var ( + // DefaultMemoryBudget represents default maximum allowed memory usage by the vm.VM. DefaultMemoryBudget uint = 1e6 - // DefaultMaxNodes represents an upper limit of AST nodes - DefaultMaxNodes uint = 10000 + // DefaultMaxNodes represents default maximum allowed AST nodes by the compiler. + DefaultMaxNodes uint = 1e4 ) type FunctionsTable map[string]*builtin.Function type Config struct { - EnvObject any - Env nature.Nature - Expect reflect.Kind - ExpectAny bool - Optimize bool - Strict bool - Profile bool - MaxNodes uint - MemoryBudget uint - ConstFns map[string]reflect.Value - Visitors []ast.Visitor - Functions FunctionsTable - Builtins FunctionsTable - Disabled map[string]bool // disabled builtins + EnvObject any + Env nature.Nature + Expect reflect.Kind + ExpectAny bool + Optimize bool + Strict bool + Profile bool + MaxNodes uint + ConstFns map[string]reflect.Value + Visitors []ast.Visitor + Functions FunctionsTable + Builtins FunctionsTable + Disabled map[string]bool // disabled builtins } // CreateNew creates new config with default values. func CreateNew() *Config { c := &Config{ - Optimize: true, - MaxNodes: DefaultMaxNodes, - MemoryBudget: DefaultMemoryBudget, - ConstFns: make(map[string]reflect.Value), - Functions: make(map[string]*builtin.Function), - Builtins: make(map[string]*builtin.Function), - Disabled: make(map[string]bool), + Optimize: true, + MaxNodes: DefaultMaxNodes, + ConstFns: make(map[string]reflect.Value), + Functions: make(map[string]*builtin.Function), + Builtins: make(map[string]*builtin.Function), + Disabled: make(map[string]bool), } for _, f := range builtin.Builtins { c.Builtins[f.Name] = f diff --git a/expr.go b/expr.go index 33b7cf354..48298fe7e 100644 --- a/expr.go +++ b/expr.go @@ -195,6 +195,15 @@ func Timezone(name string) Option { }) } +// MaxNodes sets the maximum number of nodes allowed in the expression. +// By default, the maximum number of nodes is conf.DefaultMaxNodes. +// If MaxNodes is set to 0, the node budget check is disabled. +func MaxNodes(n uint) Option { + return func(c *conf.Config) { + c.MaxNodes = n + } +} + // Compile parses and compiles given input expression to bytecode program. func Compile(input string, ops ...Option) (*vm.Program, error) { config := conf.CreateNew() diff --git a/expr_test.go b/expr_test.go index 241767a09..57d2cfbbb 100644 --- a/expr_test.go +++ b/expr_test.go @@ -10,9 +10,11 @@ import ( "testing" "time" + "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/internal/testify/assert" "github.com/expr-lang/expr/internal/testify/require" "github.com/expr-lang/expr/types" + "github.com/expr-lang/expr/vm" "github.com/expr-lang/expr" "github.com/expr-lang/expr/ast" @@ -2225,26 +2227,6 @@ func TestEval_slices_out_of_bound(t *testing.T) { } } -func TestMemoryBudget(t *testing.T) { - tests := []struct { - code string - }{ - {`map(1..100, {map(1..100, {map(1..100, {0})})})`}, - {`len(1..10000000)`}, - } - - for _, tt := range tests { - t.Run(tt.code, func(t *testing.T) { - program, err := expr.Compile(tt.code) - require.NoError(t, err, "compile error") - - _, err = expr.Run(program, nil) - assert.Error(t, err, "run error") - assert.Contains(t, err.Error(), "memory budget exceeded") - }) - } -} - func TestExpr_custom_tests(t *testing.T) { f, err := os.Open("custom_tests.json") if os.IsNotExist(err) { @@ -2731,3 +2713,55 @@ func TestIssue785_get_nil(t *testing.T) { }) } } + +func TestMaxNodes(t *testing.T) { + maxNodes := uint(100) + + code := "" + for i := 0; i < int(maxNodes); i++ { + code += "1; " + } + + _, err := expr.Compile(code, expr.MaxNodes(maxNodes)) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum allowed nodes") + + _, err = expr.Compile(code, expr.MaxNodes(maxNodes+1)) + require.NoError(t, err) +} + +func TestMaxNodesDisabled(t *testing.T) { + code := "" + for i := 0; i < 2*int(conf.DefaultMaxNodes); i++ { + code += "1; " + } + + _, err := expr.Compile(code, expr.MaxNodes(0)) + require.NoError(t, err) +} + +func TestMemoryBudget(t *testing.T) { + tests := []struct { + code string + max int + }{ + {`map(1..100, {map(1..100, {map(1..100, {0})})})`, -1}, + {`len(1..10000000)`, -1}, + {`1..100`, 100}, + } + + for _, tt := range tests { + t.Run(tt.code, func(t *testing.T) { + program, err := expr.Compile(tt.code) + require.NoError(t, err, "compile error") + + vm := vm.VM{} + if tt.max > 0 { + vm.MemoryBudget = uint(tt.max) + } + _, err = vm.Run(program, nil) + require.Error(t, err, "run error") + assert.Contains(t, err.Error(), "memory budget exceeded") + }) + } +} diff --git a/vm/vm.go b/vm/vm.go index de13cade1..3018619d9 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -75,7 +75,6 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { if len(vm.Variables) < program.variables { vm.Variables = make([]any, program.variables) } - if vm.MemoryBudget == 0 { vm.MemoryBudget = conf.DefaultMemoryBudget }