diff --git a/dpipe/dpipe.go b/dpipe/dpipe.go new file mode 100644 index 0000000..9d9660d --- /dev/null +++ b/dpipe/dpipe.go @@ -0,0 +1,146 @@ +// Package dpipe provides the pipe works like datagram protocol on memory. +// +// This package is mainly intended for testing and not for production! +package dpipe + +import ( + "context" + "io" + "net" + "sync" + "time" + + "github.com/pion/transport/v2/deadline" +) + +// Pipe creates pair of non-stream conn on memory. +// Close of the one end doesn't make effect to the other end. +func Pipe() (net.Conn, net.Conn) { + ch0 := make(chan []byte, 1000) + ch1 := make(chan []byte, 1000) + return &conn{ + rCh: ch0, + wCh: ch1, + closed: make(chan struct{}), + closing: make(chan struct{}), + readDeadline: deadline.New(), + writeDeadline: deadline.New(), + }, &conn{ + rCh: ch1, + wCh: ch0, + closed: make(chan struct{}), + closing: make(chan struct{}), + readDeadline: deadline.New(), + writeDeadline: deadline.New(), + } +} + +type pipeAddr struct{} + +func (pipeAddr) Network() string { return "pipe" } +func (pipeAddr) String() string { return ":1" } + +type conn struct { + rCh chan []byte + wCh chan []byte + closed chan struct{} + closing chan struct{} + closeOnce sync.Once + + readDeadline *deadline.Deadline + writeDeadline *deadline.Deadline +} + +func (*conn) LocalAddr() net.Addr { return pipeAddr{} } +func (*conn) RemoteAddr() net.Addr { return pipeAddr{} } + +func (c *conn) SetDeadline(t time.Time) error { + c.readDeadline.Set(t) + c.writeDeadline.Set(t) + return nil +} + +func (c *conn) SetReadDeadline(t time.Time) error { + c.readDeadline.Set(t) + return nil +} + +func (c *conn) SetWriteDeadline(t time.Time) error { + c.writeDeadline.Set(t) + return nil +} + +func (c *conn) Read(data []byte) (n int, err error) { + select { + case <-c.closed: + return 0, io.EOF + case <-c.closing: + if len(c.rCh) == 0 { + return 0, io.EOF + } + case <-c.readDeadline.Done(): + return 0, context.DeadlineExceeded + default: + } + + for { + select { + case d := <-c.rCh: + if len(d) <= len(data) { + copy(data, d) + return len(d), nil + } + copy(data, d[:len(data)]) + return len(data), nil + case <-c.closed: + return 0, io.EOF + case <-c.closing: + if len(c.rCh) == 0 { + return 0, io.EOF + } + case <-c.readDeadline.Done(): + return 0, context.DeadlineExceeded + } + } +} + +func (c *conn) cleanWriteBuffer() { + for { + select { + case <-c.wCh: + default: + return + } + } +} + +func (c *conn) Write(data []byte) (n int, err error) { + select { + case <-c.closed: + return 0, io.ErrClosedPipe + case <-c.writeDeadline.Done(): + c.cleanWriteBuffer() + return 0, context.DeadlineExceeded + default: + } + + cData := make([]byte, len(data)) + copy(cData, data) + + select { + case <-c.closed: + return 0, io.ErrClosedPipe + case <-c.writeDeadline.Done(): + c.cleanWriteBuffer() + return 0, context.DeadlineExceeded + case c.wCh <- cData: + return len(cData), nil + } +} + +func (c *conn) Close() error { + c.closeOnce.Do(func() { + close(c.closed) + }) + return nil +} diff --git a/dpipe/dpipe_test.go b/dpipe/dpipe_test.go new file mode 100644 index 0000000..c2e9f78 --- /dev/null +++ b/dpipe/dpipe_test.go @@ -0,0 +1,120 @@ +//go:build !js +// +build !js + +package dpipe + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "testing" + "time" + + "golang.org/x/net/nettest" +) + +var errFailedToCast = fmt.Errorf("failed to cast net.Conn to conn") + +func TestNetTest(t *testing.T) { + nettest.TestConn(t, func() (net.Conn, net.Conn, func(), error) { + ca, cb := Pipe() + caConn, ok := ca.(*conn) + if !ok { + return nil, nil, nil, errFailedToCast + } + + cbConn, ok := cb.(*conn) + if !ok { + return nil, nil, nil, errFailedToCast + } + + return &closePropagator{caConn, cbConn}, + &closePropagator{cbConn, caConn}, + func() { + _ = ca.Close() + _ = cb.Close() + }, nil + }) +} + +type closePropagator struct { + *conn + otherEnd *conn +} + +func (c *closePropagator) Close() error { + close(c.otherEnd.closing) + return c.conn.Close() +} + +func TestPipe(t *testing.T) { + ca, cb := Pipe() + + testData := []byte{0x01, 0x02} + + for name, cond := range map[string]struct { + ca net.Conn + cb net.Conn + }{ + "AtoB": {ca, cb}, + "BtoA": {cb, ca}, + } { + c0 := cond.ca + c1 := cond.cb + t.Run(name, func(t *testing.T) { + switch n, err := c0.Write(testData); { + case err != nil: + t.Errorf("Unexpected error on Write: %v", err) + case n != len(testData): + t.Errorf("Expected to write %d bytes, wrote %d bytes", len(testData), n) + } + + readData := make([]byte, 4) + switch n, err := c1.Read(readData); { + case err != nil: + t.Errorf("Unexpected error on Write: %v", err) + case n != len(testData): + t.Errorf("Expected to read %d bytes, got %d bytes", len(testData), n) + case !bytes.Equal(testData, readData[0:n]): + t.Errorf("Expected to read %v, got %v", testData, readData[0:n]) + } + }) + } + + if err := ca.Close(); err != nil { + t.Errorf("Unexpected error on Close: %v", err) + } + if _, err := ca.Write(testData); !errors.Is(err, io.ErrClosedPipe) { + t.Errorf("Write to closed conn should fail with %v, got %v", io.ErrClosedPipe, err) + } + + // Other side should be writable. + if _, err := cb.Write(testData); err != nil { + t.Errorf("Unexpected error on Write: %v", err) + } + + readData := make([]byte, 4) + if _, err := ca.Read(readData); !errors.Is(err, io.EOF) { + t.Errorf("Read from closed conn should fail with %v, got %v", io.EOF, err) + } + + // Other side should be readable. + readDone := make(chan struct{}) + go func() { + readData := make([]byte, 4) + if n, err := cb.Read(readData); err == nil { + t.Errorf("Unexpected data %v was arrived to orphaned conn", readData[:n]) + } + close(readDone) + }() + select { + case <-readDone: + t.Errorf("Read should be blocked if the other side is closed") + case <-time.After(10 * time.Millisecond): + } + if err := cb.Close(); err != nil { + t.Errorf("Unexpected error on Close: %v", err) + } +}