Skip to content

Commit b3b37b8

Browse files
authored
Fix data race in ReconnectClient (#185)
1 parent 0b07cb5 commit b3b37b8

File tree

3 files changed

+90
-9
lines changed

3 files changed

+90
-9
lines changed

atomic.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright 2021 The mqtt-go authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package mqtt
16+
17+
import (
18+
"sync"
19+
)
20+
21+
type firstError struct {
22+
mu sync.RWMutex
23+
err error
24+
}
25+
26+
func (e *firstError) Store(err error) {
27+
e.mu.Lock()
28+
if e.err == nil {
29+
e.err = err
30+
}
31+
e.mu.Unlock()
32+
}
33+
34+
func (e *firstError) Load() error {
35+
e.mu.RLock()
36+
defer e.mu.RUnlock()
37+
return e.err
38+
}

atomic_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright 2021 The mqtt-go authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package mqtt
16+
17+
import (
18+
"errors"
19+
"testing"
20+
)
21+
22+
func TestFirstError(t *testing.T) {
23+
var fe firstError
24+
25+
if err := fe.Load(); err != nil {
26+
t.Fatalf("Initial value must be 'nil', got '%v'", err)
27+
}
28+
29+
errDummy0 := errors.New("dummy1")
30+
errDummy1 := errors.New("dummy1")
31+
32+
fe.Store(errDummy0)
33+
34+
if err := fe.Load(); err != errDummy0 {
35+
t.Fatalf("Expected '%v', got '%v'", errDummy0, err)
36+
}
37+
38+
fe.Store(errDummy1)
39+
40+
if err := fe.Load(); err != errDummy0 {
41+
t.Fatalf("Value is updated after first store. Expected '%v', got '%v'", errDummy0, err)
42+
}
43+
}

reconnclient.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func (c *reconnectClient) Connect(ctx context.Context, clientID string, opts ...
6969
c.options.Timeout = c.options.PingInterval
7070
}
7171

72-
var errDial, errConnect error
72+
var errDial, errConnect firstError
7373

7474
done := make(chan struct{})
7575
var doneOnce sync.Once
@@ -136,12 +136,12 @@ func (c *reconnectClient) Connect(ctx context.Context, clientID string, opts ...
136136
case <-c.disconnected:
137137
return
138138
}
139-
} else if err != ctxTimeout.Err() || errConnect == nil {
140-
errConnect = err // Hold last connect error but avoid overwrote by context cancel.
139+
} else if err != ctxTimeout.Err() {
140+
errConnect.Store(err) // Hold first connect error excepting context cancel.
141141
}
142142
cancel()
143-
} else if err != ctx.Err() || errDial == nil {
144-
errDial = err // Hold last dial error but avoid overwrote by context cancel.
143+
} else if err != ctx.Err() {
144+
errDial.Store(err) // Hold first dial error excepting context cancel.
145145
}
146146
select {
147147
case <-time.After(reconnWait):
@@ -161,11 +161,11 @@ func (c *reconnectClient) Connect(ctx context.Context, clientID string, opts ...
161161
case <-done:
162162
case <-ctx.Done():
163163
var actualErrs []string
164-
if errDial != nil {
165-
actualErrs = append(actualErrs, fmt.Sprintf("dial: %v", errDial))
164+
if err := errDial.Load(); err != nil {
165+
actualErrs = append(actualErrs, fmt.Sprintf("dial: %v", err))
166166
}
167-
if errConnect != nil {
168-
actualErrs = append(actualErrs, fmt.Sprintf("connect: %v", errConnect))
167+
if err := errConnect.Load(); err != nil {
168+
actualErrs = append(actualErrs, fmt.Sprintf("connect: %v", err))
169169
}
170170
var errStr string
171171
if len(actualErrs) > 0 {

0 commit comments

Comments
 (0)