Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pion/dtls/internal/net/dpipe package #236

Merged
merged 2 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions dpipe/dpipe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Package dpipe provides the pipe works like datagram protocol on memory.
stv0g marked this conversation as resolved.
Show resolved Hide resolved
//
// This package is mainly intended for testing and not for production!
package dpipe

import (
"context"
"io"
"net"
"sync"
"time"

"github.com/pion/transport/v2/deadline"
)

// Pipe creates pair of non-stream conn on memory.
// Close of the one end doesn't make effect to the other end.
func Pipe() (net.Conn, net.Conn) {
ch0 := make(chan []byte, 1000)
ch1 := make(chan []byte, 1000)
return &conn{
rCh: ch0,
wCh: ch1,
closed: make(chan struct{}),
closing: make(chan struct{}),
readDeadline: deadline.New(),
writeDeadline: deadline.New(),
}, &conn{
rCh: ch1,
wCh: ch0,
closed: make(chan struct{}),
closing: make(chan struct{}),
readDeadline: deadline.New(),
writeDeadline: deadline.New(),
}
}

type pipeAddr struct{}

func (pipeAddr) Network() string { return "pipe" }
func (pipeAddr) String() string { return ":1" }

type conn struct {
rCh chan []byte
wCh chan []byte
closed chan struct{}
closing chan struct{}
closeOnce sync.Once

readDeadline *deadline.Deadline
writeDeadline *deadline.Deadline
}

func (*conn) LocalAddr() net.Addr { return pipeAddr{} }
func (*conn) RemoteAddr() net.Addr { return pipeAddr{} }

func (c *conn) SetDeadline(t time.Time) error {
c.readDeadline.Set(t)
c.writeDeadline.Set(t)
return nil
}

func (c *conn) SetReadDeadline(t time.Time) error {
c.readDeadline.Set(t)
return nil
}

func (c *conn) SetWriteDeadline(t time.Time) error {
c.writeDeadline.Set(t)
return nil
}

func (c *conn) Read(data []byte) (n int, err error) {
select {
case <-c.closed:
return 0, io.EOF
case <-c.closing:
if len(c.rCh) == 0 {
return 0, io.EOF
}
case <-c.readDeadline.Done():
return 0, context.DeadlineExceeded
default:
}

for {
select {
case d := <-c.rCh:
if len(d) <= len(data) {
copy(data, d)
return len(d), nil
}
copy(data, d[:len(data)])
return len(data), nil
case <-c.closed:
return 0, io.EOF
case <-c.closing:
if len(c.rCh) == 0 {
return 0, io.EOF
}
case <-c.readDeadline.Done():
return 0, context.DeadlineExceeded
}
}
}

func (c *conn) cleanWriteBuffer() {
for {
select {
case <-c.wCh:
default:
return
}
}
}

func (c *conn) Write(data []byte) (n int, err error) {
select {
case <-c.closed:
return 0, io.ErrClosedPipe
case <-c.writeDeadline.Done():
c.cleanWriteBuffer()
return 0, context.DeadlineExceeded
default:
}

cData := make([]byte, len(data))
copy(cData, data)

select {
case <-c.closed:
return 0, io.ErrClosedPipe
case <-c.writeDeadline.Done():
c.cleanWriteBuffer()
return 0, context.DeadlineExceeded
case c.wCh <- cData:
return len(cData), nil
}
}

func (c *conn) Close() error {
c.closeOnce.Do(func() {
close(c.closed)
})
return nil
}
120 changes: 120 additions & 0 deletions dpipe/dpipe_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
//go:build !js
// +build !js

package dpipe

import (
"bytes"
"errors"
"fmt"
"io"
"net"
"testing"
"time"

"golang.org/x/net/nettest"
)

var errFailedToCast = fmt.Errorf("failed to cast net.Conn to conn")

func TestNetTest(t *testing.T) {
nettest.TestConn(t, func() (net.Conn, net.Conn, func(), error) {
ca, cb := Pipe()
caConn, ok := ca.(*conn)
if !ok {
return nil, nil, nil, errFailedToCast
}

cbConn, ok := cb.(*conn)
if !ok {
return nil, nil, nil, errFailedToCast
}

return &closePropagator{caConn, cbConn},
&closePropagator{cbConn, caConn},
func() {
_ = ca.Close()
_ = cb.Close()
}, nil
})
}

type closePropagator struct {
*conn
otherEnd *conn
}

func (c *closePropagator) Close() error {
close(c.otherEnd.closing)
return c.conn.Close()
}

func TestPipe(t *testing.T) {
ca, cb := Pipe()

testData := []byte{0x01, 0x02}

for name, cond := range map[string]struct {
ca net.Conn
cb net.Conn
}{
"AtoB": {ca, cb},
"BtoA": {cb, ca},
} {
c0 := cond.ca
c1 := cond.cb
t.Run(name, func(t *testing.T) {
switch n, err := c0.Write(testData); {
case err != nil:
t.Errorf("Unexpected error on Write: %v", err)
case n != len(testData):
t.Errorf("Expected to write %d bytes, wrote %d bytes", len(testData), n)
}

readData := make([]byte, 4)
switch n, err := c1.Read(readData); {
case err != nil:
t.Errorf("Unexpected error on Write: %v", err)
case n != len(testData):
t.Errorf("Expected to read %d bytes, got %d bytes", len(testData), n)
case !bytes.Equal(testData, readData[0:n]):
t.Errorf("Expected to read %v, got %v", testData, readData[0:n])
}
})
}

if err := ca.Close(); err != nil {
t.Errorf("Unexpected error on Close: %v", err)
}
if _, err := ca.Write(testData); !errors.Is(err, io.ErrClosedPipe) {
t.Errorf("Write to closed conn should fail with %v, got %v", io.ErrClosedPipe, err)
}

// Other side should be writable.
if _, err := cb.Write(testData); err != nil {
t.Errorf("Unexpected error on Write: %v", err)
}

readData := make([]byte, 4)
if _, err := ca.Read(readData); !errors.Is(err, io.EOF) {
t.Errorf("Read from closed conn should fail with %v, got %v", io.EOF, err)
}

// Other side should be readable.
readDone := make(chan struct{})
go func() {
readData := make([]byte, 4)
if n, err := cb.Read(readData); err == nil {
t.Errorf("Unexpected data %v was arrived to orphaned conn", readData[:n])
}
close(readDone)
}()
select {
case <-readDone:
t.Errorf("Read should be blocked if the other side is closed")
case <-time.After(10 * time.Millisecond):
}
if err := cb.Close(); err != nil {
t.Errorf("Unexpected error on Close: %v", err)
}
}