From dc81faead4bb48275492bea0d9e858ae33e69b22 Mon Sep 17 00:00:00 2001 From: Mikhail Bragin Date: Fri, 18 Feb 2022 13:10:04 +0100 Subject: [PATCH] Support UDP muxing for SRFLX candidates The original UDPMux only works for the HOST candidates. UniversalUDPMux adds support for SRFLX candidates and will later support Relay candidates. UniversalUDPMux embeds UDPMuxDefault and handles STUN server packets to discover XORMappedAddr forwarding the remaining packets for muxing to UDPMuxDefault. --- agent.go | 11 +- agent_config.go | 6 + errors.go | 4 + gather.go | 68 +++++++++- gather_test.go | 92 +++++++++++++ udp_mux_universal.go | 265 ++++++++++++++++++++++++++++++++++++++ udp_mux_universal_test.go | 127 ++++++++++++++++++ 7 files changed, 569 insertions(+), 4 deletions(-) create mode 100644 udp_mux_universal.go create mode 100644 udp_mux_universal_test.go diff --git a/agent.go b/agent.go index b4aba19f..aa2a1d6f 100644 --- a/agent.go +++ b/agent.go @@ -122,9 +122,10 @@ type Agent struct { loggerFactory logging.LoggerFactory log logging.LeveledLogger - net *vnet.Net - tcpMux TCPMux - udpMux UDPMux + net *vnet.Net + tcpMux TCPMux + udpMux UDPMux + udpMuxSrflx UniversalUDPMux interfaceFilter func(string) bool @@ -319,6 +320,7 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit a.tcpMux = newInvalidTCPMux() } a.udpMux = config.UDPMux + a.udpMuxSrflx = config.UDPMuxSrflx if a.net == nil { a.net = vnet.NewNet(nil) @@ -892,6 +894,9 @@ func (a *Agent) removeUfragFromMux() { if a.udpMux != nil { a.udpMux.RemoveConnByUfrag(a.localUfrag) } + if a.udpMuxSrflx != nil { + a.udpMuxSrflx.RemoveConnByUfrag(a.localUfrag) + } } // Close cleans up the Agent diff --git a/agent_config.go b/agent_config.go index e577af10..b373939b 100644 --- a/agent_config.go +++ b/agent_config.go @@ -150,6 +150,12 @@ type AgentConfig struct { // defer to UDPMux for incoming connections UDPMux UDPMux + // UDPMuxSrflx is used for multiplexing multiple incoming UDP connections of server reflexive candidates + // on a single port when this is set, the agent ignores PortMin and PortMax configurations and will + // defer to UDPMuxSrflx for incoming connections + // It embeds UDPMux to do the actual connection multiplexing + UDPMuxSrflx UniversalUDPMux + // Proxy Dialer is a dialer that should be implemented by the user based on golang.org/x/net/proxy // dial interface in order to support corporate proxies ProxyDialer proxy.Dialer diff --git a/errors.go b/errors.go index 8ca9c2cd..2ef595e5 100644 --- a/errors.go +++ b/errors.go @@ -135,4 +135,8 @@ var ( errICEWriteSTUNMessage = errors.New("the ICE conn can't write STUN messages") errUDPMuxDisabled = errors.New("UDPMux is not enabled") errCandidateIPNotFound = errors.New("could not determine local IP for Mux candidate") + errNoXorAddrMapping = errors.New("no address mapping") + errSendSTUNPacket = errors.New("failed to send STUN packet") + errXORMappedAddrTimeout = errors.New("timeout while waiting for XORMappedAddr") + errNotImplemented = errors.New("not implemented yet") ) diff --git a/gather.go b/gather.go index e248c08d..8d2c40db 100644 --- a/gather.go +++ b/gather.go @@ -97,7 +97,11 @@ func (a *Agent) gatherCandidates(ctx context.Context) { case CandidateTypeServerReflexive: wg.Add(1) go func() { - a.gatherCandidatesSrflx(ctx, a.urls, a.networkTypes) + if a.udpMuxSrflx != nil { + a.gatherCandidatesSrflxUDPMux(ctx, a.urls, a.networkTypes) + } else { + a.gatherCandidatesSrflx(ctx, a.urls, a.networkTypes) + } wg.Done() }() if a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeServerReflexive { @@ -333,6 +337,68 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes [] } } +func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, networkTypes []NetworkType) { + var wg sync.WaitGroup + defer wg.Wait() + + for _, networkType := range networkTypes { + if networkType.IsTCP() { + continue + } + + for i := range urls { + wg.Add(1) + go func(url URL, network string) { + defer wg.Done() + + hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port) + serverAddr, err := a.net.ResolveUDPAddr(network, hostPort) + if err != nil { + a.log.Warnf("failed to resolve stun host: %s: %v", hostPort, err) + return + } + + xoraddr, err := a.udpMuxSrflx.GetXORMappedAddr(serverAddr, stunGatherTimeout) + if err != nil { + a.log.Warnf("could not get server reflexive address %s %s: %v\n", network, url, err) + return + } + + conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String()) + if err != nil { + a.log.Warnf("could not find connection in UDPMuxSrflx %s %s: %v\n", network, url, err) + return + } + + ip := xoraddr.IP + port := xoraddr.Port + + laddr := conn.LocalAddr().(*net.UDPAddr) + srflxConfig := CandidateServerReflexiveConfig{ + Network: network, + Address: ip.String(), + Port: port, + Component: ComponentRTP, + RelAddr: laddr.IP.String(), + RelPort: laddr.Port, + } + c, err := NewCandidateServerReflexive(&srflxConfig) + if err != nil { + closeConnAndLog(conn, a.log, fmt.Sprintf("Failed to create server reflexive candidate: %s %s %d: %v\n", network, ip, port, err)) + return + } + + if err := a.addCandidate(ctx, c, conn); err != nil { + if closeErr := c.close(); closeErr != nil { + a.log.Warnf("Failed to close candidate: %v", closeErr) + } + a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v\n", err) + } + }(*urls[i], networkType.String()) + } + } +} + func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*URL, networkTypes []NetworkType) { var wg sync.WaitGroup defer wg.Wait() diff --git a/gather_test.go b/gather_test.go index e69d9875..23c8296a 100644 --- a/gather_test.go +++ b/gather_test.go @@ -11,12 +11,14 @@ import ( "reflect" "sort" "strconv" + "sync" "testing" "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/pkg/crypto/selfsign" "github.com/pion/logging" + "github.com/pion/stun" "github.com/pion/transport/test" "github.com/pion/turn/v2" "github.com/stretchr/testify/assert" @@ -484,3 +486,93 @@ func TestTURNProxyDialer(t *testing.T) { assert.NoError(t, a.Close()) } + +// Assert that UniversalUDPMux is used while gathering when configured in the Agent +func TestUniversalUDPMuxUsage(t *testing.T) { + report := test.CheckRoutines(t) + defer report() + + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: randomPort(t)}) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + udpMuxSrflx := &universalUDPMuxMock{ + conn: conn, + } + + numSTUNS := 3 + urls := []*URL{} + for i := 0; i < numSTUNS; i++ { + urls = append(urls, &URL{ + Scheme: SchemeTypeSTUN, + Host: "127.0.0.1", + Port: 3478 + i, + }) + } + + a, err := NewAgent(&AgentConfig{ + NetworkTypes: supportedNetworkTypes(), + Urls: urls, + CandidateTypes: []CandidateType{CandidateTypeServerReflexive}, + UDPMuxSrflx: udpMuxSrflx, + }) + assert.NoError(t, err) + + candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) + assert.NoError(t, a.OnCandidate(func(c Candidate) { + if c == nil { + candidateGatheredFunc() + return + } + t.Log(c.NetworkType(), c.Priority(), c) + })) + assert.NoError(t, a.GatherCandidates()) + + <-candidateGathered.Done() + + assert.NoError(t, a.Close()) + // twice because of 2 STUN servers configured + assert.Equal(t, numSTUNS, udpMuxSrflx.getXORMappedAddrUsedTimes, "expected times that GetXORMappedAddr should be called") + // one for Restart() when agent has been initialized and one time when Close() the agent + assert.Equal(t, 2, udpMuxSrflx.removeConnByUfragTimes, "expected times that RemoveConnByUfrag should be called") + // twice because of 2 STUN servers configured + assert.Equal(t, numSTUNS, udpMuxSrflx.getConnForURLTimes, "expected times that GetConnForURL should be called") +} + +type universalUDPMuxMock struct { + UDPMux + getXORMappedAddrUsedTimes int + removeConnByUfragTimes int + getConnForURLTimes int + mu sync.Mutex + conn *net.UDPConn +} + +func (m *universalUDPMuxMock) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) { + return nil, errNotImplemented +} + +func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string) (net.PacketConn, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.getConnForURLTimes++ + return m.conn, nil +} + +func (m *universalUDPMuxMock) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.getXORMappedAddrUsedTimes++ + return &stun.XORMappedAddress{IP: net.IP{100, 64, 0, 1}, Port: 77878}, nil +} + +func (m *universalUDPMuxMock) RemoveConnByUfrag(ufrag string) { + m.mu.Lock() + defer m.mu.Unlock() + m.removeConnByUfragTimes++ +} diff --git a/udp_mux_universal.go b/udp_mux_universal.go new file mode 100644 index 00000000..c6198775 --- /dev/null +++ b/udp_mux_universal.go @@ -0,0 +1,265 @@ +package ice + +import ( + "fmt" + "net" + "time" + + "github.com/pion/logging" + "github.com/pion/stun" +) + +// UniversalUDPMux allows multiple connections to go over a single UDP port for +// host, server reflexive and relayed candidates. +// Actual connection muxing is happening in the UDPMux. +type UniversalUDPMux interface { + UDPMux + GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) + GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) + GetConnForURL(ufrag string, url string) (net.PacketConn, error) +} + +// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom. +// It the passes packets to the UDPMux that does the actual connection muxing. +type UniversalUDPMuxDefault struct { + *UDPMuxDefault + params UniversalUDPMuxParams + + // since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents + // stun.XORMappedAddress indexed by the STUN server addr + xorMappedMap map[string]*xorMapped +} + +// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive. +type UniversalUDPMuxParams struct { + Logger logging.LeveledLogger + UDPConn net.PacketConn + XORMappedAddrCacheTTL time.Duration +} + +// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux +func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault { + if params.Logger == nil { + params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") + } + if params.XORMappedAddrCacheTTL == 0 { + params.XORMappedAddrCacheTTL = time.Second * 25 + } + + m := &UniversalUDPMuxDefault{ + params: params, + xorMappedMap: make(map[string]*xorMapped), + } + + // wrap UDP connection, process server reflexive messages + // before they are passed to the UDPMux connection handler (connWorker) + m.params.UDPConn = &udpConn{ + PacketConn: params.UDPConn, + mux: m, + logger: params.Logger, + } + + // embed UDPMux + udpMuxParams := UDPMuxParams{ + Logger: params.Logger, + UDPConn: m.params.UDPConn, + } + m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) + + return m +} + +// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets +type udpConn struct { + net.PacketConn + mux *UniversalUDPMuxDefault + logger logging.LeveledLogger +} + +// GetRelayedAddr creates relayed connection to the given TURN service and returns the relayed addr. +// Not implemented yet. +func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) { + return nil, errNotImplemented +} + +// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers +// and return a unique connection per server. +func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string) (net.PacketConn, error) { + return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url)) +} + +// ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address. +// It passes processed packets further to the UDPMux (maybe this is not really necessary). +func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = c.PacketConn.ReadFrom(p) + if err != nil { + return + } + + if stun.IsMessage(p[:n]) { + msg := &stun.Message{ + Raw: append([]byte{}, p[:n]...), + } + + if err = msg.Decode(); err != nil { + c.logger.Warnf("Failed to handle decode ICE from %s: %v\n", addr.String(), err) + return n, addr, nil + } + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + // message about this err will be logged in the UDPMux + return + } + + if c.mux.isXORMappedResponse(msg, udpAddr.String()) { + err = c.mux.handleXORMappedResponse(udpAddr, msg) + if err != nil { + c.logger.Debugf("%w: %v", errGetXorMappedAddrResponse, err) + return n, addr, nil + } + return + } + } + return n, addr, err +} + +// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. +func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool { + m.mu.Lock() + defer m.mu.Unlock() + // check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess + _, ok := m.xorMappedMap[stunAddr] + _, err := msg.Get(stun.AttrXORMappedAddress) + return err == nil && ok +} + +// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute +// and set the mapped address for the server +func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error { + m.mu.Lock() + defer m.mu.Unlock() + + mappedAddr, ok := m.xorMappedMap[stunAddr.String()] + if !ok { + return errNoXorAddrMapping + } + + var addr stun.XORMappedAddress + if err := addr.GetFrom(msg); err != nil { + return err + } + + m.xorMappedMap[stunAddr.String()] = mappedAddr + mappedAddr.SetAddr(&addr) + + return nil +} + +// GetXORMappedAddr returns *stun.XORMappedAddress if already present for a given STUN server. +// Makes a STUN binding request to discover mapped address otherwise. +// Blocks until the stun.XORMappedAddress has been discovered or deadline. +// Method is safe for concurrent use. +func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) { + m.mu.Lock() + mappedAddr, ok := m.xorMappedMap[serverAddr.String()] + // if we already have a mapping for this STUN server (address already received) + // and if it is not too old we return it without making a new request to STUN server + if ok { + if mappedAddr.expired() { + mappedAddr.closeWaiters() + delete(m.xorMappedMap, serverAddr.String()) + ok = false + } else if mappedAddr.pending() { + ok = false + } + } + m.mu.Unlock() + if ok { + return mappedAddr.addr, nil + } + + // otherwise, make a STUN request to discover the address + // or wait for already sent request to complete + waitAddrReceived, err := m.sendStun(serverAddr) + if err != nil { + return nil, errSendSTUNPacket + } + + // block until response was handled by the connWorker routine and XORMappedAddress was updated + select { + case <-waitAddrReceived: + // when channel closed, addr was obtained + m.mu.Lock() + mappedAddr := *m.xorMappedMap[serverAddr.String()] + m.mu.Unlock() + if mappedAddr.addr == nil { + return nil, errNoXorAddrMapping + } + return mappedAddr.addr, nil + case <-time.After(deadline): + return nil, errXORMappedAddrTimeout + } +} + +// sendStun sends a STUN request via UDP conn. +// +// The returned channel is closed when the STUN response has been received. +// Method is safe for concurrent use. +func (m *UniversalUDPMuxDefault) sendStun(serverAddr net.Addr) (chan struct{}, error) { + m.mu.Lock() + defer m.mu.Unlock() + + // if record present in the map, we already sent a STUN request, + // just wait when waitAddrReceived will be closed + addrMap, ok := m.xorMappedMap[serverAddr.String()] + if !ok { + addrMap = &xorMapped{ + expiresAt: time.Now().Add(m.params.XORMappedAddrCacheTTL), + waitAddrReceived: make(chan struct{}), + } + m.xorMappedMap[serverAddr.String()] = addrMap + } + + req, err := stun.Build(stun.BindingRequest, stun.TransactionID) + if err != nil { + return nil, err + } + + if _, err = m.params.UDPConn.WriteTo(req.Raw, serverAddr); err != nil { + return nil, err + } + + return addrMap.waitAddrReceived, nil +} + +type xorMapped struct { + addr *stun.XORMappedAddress + waitAddrReceived chan struct{} + expiresAt time.Time +} + +func (a *xorMapped) closeWaiters() { + select { + case <-a.waitAddrReceived: + // notify was close, ok, that means we received duplicate response + // just exit + break + default: + // notify tha twe have a new addr + close(a.waitAddrReceived) + } +} + +func (a *xorMapped) pending() bool { + return a.addr == nil +} + +func (a *xorMapped) expired() bool { + return a.expiresAt.Before(time.Now()) +} + +func (a *xorMapped) SetAddr(addr *stun.XORMappedAddress) { + a.addr = addr + a.closeWaiters() +} diff --git a/udp_mux_universal_test.go b/udp_mux_universal_test.go new file mode 100644 index 00000000..c263bf61 --- /dev/null +++ b/udp_mux_universal_test.go @@ -0,0 +1,127 @@ +// +build !js + +package ice + +import ( + "net" + "sync" + "testing" + "time" + + "github.com/pion/stun" + "github.com/stretchr/testify/require" +) + +func TestUniversalUDPMux(t *testing.T) { + conn, err := net.ListenUDP(udp, &net.UDPAddr{}) + require.NoError(t, err) + + udpMux := NewUniversalUDPMuxDefault(UniversalUDPMuxParams{ + Logger: nil, + UDPConn: conn, + }) + + require.NoError(t, err) + defer func() { + _ = udpMux.Close() + _ = conn.Close() + }() + + require.NotNil(t, udpMux.LocalAddr(), "tcpMux.LocalAddr() is nil") + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + testMuxSrflxConnection(t, udpMux, "ufrag4", udp) + }() + + wg.Wait() +} + +func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) { + pktConn, err := udpMux.GetConn(ufrag) + require.NoError(t, err, "error retrieving muxed connection for ufrag") + defer func() { + _ = pktConn.Close() + }() + + remoteConn, err := net.DialUDP(network, nil, &net.UDPAddr{ + Port: udpMux.LocalAddr().(*net.UDPAddr).Port, + }) + require.NoError(t, err, "error dialing test udp connection") + defer func() { + _ = remoteConn.Close() + }() + + // use small value for TTL to check expiration of the address + udpMux.params.XORMappedAddrCacheTTL = time.Millisecond * 20 + testXORIP := net.ParseIP("213.141.156.236") + testXORPort := 21254 + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + address, e := udpMux.GetXORMappedAddr(remoteConn.LocalAddr(), time.Second) + require.NoError(t, e) + require.NotNil(t, address) + require.True(t, address.IP.Equal(testXORIP)) + require.Equal(t, address.Port, testXORPort) + }() + + // wait until GetXORMappedAddr calls sendStun method + time.Sleep(time.Millisecond) + + // check that mapped address filled correctly after sent stun + udpMux.mu.Lock() + mappedAddr, ok := udpMux.xorMappedMap[remoteConn.LocalAddr().String()] + require.True(t, ok) + require.NotNil(t, mappedAddr) + require.True(t, mappedAddr.pending()) + require.False(t, mappedAddr.expired()) + udpMux.mu.Unlock() + + // clean receiver read buffer + buf := make([]byte, receiveMTU) + _, err = remoteConn.Read(buf) + require.NoError(t, err) + + // write back to udpMux XOR message with address + msg := stun.New() + msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} + msg.Add(stun.AttrUsername, []byte(ufrag+":otherufrag")) + addr := &stun.XORMappedAddress{ + IP: testXORIP, + Port: testXORPort, + } + err = addr.AddTo(msg) + require.NoError(t, err) + + msg.Encode() + _, err = remoteConn.Write(msg.Raw) + require.NoError(t, err) + + // wait for the packet to be consumed and parsed by udpMux + wg.Wait() + + // we should get address immediately from the cached map + address, err := udpMux.GetXORMappedAddr(remoteConn.LocalAddr(), time.Second) + require.NoError(t, err) + require.NotNil(t, address) + + udpMux.mu.Lock() + // check mappedAddr is not pending, we didn't send stun twice + require.False(t, mappedAddr.pending()) + + // check expiration by TTL + time.Sleep(time.Millisecond * 21) + require.True(t, mappedAddr.expired()) + udpMux.mu.Unlock() + + // after expire, we send stun request again + // but we not receive response in 5 milliseconds and should get error here + address, err = udpMux.GetXORMappedAddr(remoteConn.LocalAddr(), time.Millisecond*5) + require.NotNil(t, err) + require.Nil(t, address) +}