9
9
"sync"
10
10
"time"
11
11
12
+ "github.com/fsnotify/fsnotify"
12
13
"github.com/miekg/dns"
13
14
)
14
15
@@ -21,23 +22,33 @@ type record struct {
21
22
}
22
23
23
24
var (
24
- file string
25
- tcp bool
26
- udp bool
27
- addr string
28
-
29
- recordMap map [string ]* record
25
+ file string
26
+ tcp bool
27
+ udp bool
28
+ addr string
29
+ autoload bool
30
+
31
+ recordLock sync.Mutex
32
+ recordMap map [string ]* record
30
33
)
31
34
32
35
func init () {
33
36
flag .StringVar (& file , "f" , "record.json" , "record file" )
34
37
flag .BoolVar (& tcp , "t" , false , "listen tcp" )
35
38
flag .BoolVar (& udp , "u" , true , "listen udp" )
36
39
flag .StringVar (& addr , "l" , ":53" , "listen address" )
40
+ flag .BoolVar (& autoload , "a" , true , "auto reload record file" )
37
41
38
42
recordMap = make (map [string ]* record )
39
43
}
40
44
45
+ func lookUpRecord (k string ) (r * record , ok bool ) {
46
+ recordLock .Lock ()
47
+ r , ok = recordMap [k ]
48
+ recordLock .Unlock ()
49
+ return
50
+ }
51
+
41
52
type handler struct {
42
53
isTcpServer bool
43
54
}
@@ -51,7 +62,7 @@ func (this *handler) packAnswers(msg *dns.Msg, qtype uint16, domain string) {
51
62
return
52
63
}
53
64
54
- r , ok := recordMap [ domain ]
65
+ r , ok := lookUpRecord ( domain )
55
66
if ! ok {
56
67
log .Println ("domain" , domain , "not found" )
57
68
return
@@ -124,18 +135,55 @@ func runUdpServer() {
124
135
}
125
136
}
126
137
127
- func main () {
128
- flag .Parse ()
129
-
138
+ func loadFile () error {
130
139
records , err := readRecords (file )
131
140
if err != nil {
132
- log . Fatal ( err )
141
+ return err
133
142
}
134
143
144
+ recordLock .Lock ()
145
+ defer recordLock .Unlock ()
146
+
147
+ recordMap = make (map [string ]* record )
148
+
135
149
for _ , v := range records {
136
150
recordMap [v .Host + "." ] = v
137
151
}
138
152
153
+ return nil
154
+ }
155
+
156
+ func updater () {
157
+ watcher , err := fsnotify .NewWatcher ()
158
+ if err != nil {
159
+ log .Fatal (err )
160
+ }
161
+ defer watcher .Close ()
162
+
163
+ err = watcher .Add (file )
164
+ if err != nil {
165
+ log .Fatal (err )
166
+ }
167
+ defer watcher .Remove (file )
168
+
169
+ for event := range watcher .Events {
170
+ log .Println ("updater go event" , event )
171
+ err = loadFile ()
172
+ if err != nil {
173
+ log .Println ("try to load file, but failed" , err )
174
+ }
175
+ watcher .Remove (file )
176
+ watcher .Add (file )
177
+ }
178
+ }
179
+
180
+ func main () {
181
+ flag .Parse ()
182
+
183
+ if err := loadFile (); err != nil {
184
+ log .Fatal (err )
185
+ }
186
+
139
187
var wg sync.WaitGroup
140
188
141
189
if tcp {
@@ -154,5 +202,7 @@ func main() {
154
202
}()
155
203
}
156
204
205
+ go updater ()
206
+
157
207
wg .Wait ()
158
208
}
0 commit comments