From 6440fe01452660da177af0ed2275b05dbdf86448 Mon Sep 17 00:00:00 2001 From: Yutaka Takeda Date: Fri, 12 Jul 2019 14:59:17 -0700 Subject: [PATCH] Enable channel binding Further threa-safety improvements Optimized inbound packet demuxing Resolves #74 --- client.go | 156 ++++++++++++++++++------- internal/client/binding.go | 131 +++++++++++++++++++++ internal/client/binding_test.go | 75 ++++++++++++ internal/client/channel.go | 18 --- internal/client/conn.go | 200 +++++++++++++++++++++++--------- internal/client/permission.go | 15 ++- server_vnet_test.go | 27 +++-- 7 files changed, 491 insertions(+), 131 deletions(-) create mode 100644 internal/client/binding.go create mode 100644 internal/client/binding_test.go delete mode 100644 internal/client/channel.go diff --git a/client.go b/client.go index 0bd64366..4da7e5ab 100644 --- a/client.go +++ b/client.go @@ -50,13 +50,15 @@ type Client struct { conn net.PacketConn // read-only stunServ net.Addr // read-only turnServ net.Addr // read-only + stunServStr string // read-only, used for dmuxing + turnServStr string // read-only, used for dmuxing username stun.Username // read-only password string // read-only realm stun.Realm // read-only integrity stun.MessageIntegrity // read-only trMap *client.TransactionMap // thread-safe rto time.Duration // read-only - relayedConn *client.UDPConn // protected by mutex + relayedConn *client.UDPConn // protected by mutex *** allocTryLock client.TryLock // thread-safe listenTryLock client.TryLock // thread-safe net *vnet.Net // read-only @@ -79,6 +81,7 @@ func NewClient(config *ClientConfig) (*Client, error) { } var stunServ, turnServ net.Addr + var stunServStr, turnServStr string var err error if len(config.STUNServerAddr) > 0 { log.Debugf("resolving %s", config.STUNServerAddr) @@ -86,7 +89,8 @@ func NewClient(config *ClientConfig) (*Client, error) { if err != nil { return nil, err } - log.Debugf("stunServ: %s", stunServ.String()) + stunServStr = stunServ.String() + log.Debugf("stunServ: %s", stunServStr) } if len(config.TURNServerAddr) > 0 { log.Debugf("resolving %s", config.TURNServerAddr) @@ -94,20 +98,23 @@ func NewClient(config *ClientConfig) (*Client, error) { if err != nil { return nil, err } - log.Debugf("turnServ: %s", stunServ.String()) + turnServStr = turnServ.String() + log.Debugf("turnServ: %s", turnServStr) } c := &Client{ - conn: config.Conn, - stunServ: stunServ, - turnServ: turnServ, - username: stun.NewUsername(config.Username), - password: config.Password, - realm: stun.NewRealm(config.Realm), - net: config.Net, - trMap: client.NewTransactionMap(), - rto: defaultRTO, - log: log, + conn: config.Conn, + stunServ: stunServ, + turnServ: turnServ, + stunServStr: stunServStr, + turnServStr: turnServStr, + username: stun.NewUsername(config.Username), + password: config.Password, + realm: stun.NewRealm(config.Realm), + net: config.Net, + trMap: client.NewTransactionMap(), + rto: defaultRTO, + log: log, } return c, nil @@ -213,9 +220,7 @@ func (c *Client) Allocate() (net.PacketConn, error) { } defer c.allocTryLock.Unlock() - c.mutex.RLock() - relayedConn := c.relayedConn - c.mutex.RUnlock() + relayedConn := c.relayedUDPConn() if relayedConn != nil { return nil, fmt.Errorf("already allocated at %s", relayedConn.LocalAddr().String()) } @@ -283,6 +288,10 @@ func (c *Client) Allocate() (net.PacketConn, error) { if err := relayed.GetFrom(res); err != nil { return nil, err } + relayedAddr := &net.UDPAddr{ + IP: relayed.IP, + Port: relayed.Port, + } // Getting lifetime from response var lifetime turn.Lifetime @@ -291,17 +300,15 @@ func (c *Client) Allocate() (net.PacketConn, error) { } relayedConn = client.NewUDPConn(&client.UDPConnConfig{ - Observer: c, - Relayed: relayed, - Integrity: c.integrity, - Nonce: nonce, - Lifetime: lifetime.Duration, - Log: c.log, + Observer: c, + RelayedAddr: relayedAddr, + Integrity: c.integrity, + Nonce: nonce, + Lifetime: lifetime.Duration, + Log: c.log, }) - c.mutex.Lock() - c.relayedConn = relayedConn - c.mutex.Unlock() + c.setRelayedUDPConn(relayedConn) return relayedConn, nil } @@ -338,27 +345,60 @@ func (c *Client) PerformTransaction(msg *stun.Message, to net.Addr, dontWait boo return res, nil } +// OnDeallocated is called when deallocation of relay address has been complete. +// (Called by UDPConn) +func (c *Client) OnDeallocated(relayedAddr net.Addr) { + c.setRelayedUDPConn(nil) +} + // HandleInbound handles data received. -// This method handles packets received only from the turn server address. -// If the source (from) address does not match, it would return (false, nil) -// to indicate the caller that the packet was not handled. +// This method handles incoming packet demultiplex it by the source address +// and the types of the message. +// This return a booleen (handled or not) and if there was an error. +// Caller should check if the packet was handled by this client or not. +// If not handled, it is assumed that the packet is application data. +// If an error is returned, the caller should discard the packet regardless. func (c *Client) HandleInbound(data []byte, from net.Addr) (bool, error) { - c.log.Debug("HandleInbound: in") - defer c.log.Debug("HandleInbound: out") + var handled bool var err error - switch { - case stun.IsMessage(data): + if stun.IsMessage(data) { + handled = true err = c.handleSTUNMessage(data, from) - case turn.IsChannelData(data): - err = c.handleChannelData(data) - default: + } else if len(c.turnServStr) != 0 && from.String() == c.turnServStr { + handled = true + // received from TURN server + if turn.IsChannelData(data) { + err = c.handleChannelData(data) + } else { + err = fmt.Errorf("unexpected packet from TURN server") + } + } else if len(c.stunServStr) != 0 && from.String() == c.stunServStr { + handled = true + // received from STUN server but it is not a STUN message + err = fmt.Errorf("non-STUN message from STUN server") + } else { // assume, this is an application data c.log.Tracef("non-STUN/TURN packect, unhandled") - return false, nil // unhandled } - return true, err + // +---------+---------+-------------------------------+ + // | handled | err | Meaning / Action | + // |=========+=========+===============================+ + // | false | nil | Handle the packet as app data | + // |---------+---------+-------------------------------+ + // | true | nil | Nothing to do | + // |---------+---------+-------------------------------+ + // | false | error | (shouldn't happen) | + // |---------+---------+-------------------------------+ + // | true | error | Error occurred while handling | + // +---------+---------+-------------------------------+ + // Possible causes of the error: + // - Malformed packet (parse error) + // - STUN message was a request + // - Non-STUN message from the STUN server + + return handled, err } func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { @@ -389,9 +429,7 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { c.log.Debugf("data indication received from %s", from.String()) - c.mutex.RLock() - relayedConn := c.relayedConn - c.mutex.RUnlock() + relayedConn := c.relayedUDPConn() if relayedConn == nil { c.log.Debug("no relayed conn allocated") return nil // silently discard @@ -414,6 +452,7 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { return nil } + // End the transaction tr.StopRtxTimer() c.trMap.Delete(trKey) @@ -425,8 +464,27 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { } func (c *Client) handleChannelData(data []byte) error { - c.log.Debugf("handle %d bytes of ChannelData", len(data)) - return fmt.Errorf("not implemented yet") + chData := &turn.ChannelData{ + Raw: make([]byte, len(data)), + } + copy(chData.Raw, data) + if err := chData.Decode(); err != nil { + return err + } + + relayedConn := c.relayedUDPConn() + if relayedConn != nil { + c.log.Debug("no relayed conn allocated") + return nil // silently discard + } + + addr, ok := relayedConn.FindAddrByChannelNumber(uint16(chData.Number)) + if !ok { + return fmt.Errorf("binding with channel %d not found", int(chData.Number)) + } + + relayedConn.HandleInbound(chData.Data, addr) + return nil } func (c *Client) onRtxTimeout(trKey string, nRtx int32) { @@ -460,3 +518,17 @@ func (c *Client) onRtxTimeout(trKey string, nRtx int32) { } tr.StartRtxTimer(c.onRtxTimeout) } + +func (c *Client) setRelayedUDPConn(conn *client.UDPConn) { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.relayedConn = conn +} + +func (c *Client) relayedUDPConn() *client.UDPConn { + c.mutex.RLock() + defer c.mutex.RUnlock() + + return c.relayedConn +} diff --git a/internal/client/binding.go b/internal/client/binding.go new file mode 100644 index 00000000..4a631d42 --- /dev/null +++ b/internal/client/binding.go @@ -0,0 +1,131 @@ +package client + +import ( + "net" + "sync" + "sync/atomic" +) + +// Chanel number: +// 0x4000 through 0x7FFF: These values are the allowed channel +// numbers (16,383 possible values). +const ( + minChannelNumber uint16 = 0x4000 + maxChannelNumber uint16 = 0x7fff +) + +type bindingState int32 + +const ( + bindingStateIdle bindingState = iota + bindingStateReady + bindingStateFailed +) + +type binding struct { + st bindingState // thread-safe (atomic op) + addr net.Addr // read-only + number uint16 // read-only + mgr *bindingManager // read-only + mutex sync.Mutex // thread-safe, used in UDPConn +} + +func (b *binding) setState(state bindingState) { + atomic.StoreInt32((*int32)(&b.st), int32(state)) +} + +func (b *binding) state() bindingState { + return bindingState(atomic.LoadInt32((*int32)(&b.st))) +} + +// Thread-safe binding map +type bindingManager struct { + chanMap map[uint16]*binding + addrMap map[string]*binding + next uint16 + mutex sync.RWMutex +} + +func newBindingManager() *bindingManager { + return &bindingManager{ + chanMap: map[uint16]*binding{}, + addrMap: map[string]*binding{}, + next: minChannelNumber, + } +} + +func (mgr *bindingManager) assignChannelNumber() uint16 { + n := mgr.next + if mgr.next == maxChannelNumber { + mgr.next = minChannelNumber + } else { + mgr.next++ + } + return n +} + +func (mgr *bindingManager) create(addr net.Addr) *binding { + mgr.mutex.Lock() + defer mgr.mutex.Unlock() + + b := &binding{ + addr: addr, + number: mgr.assignChannelNumber(), + mgr: mgr, + } + + mgr.chanMap[b.number] = b + mgr.addrMap[b.addr.String()] = b + return b +} + +func (mgr *bindingManager) findByAddr(addr net.Addr) (*binding, bool) { + mgr.mutex.RLock() + defer mgr.mutex.RUnlock() + + b, ok := mgr.addrMap[addr.String()] + return b, ok +} + +func (mgr *bindingManager) findByNumber(number uint16) (*binding, bool) { + mgr.mutex.RLock() + defer mgr.mutex.RUnlock() + + b, ok := mgr.chanMap[number] + return b, ok +} + +func (mgr *bindingManager) deleteByAddr(addr net.Addr) bool { + mgr.mutex.Lock() + defer mgr.mutex.Unlock() + + b, ok := mgr.addrMap[addr.String()] + if !ok { + return false + } + + delete(mgr.addrMap, addr.String()) + delete(mgr.chanMap, b.number) + return true +} + +func (mgr *bindingManager) deleteByNumber(number uint16) bool { + mgr.mutex.Lock() + defer mgr.mutex.Unlock() + + b, ok := mgr.chanMap[number] + if !ok { + return false + } + + delete(mgr.addrMap, b.addr.String()) + delete(mgr.chanMap, number) + return true +} + +func (mgr *bindingManager) size() int { + mgr.mutex.RLock() + defer mgr.mutex.RUnlock() + + return len(mgr.chanMap) +} diff --git a/internal/client/binding_test.go b/internal/client/binding_test.go new file mode 100644 index 00000000..31e4e386 --- /dev/null +++ b/internal/client/binding_test.go @@ -0,0 +1,75 @@ +package client + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBindingManager(t *testing.T) { + t.Run("number assignment", func(t *testing.T) { + m := newBindingManager() + var n uint16 + for i := uint16(0); i < 10; i++ { + n = m.assignChannelNumber() + assert.Equal(t, minChannelNumber+i, n, "should match") + } + + m.next = uint16(0x7ff0) + for i := uint16(0); i < 16; i++ { + n = m.assignChannelNumber() + assert.Equal(t, 0x7ff0+i, n, "should match") + } + // back to min + n = m.assignChannelNumber() + assert.Equal(t, minChannelNumber, n, "should match") + }) + + t.Run("method test", func(t *testing.T) { + lo := net.IPv4(127, 0, 0, 1) + count := 100 + m := newBindingManager() + for i := 0; i < count; i++ { + addr := &net.UDPAddr{IP: lo, Port: 10000 + i} + b0 := m.create(addr) + b1, ok := m.findByAddr(addr) + assert.True(t, ok, "should succeed") + b2, ok := m.findByNumber(b0.number) + assert.True(t, ok, "should succeed") + + assert.Equal(t, b0, b1, "should match") + assert.Equal(t, b0, b2, "should match") + } + + assert.Equal(t, count, m.size(), "should match") + assert.Equal(t, count, len(m.addrMap), "should match") + + for i := 0; i < count; i++ { + addr := &net.UDPAddr{IP: lo, Port: 10000 + i} + if i%2 == 0 { + assert.True(t, m.deleteByAddr(addr), "should return true") + } else { + assert.True(t, m.deleteByNumber(minChannelNumber+uint16(i)), "should return true") + } + } + + assert.Equal(t, 0, m.size(), "should match") + assert.Equal(t, 0, len(m.addrMap), "should match") + }) + + t.Run("failure test", func(t *testing.T) { + addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7777} + m := newBindingManager() + var ok bool + _, ok = m.findByAddr(addr) + assert.False(t, ok, "should fail") + _, ok = m.findByNumber(uint16(5555)) + assert.False(t, ok, "should fail") + ok = m.deleteByAddr(addr) + assert.False(t, ok, "should fail") + ok = m.deleteByNumber(uint16(5555)) + assert.False(t, ok, "should fail") + + }) +} diff --git a/internal/client/channel.go b/internal/client/channel.go deleted file mode 100644 index fda43079..00000000 --- a/internal/client/channel.go +++ /dev/null @@ -1,18 +0,0 @@ -package client - -/* -import ( - "sync" - "github.com/gortc/turn" - "github.com/pion/logging" -) -*/ - -type channel struct { - /* - log logging.LeveledLogger - mutex sync.RWMutex - number turn.ChannelNumber - peerAddr turn.PeerAddress - */ -} diff --git a/internal/client/conn.go b/internal/client/conn.go index 762b8a4a..af191ba4 100644 --- a/internal/client/conn.go +++ b/internal/client/conn.go @@ -39,57 +39,54 @@ type UDPConnObserver interface { Realm() stun.Realm WriteTo(data []byte, to net.Addr) (int, error) PerformTransaction(msg *stun.Message, to net.Addr, dontWait bool) (TransactionResult, error) + OnDeallocated(relayedAddr net.Addr) } // UDPConnConfig is a set of configuration params use by NewUDPConn type UDPConnConfig struct { - Observer UDPConnObserver - Relayed turn.RelayedAddress - Integrity stun.MessageIntegrity - Nonce stun.Nonce - Lifetime time.Duration - Log logging.LeveledLogger + Observer UDPConnObserver + RelayedAddr net.Addr + Integrity stun.MessageIntegrity + Nonce stun.Nonce + Lifetime time.Duration + Log logging.LeveledLogger } // UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections. // comatible with net.PacketConn and net.Conn type UDPConn struct { - obs UDPConnObserver // read-only - relayed turn.RelayedAddress // read-only - permMap *permissionMap // thread-safe - channelMap map[string]*channel // needs mutex ? - bindingMap map[turn.ChannelNumber]*channel // needs mutex ? - minBound turn.ChannelNumber // needs mutex ? - integrity stun.MessageIntegrity // read-only - nonce stun.Nonce // read-only - lifetime time.Duration // needs mutex x - readCh chan *inboundData // thread-safe - closeCh chan struct{} // thread-safe - closed *AtomicBool // thread-safe - readTimer *time.Timer // thread-safe - refreshAllocTimer *PeriodicTimer // thread-safe - refreshPermsTimer *PeriodicTimer // thread-safe - mutex sync.RWMutex // thread-safe - log logging.LeveledLogger // read-only + obs UDPConnObserver // read-only + relayedAddr net.Addr // read-only + permMap *permissionMap // thread-safe + bindingMgr *bindingManager // thread-safe + integrity stun.MessageIntegrity // read-only + nonce stun.Nonce // read-only + lifetime time.Duration // needs mutex x + readCh chan *inboundData // thread-safe + closeCh chan struct{} // thread-safe + closed *AtomicBool // thread-safe + readTimer *time.Timer // thread-safe + refreshAllocTimer *PeriodicTimer // thread-safe + refreshPermsTimer *PeriodicTimer // thread-safe + mutex sync.RWMutex // thread-safe + log logging.LeveledLogger // read-only } // NewUDPConn creates a new instance of UDPConn func NewUDPConn(config *UDPConnConfig) *UDPConn { c := &UDPConn{ - obs: config.Observer, - relayed: config.Relayed, - permMap: newPermissionMap(), - channelMap: map[string]*channel{}, - bindingMap: map[turn.ChannelNumber]*channel{}, - minBound: turn.MinChannelNumber, - integrity: config.Integrity, - nonce: config.Nonce, - lifetime: config.Lifetime, - readCh: make(chan *inboundData, maxReadQueueSize), - closeCh: make(chan struct{}), - closed: NewAtomicBool(false), - readTimer: time.NewTimer(time.Duration(math.MaxInt64)), - log: config.Log, + obs: config.Observer, + relayedAddr: config.RelayedAddr, + permMap: newPermissionMap(), + bindingMgr: newBindingManager(), + integrity: config.Integrity, + nonce: config.Nonce, + lifetime: config.Lifetime, + readCh: make(chan *inboundData, maxReadQueueSize), + closeCh: make(chan struct{}), + closed: NewAtomicBool(false), + readTimer: time.NewTimer(time.Duration(math.MaxInt64)), + log: config.Log, } c.log.Debugf("initial lifetime: %d seconds", int(c.lifetime.Seconds())) @@ -186,13 +183,13 @@ func (c *UDPConn) WriteTo(p []byte, addr net.Addr) (int, error) { perm.mutex.Lock() defer perm.mutex.Unlock() - if perm.state == permStateIdle { + if perm.state() == permStateIdle { // punch a hole! (this would block a bit..) if err := c.createPermissions(addr); err != nil { c.permMap.delete(addr) return err } - perm.state = permStatePermitted + perm.setState(permStatePermitted) } return nil }() @@ -200,26 +197,57 @@ func (c *UDPConn) WriteTo(p []byte, addr net.Addr) (int, error) { return 0, err } - // TODO: bind channel for better performance here + // bind channel - // send data using SendIndication - // TODO: send over channel when it becomes available - peerAddr := addr2PeerAddress(addr) - msg, err := stun.Build( - stun.TransactionID, - stun.NewType(stun.MethodSend, stun.ClassIndication), - turn.RequestedTransportUDP, - turn.Data(p), - peerAddr, - stun.Fingerprint, - ) - if err != nil { - return 0, err + b, ok := c.bindingMgr.findByAddr(addr) + if !ok { + b = c.bindingMgr.create(addr) } + if b.state() != bindingStateReady { + if b.state() == bindingStateIdle { + func() { + // block only callers with the same binding until + // the binding transaction has been complete + b.mutex.Lock() + defer b.mutex.Unlock() + + // binding state may have been changed while waiting. check again. + if b.state() == bindingStateIdle { + err = c.bind(b) + if err != nil { + c.log.Warnf("bind() failed: %s", err.Error()) + b.setState(bindingStateFailed) + // keep going... + // TODO: consider try binding again after a while + } else { + b.setState(bindingStateReady) + } + } + }() + } - // indication has no transaction (fire-and-forget) + // send data using SendIndication + // TODO: send over channel when it becomes available + peerAddr := addr2PeerAddress(addr) + msg, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodSend, stun.ClassIndication), + turn.RequestedTransportUDP, + turn.Data(p), + peerAddr, + stun.Fingerprint, + ) + if err != nil { + return 0, err + } + + // indication has no transaction (fire-and-forget) - return c.obs.WriteTo(msg.Raw, c.obs.TURNServerAddr()) + return c.obs.WriteTo(msg.Raw, c.obs.TURNServerAddr()) + } + + // send via ChannelData + return c.sendChannelData(p, b.number) } // Close closes the connection. @@ -236,12 +264,13 @@ func (c *UDPConn) Close() error { } c.refreshAllocation(0, true) // dontWait = true + c.obs.OnDeallocated(c.relayedAddr) return nil } // LocalAddr returns the local network address. func (c *UDPConn) LocalAddr() net.Addr { - return &net.UDPAddr{IP: c.relayed.IP, Port: c.relayed.Port} + return c.relayedAddr } // SetDeadline sets the read and write deadlines associated @@ -344,7 +373,7 @@ func (c *UDPConn) createPermissions(addrs ...net.Addr) error { return nil } -// HandleInbound passes inbound data to UDPConn +// HandleInbound passes inbound data in UDPConn func (c *UDPConn) HandleInbound(data []byte, from net.Addr) { select { case c.readCh <- &inboundData{data: data, from: from}: @@ -353,6 +382,16 @@ func (c *UDPConn) HandleInbound(data []byte, from net.Addr) { } } +// FindAddrByChannelNumber returns a peer address associated with the +// channel number on this UDPConn +func (c *UDPConn) FindAddrByChannelNumber(chNum uint16) (net.Addr, bool) { + b, ok := c.bindingMgr.findByNumber(chNum) + if !ok { + return nil, false + } + return b.addr, true +} + func (c *UDPConn) refreshAllocation(lifetime time.Duration, dontWait bool) { msg, err := stun.Build( stun.TransactionID, @@ -402,6 +441,53 @@ func (c *UDPConn) refreshPermissions() { c.log.Debug("refresh permissions successful") } +func (c *UDPConn) bind(b *binding) error { + setters := []stun.Setter{ + stun.TransactionID, + stun.NewType(stun.MethodChannelBind, stun.ClassRequest), + turn.RequestedTransportUDP, + addr2PeerAddress(b.addr), + turn.ChannelNumber(b.number), + c.obs.Username(), + c.obs.Realm(), + c.nonce, + c.integrity, + stun.Fingerprint, + } + + msg, err := stun.Build(setters...) + if err != nil { + return err + } + + trRes, err := c.obs.PerformTransaction(msg, c.obs.TURNServerAddr(), false) + if err != nil { + c.bindingMgr.deleteByAddr(b.addr) + } + + res := trRes.Msg + + if res.Type != stun.NewType(stun.MethodChannelBind, stun.ClassSuccessResponse) { + return fmt.Errorf("unexpected response type %s", res.Type) + } + + c.log.Debugf("channel binding successful: %s %d", + b.addr.String(), + b.number) + + // Success. + return nil +} + +func (c *UDPConn) sendChannelData(data []byte, chNum uint16) (int, error) { + chData := &turn.ChannelData{ + Data: data, + Number: turn.ChannelNumber(chNum), + } + chData.Encode() + return c.obs.WriteTo(chData.Raw, c.obs.TURNServerAddr()) +} + func (c *UDPConn) onRefreshTimers(id int) { c.log.Debugf("refresh timer %d expired", id) c.mutex.RLock() diff --git a/internal/client/permission.go b/internal/client/permission.go index 66642e84..5546a22e 100644 --- a/internal/client/permission.go +++ b/internal/client/permission.go @@ -3,9 +3,10 @@ package client import ( "net" "sync" + "sync/atomic" ) -type permState int +type permState int32 const ( permStateIdle permState = iota @@ -13,8 +14,16 @@ const ( ) type permission struct { - state permState - mutex sync.RWMutex + st permState // thread-safe (atomic op) + mutex sync.RWMutex // thread-safe +} + +func (p *permission) setState(state permState) { + atomic.StoreInt32((*int32)(&p.st), int32(state)) +} + +func (p *permission) state() permState { + return permState(atomic.LoadInt32((*int32)(&p.st))) } // Thread-safe permission map diff --git a/server_vnet_test.go b/server_vnet_test.go index 3d255d91..7e9f4372 100644 --- a/server_vnet_test.go +++ b/server_vnet_test.go @@ -182,7 +182,7 @@ func TestServerVNet(t *testing.T) { assert.True(t, udpAddr.IP.Equal(net.IPv4(5, 6, 7, 8)), "should match") }) - t.Run("Allocate", func(t *testing.T) { + t.Run("Echo via relay", func(t *testing.T) { v, err := buildVNet() if !assert.NoError(t, err, "should succeed") { return @@ -254,18 +254,23 @@ func TestServerVNet(t *testing.T) { } }() - log.Debug("sending \"Hello\"..") - _, err = conn.WriteTo([]byte("Hello"), echoConn.LocalAddr()) - if !assert.NoError(t, err, "should succeed") { - return - } - buf := make([]byte, 1500) - _, from, err := conn.ReadFrom(buf) - assert.NoError(t, err, "should succeed") - // verify the message was received from the relay address - assert.Equal(t, echoConn.LocalAddr().String(), from.String(), "should match") + for i := 0; i < 4; i++ { + log.Debug("sending \"Hello\"..") + _, err = conn.WriteTo([]byte("Hello"), echoConn.LocalAddr()) + if !assert.NoError(t, err, "should succeed") { + return + } + + _, from, err := conn.ReadFrom(buf) + assert.NoError(t, err, "should succeed") + + // verify the message was received from the relay address + assert.Equal(t, echoConn.LocalAddr().String(), from.String(), "should match") + + time.Sleep(200 * time.Millisecond) + } err = conn.Close() assert.NoError(t, err, "should succeed")