@@ -10,7 +10,6 @@ import (
10
10
"errors"
11
11
"fmt"
12
12
"io"
13
- "sync"
14
13
"time"
15
14
16
15
"github.com/klauspost/compress/flate"
@@ -71,7 +70,7 @@ type msgWriterState struct {
71
70
c * Conn
72
71
73
72
mu * mu
74
- writeMu sync. Mutex
73
+ writeMu * mu
75
74
76
75
ctx context.Context
77
76
opcode opcode
@@ -83,8 +82,9 @@ type msgWriterState struct {
83
82
84
83
func newMsgWriterState (c * Conn ) * msgWriterState {
85
84
mw := & msgWriterState {
86
- c : c ,
87
- mu : newMu (c ),
85
+ c : c ,
86
+ mu : newMu (c ),
87
+ writeMu : newMu (c ),
88
88
}
89
89
return mw
90
90
}
@@ -155,12 +155,15 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
155
155
156
156
// Write writes the given bytes to the WebSocket connection.
157
157
func (mw * msgWriterState ) Write (p []byte ) (_ int , err error ) {
158
- mw .writeMu .Lock ()
159
- defer mw .writeMu .Unlock ()
158
+ err = mw .writeMu .lock (mw .ctx )
159
+ if err != nil {
160
+ return 0 , fmt .Errorf ("failed to write: %w" , err )
161
+ }
162
+ defer mw .writeMu .unlock ()
160
163
161
164
defer func () {
162
- err = fmt .Errorf ("failed to write: %w" , err )
163
165
if err != nil {
166
+ err = fmt .Errorf ("failed to write: %w" , err )
164
167
mw .c .close (err )
165
168
}
166
169
}()
@@ -198,8 +201,11 @@ func (mw *msgWriterState) write(p []byte) (int, error) {
198
201
func (mw * msgWriterState ) Close () (err error ) {
199
202
defer errd .Wrap (& err , "failed to close writer" )
200
203
201
- mw .writeMu .Lock ()
202
- defer mw .writeMu .Unlock ()
204
+ err = mw .writeMu .lock (mw .ctx )
205
+ if err != nil {
206
+ return err
207
+ }
208
+ defer mw .writeMu .unlock ()
203
209
204
210
_ , err = mw .c .writeFrame (mw .ctx , true , mw .flate , mw .opcode , nil )
205
211
if err != nil {
@@ -219,7 +225,7 @@ func (mw *msgWriterState) close() {
219
225
putBufioWriter (mw .c .bw )
220
226
}
221
227
222
- mw .writeMu .Lock ()
228
+ mw .writeMu .forceLock ()
223
229
mw .dict .close ()
224
230
}
225
231
@@ -250,7 +256,8 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
250
256
251
257
defer func () {
252
258
if err != nil {
253
- c .close (fmt .Errorf ("failed to write frame: %w" , err ))
259
+ err = fmt .Errorf ("failed to write frame: %w" , err )
260
+ c .close (err )
254
261
}
255
262
}()
256
263
0 commit comments