Skip to content

Commit 2e0dd1c

Browse files
committed
Make writeMu a channel based mutex
Will prevent deadlock if a writer is used after close.
1 parent 1200707 commit 2e0dd1c

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

write.go

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"errors"
1111
"fmt"
1212
"io"
13-
"sync"
1413
"time"
1514

1615
"github.com/klauspost/compress/flate"
@@ -71,7 +70,7 @@ type msgWriterState struct {
7170
c *Conn
7271

7372
mu *mu
74-
writeMu sync.Mutex
73+
writeMu *mu
7574

7675
ctx context.Context
7776
opcode opcode
@@ -83,8 +82,9 @@ type msgWriterState struct {
8382

8483
func newMsgWriterState(c *Conn) *msgWriterState {
8584
mw := &msgWriterState{
86-
c: c,
87-
mu: newMu(c),
85+
c: c,
86+
mu: newMu(c),
87+
writeMu: newMu(c),
8888
}
8989
return mw
9090
}
@@ -155,12 +155,15 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
155155

156156
// Write writes the given bytes to the WebSocket connection.
157157
func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
158-
mw.writeMu.Lock()
159-
defer mw.writeMu.Unlock()
158+
err = mw.writeMu.lock(mw.ctx)
159+
if err != nil {
160+
return 0, fmt.Errorf("failed to write: %w", err)
161+
}
162+
defer mw.writeMu.unlock()
160163

161164
defer func() {
162-
err = fmt.Errorf("failed to write: %w", err)
163165
if err != nil {
166+
err = fmt.Errorf("failed to write: %w", err)
164167
mw.c.close(err)
165168
}
166169
}()
@@ -198,8 +201,11 @@ func (mw *msgWriterState) write(p []byte) (int, error) {
198201
func (mw *msgWriterState) Close() (err error) {
199202
defer errd.Wrap(&err, "failed to close writer")
200203

201-
mw.writeMu.Lock()
202-
defer mw.writeMu.Unlock()
204+
err = mw.writeMu.lock(mw.ctx)
205+
if err != nil {
206+
return err
207+
}
208+
defer mw.writeMu.unlock()
203209

204210
_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
205211
if err != nil {
@@ -219,7 +225,7 @@ func (mw *msgWriterState) close() {
219225
putBufioWriter(mw.c.bw)
220226
}
221227

222-
mw.writeMu.Lock()
228+
mw.writeMu.forceLock()
223229
mw.dict.close()
224230
}
225231

@@ -250,7 +256,8 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
250256

251257
defer func() {
252258
if err != nil {
253-
c.close(fmt.Errorf("failed to write frame: %w", err))
259+
err = fmt.Errorf("failed to write frame: %w", err)
260+
c.close(err)
254261
}
255262
}()
256263

0 commit comments

Comments
 (0)