Skip to content

Commit ae34b48

Browse files
committed
feat: dns lookup function
1 parent d7737e9 commit ae34b48

File tree

4 files changed

+115
-60
lines changed

4 files changed

+115
-60
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ go 1.21
55
require (
66
github.com/bwmarrin/snowflake v0.3.0
77
github.com/coreos/go-iptables v0.7.0
8-
github.com/expr-lang/expr v1.15.7
8+
github.com/expr-lang/expr v1.16.3
99
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf
1010
github.com/google/gopacket v1.1.20-0.20220810144506-32ee38206866
1111
github.com/hashicorp/golang-lru/v2 v2.0.7

go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
77
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
88
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
99
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
10-
github.com/expr-lang/expr v1.15.7 h1:BK0JcWUkoW6nrbLBo6xCKhz4BvH5DSOOu1Gx5lucyZo=
11-
github.com/expr-lang/expr v1.15.7/go.mod h1:uCkhfG+x7fcZ5A5sXHKuQ07jGZRl6J0FCAaf2k4PtVQ=
10+
github.com/expr-lang/expr v1.16.3 h1:NLldf786GffptcXNxxJx5dQ+FzeWDKChBDqOOwyK8to=
11+
github.com/expr-lang/expr v1.16.3/go.mod h1:uCkhfG+x7fcZ5A5sXHKuQ07jGZRl6J0FCAaf2k4PtVQ=
1212
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf h1:NqGS3vTHzVENbIfd87cXZwdpO6MB2R1PjHMJLi4Z3ow=
1313
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf/go.mod h1:eSnAor2YCfMCVYrVNEhkLGN/r1L+J4uDjc0EUy0tfq4=
1414
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=

io/nfqueue.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
5555
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
5656
}
5757
c.Rules = append(c.Rules, "ct mark $DROP_CTMARK counter drop")
58-
c.Rules = append(c.Rules, "counter queue num $QUEUE_NUM bypass")
58+
c.Rules = append(c.Rules, "ip protocol tcp counter queue num $QUEUE_NUM bypass")
5959
}
6060
return table, nil
6161
}

ruleset/expr.go

+111-56
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package ruleset
22

33
import (
4+
"context"
45
"fmt"
56
"net"
67
"os"
78
"reflect"
89
"strings"
10+
"time"
11+
12+
"github.com/expr-lang/expr/builtin"
913

1014
"github.com/expr-lang/expr"
1115
"github.com/expr-lang/expr/ast"
@@ -104,6 +108,7 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
104108
if err != nil {
105109
return nil, err
106110
}
111+
funcMap := buildFunctionMap(geoMatcher)
107112
// Compile all rules and build a map of analyzers that are used by the rules.
108113
for _, rule := range rules {
109114
if rule.Action == "" && !rule.Log {
@@ -118,13 +123,19 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
118123
action = &a
119124
}
120125
visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)}
121-
patcher := &idPatcher{}
126+
patcher := &idPatcher{FuncMap: funcMap}
122127
program, err := expr.Compile(rule.Expr,
123128
func(c *conf.Config) {
124129
c.Strict = false
125130
c.Expect = reflect.Bool
126131
c.Visitors = append(c.Visitors, visitor, patcher)
127-
registerBuiltinFunctions(c.Functions, geoMatcher)
132+
for name, f := range funcMap {
133+
c.Functions[name] = &builtin.Function{
134+
Name: name,
135+
Func: f.Func,
136+
Types: f.Types,
137+
}
138+
}
128139
},
129140
)
130141
if err != nil {
@@ -138,24 +149,15 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
138149
if isBuiltInAnalyzer(name) || visitor.Variables[name] {
139150
continue
140151
}
141-
// Check if it's one of the built-in functions, and if so,
142-
// skip it as an analyzer & do initialization if necessary.
143-
switch name {
144-
case "geoip":
145-
if err := geoMatcher.LoadGeoIP(); err != nil {
146-
return nil, fmt.Errorf("rule %q failed to load geoip: %w", rule.Name, err)
147-
}
148-
case "geosite":
149-
if err := geoMatcher.LoadGeoSite(); err != nil {
150-
return nil, fmt.Errorf("rule %q failed to load geosite: %w", rule.Name, err)
151-
}
152-
case "cidr":
153-
// No initialization needed for CIDR.
154-
default:
155-
a, ok := fullAnMap[name]
156-
if !ok {
157-
return nil, fmt.Errorf("rule %q uses unknown analyzer %q", rule.Name, name)
152+
if f, ok := funcMap[name]; ok {
153+
// Built-in function, initialize if necessary
154+
if f.InitFunc != nil {
155+
if err := f.InitFunc(); err != nil {
156+
return nil, fmt.Errorf("rule %q failed to initialize function %q: %w", rule.Name, name, err)
157+
}
158158
}
159+
} else if a, ok := fullAnMap[name]; ok {
160+
// Analyzer, add to dependency map
159161
depAnMap[name] = a
160162
}
161163
}
@@ -191,30 +193,6 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
191193
}, nil
192194
}
193195

194-
func registerBuiltinFunctions(funcMap map[string]*ast.Function, geoMatcher *geo.GeoMatcher) {
195-
funcMap["geoip"] = &ast.Function{
196-
Name: "geoip",
197-
Func: func(params ...any) (any, error) {
198-
return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil
199-
},
200-
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)},
201-
}
202-
funcMap["geosite"] = &ast.Function{
203-
Name: "geosite",
204-
Func: func(params ...any) (any, error) {
205-
return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil
206-
},
207-
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
208-
}
209-
funcMap["cidr"] = &ast.Function{
210-
Name: "cidr",
211-
Func: func(params ...any) (any, error) {
212-
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil
213-
},
214-
Types: []reflect.Type{reflect.TypeOf((func(string, string) bool)(nil)), reflect.TypeOf(builtins.MatchCIDR)},
215-
}
216-
}
217-
218196
func streamInfoToExprEnv(info StreamInfo) map[string]interface{} {
219197
m := map[string]interface{}{
220198
"id": info.ID,
@@ -299,29 +277,106 @@ func (v *idVisitor) Visit(node *ast.Node) {
299277
// idPatcher patches the AST during expr compilation, replacing certain values with
300278
// their internal representations for better runtime performance.
301279
type idPatcher struct {
302-
Err error
280+
FuncMap map[string]*Function
281+
Err error
303282
}
304283

305284
func (p *idPatcher) Visit(node *ast.Node) {
306285
switch (*node).(type) {
307286
case *ast.CallNode:
308287
callNode := (*node).(*ast.CallNode)
309-
if callNode.Func == nil {
288+
if callNode.Callee == nil {
310289
// Ignore invalid call nodes
311290
return
312291
}
313-
switch callNode.Func.Name {
314-
case "cidr":
315-
cidrStringNode, ok := callNode.Arguments[1].(*ast.StringNode)
316-
if !ok {
317-
return
318-
}
319-
cidr, err := builtins.CompileCIDR(cidrStringNode.Value)
320-
if err != nil {
321-
p.Err = err
322-
return
292+
if f, ok := p.FuncMap[callNode.Callee.String()]; ok {
293+
if f.PatchFunc != nil {
294+
if err := f.PatchFunc(&callNode.Arguments); err != nil {
295+
p.Err = err
296+
return
297+
}
323298
}
324-
callNode.Arguments[1] = &ast.ConstantNode{Value: cidr}
325299
}
326300
}
327301
}
302+
303+
type Function struct {
304+
InitFunc func() error
305+
PatchFunc func(args *[]ast.Node) error
306+
Func func(params ...any) (any, error)
307+
Types []reflect.Type
308+
}
309+
310+
func buildFunctionMap(geoMatcher *geo.GeoMatcher) map[string]*Function {
311+
return map[string]*Function{
312+
"geoip": {
313+
InitFunc: geoMatcher.LoadGeoIP,
314+
PatchFunc: nil,
315+
Func: func(params ...any) (any, error) {
316+
return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil
317+
},
318+
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)},
319+
},
320+
"geosite": {
321+
InitFunc: geoMatcher.LoadGeoSite,
322+
PatchFunc: nil,
323+
Func: func(params ...any) (any, error) {
324+
return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil
325+
},
326+
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
327+
},
328+
"cidr": {
329+
InitFunc: nil,
330+
PatchFunc: func(args *[]ast.Node) error {
331+
cidrStringNode, ok := (*args)[1].(*ast.StringNode)
332+
if !ok {
333+
return fmt.Errorf("cidr: invalid argument type")
334+
}
335+
cidr, err := builtins.CompileCIDR(cidrStringNode.Value)
336+
if err != nil {
337+
return err
338+
}
339+
(*args)[1] = &ast.ConstantNode{Value: cidr}
340+
return nil
341+
},
342+
Func: func(params ...any) (any, error) {
343+
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil
344+
},
345+
Types: []reflect.Type{reflect.TypeOf((func(string, string) bool)(nil)), reflect.TypeOf(builtins.MatchCIDR)},
346+
},
347+
"lookup": {
348+
InitFunc: nil,
349+
PatchFunc: func(args *[]ast.Node) error {
350+
if len(*args) < 2 {
351+
// Second argument (DNS server) is optional
352+
return nil
353+
}
354+
serverStr, ok := (*args)[1].(*ast.StringNode)
355+
if !ok {
356+
return fmt.Errorf("lookup: invalid argument type")
357+
}
358+
r := &net.Resolver{
359+
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
360+
return net.Dial(network, serverStr.Value)
361+
},
362+
}
363+
(*args)[1] = &ast.ConstantNode{Value: r}
364+
return nil
365+
},
366+
Func: func(params ...any) (any, error) {
367+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
368+
defer cancel()
369+
if len(params) < 2 {
370+
return net.DefaultResolver.LookupHost(ctx, params[0].(string))
371+
} else {
372+
return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
373+
}
374+
},
375+
Types: []reflect.Type{
376+
reflect.TypeOf((func(string, string) []string)(nil)),
377+
reflect.TypeOf((func(string) []string)(nil)),
378+
reflect.TypeOf((func(string, *net.Resolver) []string)(nil)),
379+
},
380+
},
381+
}
382+
}

0 commit comments

Comments
 (0)