1
1
package memory
2
2
3
3
import (
4
+ "bytes"
4
5
"errors"
5
6
"io"
6
7
"net"
8
+ "sync"
9
+ "sync/atomic"
7
10
"time"
8
11
9
12
"github.com/libp2p/go-libp2p/core/network"
@@ -14,76 +17,164 @@ type stream struct {
14
17
id int64
15
18
conn * conn
16
19
17
- read * io.PipeReader
18
- write * io.PipeWriter
19
- writeC chan []byte
20
+ wrMu sync.Mutex // Serialize Write operations
21
+ buf * bytes.Buffer // Buffer for partial reads
20
22
21
- reset chan struct {}
22
- close chan struct {}
23
- closed chan struct {}
23
+ // Used by local Read to interact with remote Write.
24
+ rdRx <- chan []byte
24
25
25
- writeErr error
26
+ // Used by local Write to interact with remote Read.
27
+ wrTx chan <- []byte
28
+
29
+ once sync.Once // Protects closing localDone
30
+ localDone chan struct {}
31
+ remoteDone <- chan struct {}
32
+
33
+ reset chan struct {}
34
+ close chan struct {}
35
+ readClosed atomic.Bool
36
+ writeClosed atomic.Bool
26
37
}
27
38
28
39
var ErrClosed = errors .New ("stream closed" )
29
40
30
41
func newStreamPair () (* stream , * stream ) {
31
- ra , wb := io .Pipe ()
32
- rb , wa := io .Pipe ()
33
-
34
- sa := newStream (wa , ra , network .DirOutbound )
35
- sb := newStream (wb , rb , network .DirInbound )
42
+ io .Pipe ()
43
+
44
+ cb1 := make (chan []byte , 1 )
45
+ cb2 := make (chan []byte , 1 )
46
+
47
+ done1 := make (chan struct {})
48
+ done2 := make (chan struct {})
49
+
50
+ sa := & stream {
51
+ id : streamCounter .Add (1 ),
52
+ rdRx : cb1 ,
53
+ wrTx : cb2 ,
54
+ buf : new (bytes.Buffer ),
55
+ localDone : done1 , remoteDone : done2 ,
56
+ reset : make (chan struct {}, 1 ),
57
+ close : make (chan struct {}, 1 ),
58
+ }
59
+ sb := & stream {
60
+ rdRx : cb2 ,
61
+ wrTx : cb1 ,
62
+ buf : new (bytes.Buffer ),
63
+ localDone : done2 , remoteDone : done1 ,
64
+ reset : make (chan struct {}, 1 ),
65
+ close : make (chan struct {}, 1 ),
66
+ }
36
67
37
68
return sa , sb
38
69
}
39
70
40
- func newStream (w * io. PipeWriter , r * io. PipeReader , _ network. Direction ) * stream {
71
+ func newStream (rdRx <- chan [] byte , wrTx chan <- [] byte , localDone chan struct {}, remoteDone <- chan struct {} ) * stream {
41
72
s := & stream {
42
- id : streamCounter .Add (1 ),
43
- read : r ,
44
- write : w ,
45
- writeC : make (chan []byte ),
46
- reset : make (chan struct {}, 1 ),
47
- close : make (chan struct {}, 1 ),
48
- closed : make (chan struct {}),
73
+ rdRx : rdRx ,
74
+ wrTx : wrTx ,
75
+ localDone : localDone ,
76
+ remoteDone : remoteDone ,
77
+ reset : make (chan struct {}, 1 ),
78
+ close : make (chan struct {}, 1 ),
49
79
}
50
80
51
- go s .writeLoop ()
52
81
return s
53
82
}
54
83
55
- func (s * stream ) Write (p []byte ) (int , error ) {
56
- cpy := make ([]byte , len (p ))
57
- copy (cpy , p )
84
+ func (p * stream ) Write (b []byte ) (int , error ) {
85
+ if p .writeClosed .Load () {
86
+ return 0 , ErrClosed
87
+ }
88
+
89
+ n , err := p .write (b )
90
+ if err != nil && err != io .ErrClosedPipe {
91
+ err = & net.OpError {Op : "write" , Net : "pipe" , Err : err }
92
+ }
93
+ return n , err
94
+ }
95
+
96
+ func (p * stream ) write (b []byte ) (n int , err error ) {
97
+ switch {
98
+ case isClosedChan (p .localDone ):
99
+ return 0 , io .ErrClosedPipe
100
+ case isClosedChan (p .remoteDone ):
101
+ return 0 , io .ErrClosedPipe
102
+ }
103
+
104
+ p .wrMu .Lock () // Ensure entirety of b is written together
105
+ defer p .wrMu .Unlock ()
58
106
59
107
select {
60
- case <- s .closed :
61
- return 0 , s .writeErr
62
- case s .writeC <- cpy :
108
+ case <- p .close :
109
+ return n , ErrClosed
110
+ case <- p .reset :
111
+ return n , network .ErrReset
112
+ case p .wrTx <- b :
113
+ n += len (b )
114
+ case <- p .localDone :
115
+ return n , io .ErrClosedPipe
116
+ case <- p .remoteDone :
117
+ return n , io .ErrClosedPipe
63
118
}
64
119
65
- return len ( p ) , nil
120
+ return n , nil
66
121
}
67
122
68
- func (s * stream ) Read (p []byte ) (int , error ) {
69
- return s .read .Read (p )
123
+ func (p * stream ) Read (b []byte ) (int , error ) {
124
+ if p .readClosed .Load () {
125
+ return 0 , ErrClosed
126
+ }
127
+
128
+ n , err := p .read (b )
129
+ if err != nil && err != io .EOF && err != io .ErrClosedPipe {
130
+ err = & net.OpError {Op : "read" , Net : "pipe" , Err : err }
131
+ }
132
+
133
+ return n , err
134
+ }
135
+
136
+ func (p * stream ) read (b []byte ) (n int , err error ) {
137
+ switch {
138
+ case isClosedChan (p .localDone ):
139
+ return 0 , io .ErrClosedPipe
140
+ case isClosedChan (p .remoteDone ):
141
+ return 0 , io .EOF
142
+ }
143
+
144
+ select {
145
+ case <- p .reset :
146
+ return n , network .ErrReset
147
+ case bw , ok := <- p .rdRx :
148
+ if ! ok {
149
+ p .readClosed .Store (true )
150
+ return 0 , io .EOF
151
+ }
152
+
153
+ p .buf .Write (bw )
154
+ case <- p .localDone :
155
+ return 0 , io .ErrClosedPipe
156
+ case <- p .remoteDone :
157
+ return 0 , io .EOF
158
+ default :
159
+ n , err = p .buf .Read (b )
160
+ }
161
+
162
+ return n , err
70
163
}
71
164
72
165
func (s * stream ) CloseWrite () error {
73
166
select {
74
167
case s .close <- struct {}{}:
75
168
default :
76
169
}
77
- <- s .closed
78
- if ! errors .Is (s .writeErr , ErrClosed ) {
79
- return s .writeErr
80
- }
81
- return nil
82
170
171
+ s .writeClosed .Store (true )
172
+ return nil
83
173
}
84
174
85
175
func (s * stream ) CloseRead () error {
86
- return s .read .CloseWithError (ErrClosed )
176
+ s .readClosed .Store (true )
177
+ return nil
87
178
}
88
179
89
180
func (s * stream ) Close () error {
@@ -92,15 +183,15 @@ func (s *stream) Close() error {
92
183
}
93
184
94
185
func (s * stream ) Reset () error {
95
- // Cancel any pending reads/writes with an error.
96
- s .write .CloseWithError (network .ErrReset )
97
- s .read .CloseWithError (network .ErrReset )
98
-
99
186
select {
100
187
case s .reset <- struct {}{}:
101
188
default :
102
189
}
103
- <- s .closed
190
+
191
+ s .once .Do (func () {
192
+ close (s .localDone )
193
+ })
194
+
104
195
// No meaningful error case here.
105
196
return nil
106
197
}
@@ -117,48 +208,11 @@ func (s *stream) SetWriteDeadline(t time.Time) error {
117
208
return & net.OpError {Op : "set" , Net : "pipe" , Source : nil , Addr : nil , Err : errors .New ("deadline not supported" )}
118
209
}
119
210
120
- func (s * stream ) writeLoop () {
121
- defer s .teardown ()
122
-
123
- for {
124
- // Reset takes precedent.
125
- select {
126
- case <- s .reset :
127
- s .writeErr = network .ErrReset
128
- return
129
- default :
130
- }
131
-
132
- select {
133
- case <- s .reset :
134
- s .writeErr = network .ErrReset
135
- return
136
- case <- s .close :
137
- s .writeErr = s .write .Close ()
138
- if s .writeErr == nil {
139
- s .writeErr = ErrClosed
140
- }
141
- return
142
- case p := <- s .writeC :
143
- if _ , err := s .write .Write (p ); err != nil {
144
- s .cancelWrite (err )
145
- return
146
- }
147
- }
148
- }
149
- }
150
-
151
- func (s * stream ) cancelWrite (err error ) {
152
- s .write .CloseWithError (err )
153
- s .writeErr = err
154
- }
155
-
156
- func (s * stream ) teardown () {
157
- // at this point, no streams are writing.
158
- if s .conn != nil {
159
- s .conn .removeStream (s .id )
211
+ func isClosedChan (c <- chan struct {}) bool {
212
+ select {
213
+ case <- c :
214
+ return true
215
+ default :
216
+ return false
160
217
}
161
-
162
- // Mark as closed.
163
- close (s .closed )
164
218
}
0 commit comments