Skip to content

Commit 76e141f

Browse files
committed
Revert to stream using channels instead of io.Pipe
1 parent 5236ff2 commit 76e141f

File tree

2 files changed

+155
-103
lines changed

2 files changed

+155
-103
lines changed

p2p/transport/memory/stream.go

Lines changed: 138 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package memory
22

33
import (
4+
"bytes"
45
"errors"
56
"io"
67
"net"
8+
"sync"
9+
"sync/atomic"
710
"time"
811

912
"github.com/libp2p/go-libp2p/core/network"
@@ -14,76 +17,164 @@ type stream struct {
1417
id int64
1518
conn *conn
1619

17-
read *io.PipeReader
18-
write *io.PipeWriter
19-
writeC chan []byte
20+
wrMu sync.Mutex // Serialize Write operations
21+
buf *bytes.Buffer // Buffer for partial reads
2022

21-
reset chan struct{}
22-
close chan struct{}
23-
closed chan struct{}
23+
// Used by local Read to interact with remote Write.
24+
rdRx <-chan []byte
2425

25-
writeErr error
26+
// Used by local Write to interact with remote Read.
27+
wrTx chan<- []byte
28+
29+
once sync.Once // Protects closing localDone
30+
localDone chan struct{}
31+
remoteDone <-chan struct{}
32+
33+
reset chan struct{}
34+
close chan struct{}
35+
readClosed atomic.Bool
36+
writeClosed atomic.Bool
2637
}
2738

2839
var ErrClosed = errors.New("stream closed")
2940

3041
func newStreamPair() (*stream, *stream) {
31-
ra, wb := io.Pipe()
32-
rb, wa := io.Pipe()
33-
34-
sa := newStream(wa, ra, network.DirOutbound)
35-
sb := newStream(wb, rb, network.DirInbound)
42+
io.Pipe()
43+
44+
cb1 := make(chan []byte, 1)
45+
cb2 := make(chan []byte, 1)
46+
47+
done1 := make(chan struct{})
48+
done2 := make(chan struct{})
49+
50+
sa := &stream{
51+
id: streamCounter.Add(1),
52+
rdRx: cb1,
53+
wrTx: cb2,
54+
buf: new(bytes.Buffer),
55+
localDone: done1, remoteDone: done2,
56+
reset: make(chan struct{}, 1),
57+
close: make(chan struct{}, 1),
58+
}
59+
sb := &stream{
60+
rdRx: cb2,
61+
wrTx: cb1,
62+
buf: new(bytes.Buffer),
63+
localDone: done2, remoteDone: done1,
64+
reset: make(chan struct{}, 1),
65+
close: make(chan struct{}, 1),
66+
}
3667

3768
return sa, sb
3869
}
3970

40-
func newStream(w *io.PipeWriter, r *io.PipeReader, _ network.Direction) *stream {
71+
func newStream(rdRx <-chan []byte, wrTx chan<- []byte, localDone chan struct{}, remoteDone <-chan struct{}) *stream {
4172
s := &stream{
42-
id: streamCounter.Add(1),
43-
read: r,
44-
write: w,
45-
writeC: make(chan []byte),
46-
reset: make(chan struct{}, 1),
47-
close: make(chan struct{}, 1),
48-
closed: make(chan struct{}),
73+
rdRx: rdRx,
74+
wrTx: wrTx,
75+
localDone: localDone,
76+
remoteDone: remoteDone,
77+
reset: make(chan struct{}, 1),
78+
close: make(chan struct{}, 1),
4979
}
5080

51-
go s.writeLoop()
5281
return s
5382
}
5483

55-
func (s *stream) Write(p []byte) (int, error) {
56-
cpy := make([]byte, len(p))
57-
copy(cpy, p)
84+
func (p *stream) Write(b []byte) (int, error) {
85+
if p.writeClosed.Load() {
86+
return 0, ErrClosed
87+
}
88+
89+
n, err := p.write(b)
90+
if err != nil && err != io.ErrClosedPipe {
91+
err = &net.OpError{Op: "write", Net: "pipe", Err: err}
92+
}
93+
return n, err
94+
}
95+
96+
func (p *stream) write(b []byte) (n int, err error) {
97+
switch {
98+
case isClosedChan(p.localDone):
99+
return 0, io.ErrClosedPipe
100+
case isClosedChan(p.remoteDone):
101+
return 0, io.ErrClosedPipe
102+
}
103+
104+
p.wrMu.Lock() // Ensure entirety of b is written together
105+
defer p.wrMu.Unlock()
58106

59107
select {
60-
case <-s.closed:
61-
return 0, s.writeErr
62-
case s.writeC <- cpy:
108+
case <-p.close:
109+
return n, ErrClosed
110+
case <-p.reset:
111+
return n, network.ErrReset
112+
case p.wrTx <- b:
113+
n += len(b)
114+
case <-p.localDone:
115+
return n, io.ErrClosedPipe
116+
case <-p.remoteDone:
117+
return n, io.ErrClosedPipe
63118
}
64119

65-
return len(p), nil
120+
return n, nil
66121
}
67122

68-
func (s *stream) Read(p []byte) (int, error) {
69-
return s.read.Read(p)
123+
func (p *stream) Read(b []byte) (int, error) {
124+
if p.readClosed.Load() {
125+
return 0, ErrClosed
126+
}
127+
128+
n, err := p.read(b)
129+
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
130+
err = &net.OpError{Op: "read", Net: "pipe", Err: err}
131+
}
132+
133+
return n, err
134+
}
135+
136+
func (p *stream) read(b []byte) (n int, err error) {
137+
switch {
138+
case isClosedChan(p.localDone):
139+
return 0, io.ErrClosedPipe
140+
case isClosedChan(p.remoteDone):
141+
return 0, io.EOF
142+
}
143+
144+
select {
145+
case <-p.reset:
146+
return n, network.ErrReset
147+
case bw, ok := <-p.rdRx:
148+
if !ok {
149+
p.readClosed.Store(true)
150+
return 0, io.EOF
151+
}
152+
153+
p.buf.Write(bw)
154+
case <-p.localDone:
155+
return 0, io.ErrClosedPipe
156+
case <-p.remoteDone:
157+
return 0, io.EOF
158+
default:
159+
n, err = p.buf.Read(b)
160+
}
161+
162+
return n, err
70163
}
71164

72165
func (s *stream) CloseWrite() error {
73166
select {
74167
case s.close <- struct{}{}:
75168
default:
76169
}
77-
<-s.closed
78-
if !errors.Is(s.writeErr, ErrClosed) {
79-
return s.writeErr
80-
}
81-
return nil
82170

171+
s.writeClosed.Store(true)
172+
return nil
83173
}
84174

85175
func (s *stream) CloseRead() error {
86-
return s.read.CloseWithError(ErrClosed)
176+
s.readClosed.Store(true)
177+
return nil
87178
}
88179

89180
func (s *stream) Close() error {
@@ -92,15 +183,15 @@ func (s *stream) Close() error {
92183
}
93184

94185
func (s *stream) Reset() error {
95-
// Cancel any pending reads/writes with an error.
96-
s.write.CloseWithError(network.ErrReset)
97-
s.read.CloseWithError(network.ErrReset)
98-
99186
select {
100187
case s.reset <- struct{}{}:
101188
default:
102189
}
103-
<-s.closed
190+
191+
s.once.Do(func() {
192+
close(s.localDone)
193+
})
194+
104195
// No meaningful error case here.
105196
return nil
106197
}
@@ -117,48 +208,11 @@ func (s *stream) SetWriteDeadline(t time.Time) error {
117208
return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
118209
}
119210

120-
func (s *stream) writeLoop() {
121-
defer s.teardown()
122-
123-
for {
124-
// Reset takes precedent.
125-
select {
126-
case <-s.reset:
127-
s.writeErr = network.ErrReset
128-
return
129-
default:
130-
}
131-
132-
select {
133-
case <-s.reset:
134-
s.writeErr = network.ErrReset
135-
return
136-
case <-s.close:
137-
s.writeErr = s.write.Close()
138-
if s.writeErr == nil {
139-
s.writeErr = ErrClosed
140-
}
141-
return
142-
case p := <-s.writeC:
143-
if _, err := s.write.Write(p); err != nil {
144-
s.cancelWrite(err)
145-
return
146-
}
147-
}
148-
}
149-
}
150-
151-
func (s *stream) cancelWrite(err error) {
152-
s.write.CloseWithError(err)
153-
s.writeErr = err
154-
}
155-
156-
func (s *stream) teardown() {
157-
// at this point, no streams are writing.
158-
if s.conn != nil {
159-
s.conn.removeStream(s.id)
211+
func isClosedChan(c <-chan struct{}) bool {
212+
select {
213+
case <-c:
214+
return true
215+
default:
216+
return false
160217
}
161-
162-
// Mark as closed.
163-
close(s.closed)
164218
}

p2p/transport/memory/stream_test.go

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,62 +8,60 @@ import (
88

99
func TestStreamSimpleReadWriteClose(t *testing.T) {
1010
t.Parallel()
11-
clientStr, serverStr := newStreamPair()
11+
streamLocal, streamRemote := newStreamPair()
1212

1313
// send a foobar from the client
14-
n, err := clientStr.Write([]byte("foobar"))
14+
n, err := streamLocal.Write([]byte("foobar"))
1515
require.NoError(t, err)
1616
require.Equal(t, 6, n)
17-
require.NoError(t, clientStr.CloseWrite())
17+
require.NoError(t, streamLocal.CloseWrite())
1818

1919
// writing after closing should error
20-
_, err = clientStr.Write([]byte("foobar"))
20+
_, err = streamLocal.Write([]byte("foobar"))
2121
require.Error(t, err)
2222

2323
// now read all the data on the server side
24-
b, err := io.ReadAll(serverStr)
24+
b, err := io.ReadAll(streamRemote)
2525
require.NoError(t, err)
2626
require.Equal(t, []byte("foobar"), b)
2727

2828
// reading again should give another io.EOF
29-
n, err = serverStr.Read(make([]byte, 10))
29+
n, err = streamRemote.Read(make([]byte, 10))
3030
require.Zero(t, n)
3131
require.ErrorIs(t, err, io.EOF)
3232

3333
// send something back
34-
_, err = serverStr.Write([]byte("lorem ipsum"))
34+
_, err = streamRemote.Write([]byte("lorem ipsum"))
3535
require.NoError(t, err)
36-
require.NoError(t, serverStr.CloseWrite())
36+
require.NoError(t, streamRemote.CloseWrite())
3737

3838
// and read it at the client
39-
b, err = io.ReadAll(clientStr)
39+
b, err = io.ReadAll(streamLocal)
4040
require.NoError(t, err)
4141
require.Equal(t, []byte("lorem ipsum"), b)
4242

4343
// stream is only cleaned up on calling Close or Reset
44-
clientStr.Close()
45-
serverStr.Close()
46-
// Need to call Close for cleanup. Otherwise the FIN_ACK is never read
47-
require.NoError(t, serverStr.Close())
44+
require.NoError(t, streamLocal.Close())
45+
require.NoError(t, streamRemote.Close())
4846
}
4947

5048
func TestStreamPartialReads(t *testing.T) {
5149
t.Parallel()
52-
clientStr, serverStr := newStreamPair()
50+
streamLocal, streamRemote := newStreamPair()
5351

54-
_, err := serverStr.Write([]byte("foobar"))
52+
_, err := streamRemote.Write([]byte("foobar"))
5553
require.NoError(t, err)
56-
require.NoError(t, serverStr.CloseWrite())
54+
require.NoError(t, streamRemote.CloseWrite())
5755

58-
n, err := clientStr.Read([]byte{}) // empty read
56+
n, err := streamLocal.Read([]byte{}) // empty read
5957
require.NoError(t, err)
6058
require.Zero(t, n)
6159
b := make([]byte, 3)
62-
n, err = clientStr.Read(b)
60+
n, err = streamLocal.Read(b)
6361
require.Equal(t, 3, n)
6462
require.NoError(t, err)
6563
require.Equal(t, []byte("foo"), b)
66-
b, err = io.ReadAll(clientStr)
64+
b, err = io.ReadAll(streamLocal)
6765
require.NoError(t, err)
6866
require.Equal(t, []byte("bar"), b)
6967
}

0 commit comments

Comments
 (0)