diff --git a/api_request.go b/api_request.go index d3e97596..e565e7cf 100644 --- a/api_request.go +++ b/api_request.go @@ -24,20 +24,6 @@ func (c *deadlinedConn) Write(b []byte) (n int, err error) { return c.Conn.Write(b) } -func newDeadlineTransport(timeout time.Duration) *http.Transport { - transport := &http.Transport{ - DisableKeepAlives: true, - Dial: func(netw, addr string) (net.Conn, error) { - c, err := net.DialTimeout(netw, addr, timeout) - if err != nil { - return nil, err - } - return &deadlinedConn{timeout, c}, nil - }, - } - return transport -} - type wrappedResp struct { Status string `json:"status_txt"` StatusCode int `json:"status_code"` @@ -45,8 +31,7 @@ type wrappedResp struct { } // stores the result in the value pointed to by ret(must be a pointer) -func apiRequestNegotiateV1(method string, endpoint string, headers http.Header, ret interface{}) error { - httpclient := &http.Client{Transport: newDeadlineTransport(2 * time.Second)} +func apiRequestNegotiateV1(httpclient *http.Client, method string, endpoint string, headers http.Header, ret interface{}) error { req, err := http.NewRequest(method, endpoint, nil) if err != nil { return err diff --git a/config.go b/config.go index 644aaac8..1f7ea2cd 100644 --- a/config.go +++ b/config.go @@ -110,6 +110,7 @@ type Config struct { // reconnection attempts LookupdPollInterval time.Duration `opt:"lookupd_poll_interval" min:"10ms" max:"5m" default:"60s"` LookupdPollJitter float64 `opt:"lookupd_poll_jitter" min:"0" max:"1" default:"0.3"` + LookupdPollTimeout time.Duration `opt:"lookupd_poll_timeout" default:"1m"` // Maximum duration when REQueueing (for doubling of deferred requeue) MaxRequeueDelay time.Duration `opt:"max_requeue_delay" min:"0" max:"60m" default:"15m"` diff --git a/consumer.go b/consumer.go index 984c07c6..b4d7487b 100644 --- a/consumer.go +++ b/consumer.go @@ -128,6 +128,7 @@ type Consumer struct { lookupdRecheckChan chan int lookupdHTTPAddrs []string lookupdQueryIndex int + lookupdHttpClient *http.Client wg sync.WaitGroup runningHandlers int32 @@ -326,6 +327,11 @@ func (r *Consumer) ChangeMaxInFlight(maxInFlight int) { } } +// set lookupd http client +func (r *Consumer) SetLookupdHttpClient(httpclient *http.Client) { + r.lookupdHttpClient = httpclient +} + // ConnectToNSQLookupd adds an nsqlookupd address to the list for this Consumer instance. // // If it is the first to be added, it initiates an HTTP request to discover nsqd @@ -355,6 +361,23 @@ func (r *Consumer) ConnectToNSQLookupd(addr string) error { } } r.lookupdHTTPAddrs = append(r.lookupdHTTPAddrs, parsedAddr) + if r.lookupdHttpClient == nil { + transport := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: r.config.LookupdPollTimeout, + KeepAlive: 30 * time.Second, + }).DialContext, + ResponseHeaderTimeout: r.config.LookupdPollTimeout, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + } + r.lookupdHttpClient = &http.Client{ + Transport: transport, + Timeout: r.config.LookupdPollTimeout, + } + } + numLookupd := len(r.lookupdHTTPAddrs) r.mtx.Unlock() @@ -468,7 +491,7 @@ retry: if r.config.AuthSecret != "" && r.config.LookupdAuthorization { headers.Set("Authorization", fmt.Sprintf("Bearer %s", r.config.AuthSecret)) } - err := apiRequestNegotiateV1("GET", endpoint, headers, &data) + err := apiRequestNegotiateV1(r.lookupdHttpClient, "GET", endpoint, headers, &data) if err != nil { r.log(LogLevelError, "error querying nsqlookupd (%s) - %s", endpoint, err) retries++