Skip to content

Commit

Permalink
Enable channel binding
Browse files Browse the repository at this point in the history
Further threa-safety improvements
Optimized inbound packet demuxing
Resolves #74
  • Loading branch information
enobufs committed Jul 12, 2019
1 parent 78d96fa commit 6440fe0
Show file tree
Hide file tree
Showing 7 changed files with 491 additions and 131 deletions.
156 changes: 114 additions & 42 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ type Client struct {
conn net.PacketConn // read-only
stunServ net.Addr // read-only
turnServ net.Addr // read-only
stunServStr string // read-only, used for dmuxing
turnServStr string // read-only, used for dmuxing
username stun.Username // read-only
password string // read-only
realm stun.Realm // read-only
integrity stun.MessageIntegrity // read-only
trMap *client.TransactionMap // thread-safe
rto time.Duration // read-only
relayedConn *client.UDPConn // protected by mutex
relayedConn *client.UDPConn // protected by mutex ***
allocTryLock client.TryLock // thread-safe
listenTryLock client.TryLock // thread-safe
net *vnet.Net // read-only
Expand All @@ -79,35 +81,40 @@ func NewClient(config *ClientConfig) (*Client, error) {
}

var stunServ, turnServ net.Addr
var stunServStr, turnServStr string
var err error
if len(config.STUNServerAddr) > 0 {
log.Debugf("resolving %s", config.STUNServerAddr)
stunServ, err = config.Net.ResolveUDPAddr("udp4", config.STUNServerAddr)
if err != nil {
return nil, err
}
log.Debugf("stunServ: %s", stunServ.String())
stunServStr = stunServ.String()
log.Debugf("stunServ: %s", stunServStr)
}
if len(config.TURNServerAddr) > 0 {
log.Debugf("resolving %s", config.TURNServerAddr)
turnServ, err = config.Net.ResolveUDPAddr("udp4", config.TURNServerAddr)
if err != nil {
return nil, err
}
log.Debugf("turnServ: %s", stunServ.String())
turnServStr = turnServ.String()
log.Debugf("turnServ: %s", turnServStr)
}

c := &Client{
conn: config.Conn,
stunServ: stunServ,
turnServ: turnServ,
username: stun.NewUsername(config.Username),
password: config.Password,
realm: stun.NewRealm(config.Realm),
net: config.Net,
trMap: client.NewTransactionMap(),
rto: defaultRTO,
log: log,
conn: config.Conn,
stunServ: stunServ,
turnServ: turnServ,
stunServStr: stunServStr,
turnServStr: turnServStr,
username: stun.NewUsername(config.Username),
password: config.Password,
realm: stun.NewRealm(config.Realm),
net: config.Net,
trMap: client.NewTransactionMap(),
rto: defaultRTO,
log: log,
}

return c, nil
Expand Down Expand Up @@ -213,9 +220,7 @@ func (c *Client) Allocate() (net.PacketConn, error) {
}
defer c.allocTryLock.Unlock()

c.mutex.RLock()
relayedConn := c.relayedConn
c.mutex.RUnlock()
relayedConn := c.relayedUDPConn()
if relayedConn != nil {
return nil, fmt.Errorf("already allocated at %s", relayedConn.LocalAddr().String())
}
Expand Down Expand Up @@ -283,6 +288,10 @@ func (c *Client) Allocate() (net.PacketConn, error) {
if err := relayed.GetFrom(res); err != nil {
return nil, err
}
relayedAddr := &net.UDPAddr{
IP: relayed.IP,
Port: relayed.Port,
}

// Getting lifetime from response
var lifetime turn.Lifetime
Expand All @@ -291,17 +300,15 @@ func (c *Client) Allocate() (net.PacketConn, error) {
}

relayedConn = client.NewUDPConn(&client.UDPConnConfig{
Observer: c,
Relayed: relayed,
Integrity: c.integrity,
Nonce: nonce,
Lifetime: lifetime.Duration,
Log: c.log,
Observer: c,
RelayedAddr: relayedAddr,
Integrity: c.integrity,
Nonce: nonce,
Lifetime: lifetime.Duration,
Log: c.log,
})

c.mutex.Lock()
c.relayedConn = relayedConn
c.mutex.Unlock()
c.setRelayedUDPConn(relayedConn)

return relayedConn, nil
}
Expand Down Expand Up @@ -338,27 +345,60 @@ func (c *Client) PerformTransaction(msg *stun.Message, to net.Addr, dontWait boo
return res, nil
}

// OnDeallocated is called when deallocation of relay address has been complete.
// (Called by UDPConn)
func (c *Client) OnDeallocated(relayedAddr net.Addr) {
c.setRelayedUDPConn(nil)
}

// HandleInbound handles data received.
// This method handles packets received only from the turn server address.
// If the source (from) address does not match, it would return (false, nil)
// to indicate the caller that the packet was not handled.
// This method handles incoming packet demultiplex it by the source address
// and the types of the message.
// This return a booleen (handled or not) and if there was an error.
// Caller should check if the packet was handled by this client or not.
// If not handled, it is assumed that the packet is application data.
// If an error is returned, the caller should discard the packet regardless.
func (c *Client) HandleInbound(data []byte, from net.Addr) (bool, error) {
c.log.Debug("HandleInbound: in")
defer c.log.Debug("HandleInbound: out")
var handled bool
var err error

switch {
case stun.IsMessage(data):
if stun.IsMessage(data) {
handled = true
err = c.handleSTUNMessage(data, from)
case turn.IsChannelData(data):
err = c.handleChannelData(data)
default:
} else if len(c.turnServStr) != 0 && from.String() == c.turnServStr {
handled = true
// received from TURN server
if turn.IsChannelData(data) {
err = c.handleChannelData(data)
} else {
err = fmt.Errorf("unexpected packet from TURN server")
}
} else if len(c.stunServStr) != 0 && from.String() == c.stunServStr {
handled = true
// received from STUN server but it is not a STUN message
err = fmt.Errorf("non-STUN message from STUN server")
} else {
// assume, this is an application data
c.log.Tracef("non-STUN/TURN packect, unhandled")
return false, nil // unhandled
}

return true, err
// +---------+---------+-------------------------------+
// | handled | err | Meaning / Action |
// |=========+=========+===============================+
// | false | nil | Handle the packet as app data |
// |---------+---------+-------------------------------+
// | true | nil | Nothing to do |
// |---------+---------+-------------------------------+
// | false | error | (shouldn't happen) |
// |---------+---------+-------------------------------+
// | true | error | Error occurred while handling |
// +---------+---------+-------------------------------+
// Possible causes of the error:
// - Malformed packet (parse error)
// - STUN message was a request
// - Non-STUN message from the STUN server

return handled, err
}

func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
Expand Down Expand Up @@ -389,9 +429,7 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {

c.log.Debugf("data indication received from %s", from.String())

c.mutex.RLock()
relayedConn := c.relayedConn
c.mutex.RUnlock()
relayedConn := c.relayedUDPConn()
if relayedConn == nil {
c.log.Debug("no relayed conn allocated")
return nil // silently discard
Expand All @@ -414,6 +452,7 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
return nil
}

// End the transaction
tr.StopRtxTimer()
c.trMap.Delete(trKey)

Expand All @@ -425,8 +464,27 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
}

func (c *Client) handleChannelData(data []byte) error {
c.log.Debugf("handle %d bytes of ChannelData", len(data))
return fmt.Errorf("not implemented yet")
chData := &turn.ChannelData{
Raw: make([]byte, len(data)),
}
copy(chData.Raw, data)
if err := chData.Decode(); err != nil {
return err
}

relayedConn := c.relayedUDPConn()
if relayedConn != nil {
c.log.Debug("no relayed conn allocated")
return nil // silently discard
}

addr, ok := relayedConn.FindAddrByChannelNumber(uint16(chData.Number))
if !ok {
return fmt.Errorf("binding with channel %d not found", int(chData.Number))
}

relayedConn.HandleInbound(chData.Data, addr)
return nil
}

func (c *Client) onRtxTimeout(trKey string, nRtx int32) {
Expand Down Expand Up @@ -460,3 +518,17 @@ func (c *Client) onRtxTimeout(trKey string, nRtx int32) {
}
tr.StartRtxTimer(c.onRtxTimeout)
}

func (c *Client) setRelayedUDPConn(conn *client.UDPConn) {
c.mutex.Lock()
defer c.mutex.Unlock()

c.relayedConn = conn
}

func (c *Client) relayedUDPConn() *client.UDPConn {
c.mutex.RLock()
defer c.mutex.RUnlock()

return c.relayedConn
}
131 changes: 131 additions & 0 deletions internal/client/binding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package client

import (
"net"
"sync"
"sync/atomic"
)

// Chanel number:
// 0x4000 through 0x7FFF: These values are the allowed channel
// numbers (16,383 possible values).
const (
minChannelNumber uint16 = 0x4000
maxChannelNumber uint16 = 0x7fff
)

type bindingState int32

const (
bindingStateIdle bindingState = iota
bindingStateReady
bindingStateFailed
)

type binding struct {
st bindingState // thread-safe (atomic op)
addr net.Addr // read-only
number uint16 // read-only
mgr *bindingManager // read-only
mutex sync.Mutex // thread-safe, used in UDPConn
}

func (b *binding) setState(state bindingState) {
atomic.StoreInt32((*int32)(&b.st), int32(state))
}

func (b *binding) state() bindingState {
return bindingState(atomic.LoadInt32((*int32)(&b.st)))
}

// Thread-safe binding map
type bindingManager struct {
chanMap map[uint16]*binding
addrMap map[string]*binding
next uint16
mutex sync.RWMutex
}

func newBindingManager() *bindingManager {
return &bindingManager{
chanMap: map[uint16]*binding{},
addrMap: map[string]*binding{},
next: minChannelNumber,
}
}

func (mgr *bindingManager) assignChannelNumber() uint16 {
n := mgr.next
if mgr.next == maxChannelNumber {
mgr.next = minChannelNumber
} else {
mgr.next++
}
return n
}

func (mgr *bindingManager) create(addr net.Addr) *binding {
mgr.mutex.Lock()
defer mgr.mutex.Unlock()

b := &binding{
addr: addr,
number: mgr.assignChannelNumber(),
mgr: mgr,
}

mgr.chanMap[b.number] = b
mgr.addrMap[b.addr.String()] = b
return b
}

func (mgr *bindingManager) findByAddr(addr net.Addr) (*binding, bool) {
mgr.mutex.RLock()
defer mgr.mutex.RUnlock()

b, ok := mgr.addrMap[addr.String()]
return b, ok
}

func (mgr *bindingManager) findByNumber(number uint16) (*binding, bool) {
mgr.mutex.RLock()
defer mgr.mutex.RUnlock()

b, ok := mgr.chanMap[number]
return b, ok
}

func (mgr *bindingManager) deleteByAddr(addr net.Addr) bool {
mgr.mutex.Lock()
defer mgr.mutex.Unlock()

b, ok := mgr.addrMap[addr.String()]
if !ok {
return false
}

delete(mgr.addrMap, addr.String())
delete(mgr.chanMap, b.number)
return true
}

func (mgr *bindingManager) deleteByNumber(number uint16) bool {
mgr.mutex.Lock()
defer mgr.mutex.Unlock()

b, ok := mgr.chanMap[number]
if !ok {
return false
}

delete(mgr.addrMap, b.addr.String())
delete(mgr.chanMap, number)
return true
}

func (mgr *bindingManager) size() int {
mgr.mutex.RLock()
defer mgr.mutex.RUnlock()

return len(mgr.chanMap)
}
Loading

0 comments on commit 6440fe0

Please sign in to comment.