diff --git a/drove.go b/drove.go index 250617b..29a98ae 100644 --- a/drove.go +++ b/drove.go @@ -11,6 +11,7 @@ import ( "net/url" "strconv" "strings" + "sync" "time" ) @@ -26,10 +27,11 @@ type IDroveClient interface { PollEvents(callback func(event *DroveEventSummary)) } type DroveClient struct { - Endpoint []EndpointStatus - Leader *LeaderController - AuthConfig *DroveAuthConfig - client *http.Client + EndpointMutex sync.RWMutex + Endpoint []EndpointStatus + Leader *LeaderController + AuthConfig *DroveAuthConfig + client *http.Client } func NewDroveClient(config DroveConfig) DroveClient { @@ -167,6 +169,8 @@ func leaderController(endpoint string) (*LeaderController, error) { } func (c *DroveClient) endpoint() (string, error) { + c.EndpointMutex.RLock() + defer c.EndpointMutex.RUnlock() var err error = nil if c.Leader == nil || c.Leader.Endpoint == "" { return "", errors.New("all endpoints are down") @@ -190,7 +194,8 @@ func (c *DroveClient) refreshLeaderData() { log.Errorf("Leader struct generation failed %+v", err) return } - + c.EndpointMutex.Lock() + defer c.EndpointMutex.Unlock() c.Leader = newLeader log.Infof("New leader being set leader %+v", c.Leader) } diff --git a/drove_test.go b/drove_test.go index 2699413..6ccb693 100644 --- a/drove_test.go +++ b/drove_test.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync/atomic" "testing" "time" @@ -67,29 +68,32 @@ func TestLeaderElection(t *testing.T) { client.Init() assert.Nil(t, client.Leader) - client = NewDroveClient(DroveConfig{Endpoint: "http://random.blah.endpoint.non-existent", AuthConfig: DroveAuthConfig{AccessToken: ""}}) - client.Init() - assert.Nil(t, client.Leader) + client1 := NewDroveClient(DroveConfig{Endpoint: "http://random.blah.endpoint.non-existent", AuthConfig: DroveAuthConfig{AccessToken: ""}}) + client1.Init() + assert.Nil(t, client1.Leader) - client = NewDroveClient(DroveConfig{Endpoint: fmt.Sprintf("%s,%s", server.URL, server2.URL), AuthConfig: DroveAuthConfig{AccessToken: ""}}) - client.Init() - assert.NotNil(t, client.Leader) - assert.Equal(t, server2.URL, client.Leader.Endpoint) + client2 := NewDroveClient(DroveConfig{Endpoint: fmt.Sprintf("%s,%s", server.URL, server2.URL), AuthConfig: DroveAuthConfig{AccessToken: ""}}) + client2.Init() + assert.NotNil(t, client2.Leader) + assert.Equal(t, server2.URL, client2.Leader.Endpoint) time.Sleep(2 * time.Second) - assert.NotNil(t, client.Leader) - assert.Equal(t, server2.URL, client.Leader.Endpoint) + assert.NotNil(t, client2.Leader) + assert.Equal(t, server2.URL, client2.Leader.Endpoint) } func TestLeaderFailover(t *testing.T) { mux := http.NewServeMux() server := httptest.NewServer(mux) - status1, status2 := 200, 400 + var status1, status2 atomic.Int64 + status1.Store(200) + status2.Store(400) + mux.HandleFunc("/apis/v1/ping", func(rw http.ResponseWriter, req *http.Request) { // Test request parameters // Send response to be tested // rw.Write([]byte(`OK`)) - rw.WriteHeader(status1) + rw.WriteHeader(int(status1.Load())) }) mux2 := http.NewServeMux() @@ -98,15 +102,19 @@ func TestLeaderFailover(t *testing.T) { // Test request parameters // Send response to be tested // rw.Write([]byte(`OK`)) - rw.WriteHeader(status2) + rw.WriteHeader(int(status2.Load())) }) client := NewDroveClient(DroveConfig{Endpoint: fmt.Sprintf("%s,%s", server.URL, server2.URL), AuthConfig: DroveAuthConfig{AccessToken: ""}}) client.Init() assert.NotNil(t, client.Leader) - assert.Equal(t, server.URL, client.Leader.Endpoint) - status1, status2 = 400, 200 + endpoint, err := client.endpoint() + assert.Equal(t, server.URL, endpoint) + status1.Store(400) + status2.Store(200) time.Sleep(4 * time.Second) - assert.NotNil(t, client.Leader) - assert.Equal(t, server2.URL, client.Leader.Endpoint) + endpoint, err = client.endpoint() + assert.Nil(t, err) + + assert.Equal(t, server2.URL, endpoint) } diff --git a/endpoints.go b/endpoints.go index 2fc0d68..0722780 100644 --- a/endpoints.go +++ b/endpoints.go @@ -6,39 +6,48 @@ import ( ) type DroveEndpoints struct { - appsMutext *sync.RWMutex + appsMutex *sync.RWMutex AppsDB *DroveAppsResponse DroveClient IDroveClient + AppsByVhost map[string]DroveApp } func (dr *DroveEndpoints) setApps(appDB *DroveAppsResponse) { - dr.appsMutext.Lock() + var appsByVhost map[string]DroveApp = make(map[string]DroveApp) + if appDB != nil { + for _, app := range appDB.Apps { + appsByVhost[app.Vhost+"."] = app + } + } + dr.appsMutex.Lock() dr.AppsDB = appDB - dr.appsMutext.Unlock() + dr.AppsByVhost = appsByVhost + dr.appsMutex.Unlock() } -func (dr *DroveEndpoints) getApps() DroveAppsResponse { - dr.appsMutext.RLock() - defer dr.appsMutext.RUnlock() +func (dr *DroveEndpoints) getApps() *DroveAppsResponse { + dr.appsMutex.RLock() + defer dr.appsMutex.RUnlock() if dr.AppsDB == nil { - return DroveAppsResponse{} + return nil } - return *dr.AppsDB + return dr.AppsDB } func (dr *DroveEndpoints) searchApps(questionName string) *DroveApp { - dr.appsMutext.RLock() - defer dr.appsMutext.RUnlock() - for _, app := range dr.AppsDB.Apps { - if app.Vhost+"." == questionName { - return &app - } + dr.appsMutex.RLock() + defer dr.appsMutex.RUnlock() + if dr.AppsByVhost == nil { + return nil + } + if app, ok := dr.AppsByVhost[questionName]; ok { + return &app } return nil } func newDroveEndpoints(client IDroveClient) *DroveEndpoints { - endpoints := DroveEndpoints{DroveClient: client, appsMutext: &sync.RWMutex{}} + endpoints := DroveEndpoints{DroveClient: client, appsMutex: &sync.RWMutex{}} ticker := time.NewTicker(10 * time.Second) done := make(chan bool) reload := make(chan bool) @@ -62,7 +71,7 @@ func newDroveEndpoints(client IDroveClient) *DroveEndpoints { apps, err := endpoints.DroveClient.FetchApps() if err != nil { DroveQueryFailure.Inc() - log.Errorf("Error refreshing nodes data %+v", endpoints.AppsDB) + log.Errorf("Error refreshing nodes data") return } diff --git a/endpoints_test.go b/endpoints_test.go index ff93632..2c445ad 100644 --- a/endpoints_test.go +++ b/endpoints_test.go @@ -49,6 +49,9 @@ func TestRaceCondidtion(t *testing.T) { for i := 0; i < 100; i++ { go func() { apps := underTest.getApps() + if apps == nil { + return + } for i, _ := range apps.Apps { t.Logf("%+v", apps.Apps[i].Hosts) } diff --git a/handler.go b/handler.go index 89f29a3..5d79c25 100644 --- a/handler.go +++ b/handler.go @@ -27,7 +27,7 @@ func (e *DroveHandler) Name() string { return "drove" } func (e *DroveHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { a := new(dns.Msg) - if e.DroveEndpoints.AppsDB == nil { + if e.DroveEndpoints.getApps() == nil { return dns.RcodeServerFailure, fmt.Errorf("Drove DNS not ready") } app := e.DroveEndpoints.searchApps(r.Question[0].Name) diff --git a/handler_test.go b/handler_test.go index ab7d72a..042e9f6 100644 --- a/handler_test.go +++ b/handler_test.go @@ -43,8 +43,8 @@ func (w *MockResponseWriter) WriteMsg(res *dns.Msg) error { } func TestServeDNSNotReady(t *testing.T) { - handler := DroveHandler{DroveEndpoints: &DroveEndpoints{DroveClient: &MockDroveClient{}}} + handler := DroveHandler{DroveEndpoints: newDroveEndpoints(&MockDroveClient{})} writer := &MockResponseWriter{ validator: func(res *dns.Msg) { assert.Equal(t, 1, len(res.Answer), "One Answer should be returned") diff --git a/ready.go b/ready.go index de899a7..43c368c 100644 --- a/ready.go +++ b/ready.go @@ -2,5 +2,5 @@ package drovedns // Checks if apps data could be synced from drove cluster func (e *DroveHandler) Ready() bool { - return e.DroveEndpoints.AppsDB != nil + return e.DroveEndpoints.getApps() != nil }