diff --git a/agent.go b/agent.go index b4aba19f..aa2a1d6f 100644 --- a/agent.go +++ b/agent.go @@ -122,9 +122,10 @@ type Agent struct { loggerFactory logging.LoggerFactory log logging.LeveledLogger - net *vnet.Net - tcpMux TCPMux - udpMux UDPMux + net *vnet.Net + tcpMux TCPMux + udpMux UDPMux + udpMuxSrflx UniversalUDPMux interfaceFilter func(string) bool @@ -319,6 +320,7 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit a.tcpMux = newInvalidTCPMux() } a.udpMux = config.UDPMux + a.udpMuxSrflx = config.UDPMuxSrflx if a.net == nil { a.net = vnet.NewNet(nil) @@ -892,6 +894,9 @@ func (a *Agent) removeUfragFromMux() { if a.udpMux != nil { a.udpMux.RemoveConnByUfrag(a.localUfrag) } + if a.udpMuxSrflx != nil { + a.udpMuxSrflx.RemoveConnByUfrag(a.localUfrag) + } } // Close cleans up the Agent diff --git a/agent_config.go b/agent_config.go index e577af10..b373939b 100644 --- a/agent_config.go +++ b/agent_config.go @@ -150,6 +150,12 @@ type AgentConfig struct { // defer to UDPMux for incoming connections UDPMux UDPMux + // UDPMuxSrflx is used for multiplexing multiple incoming UDP connections of server reflexive candidates + // on a single port when this is set, the agent ignores PortMin and PortMax configurations and will + // defer to UDPMuxSrflx for incoming connections + // It embeds UDPMux to do the actual connection multiplexing + UDPMuxSrflx UniversalUDPMux + // Proxy Dialer is a dialer that should be implemented by the user based on golang.org/x/net/proxy // dial interface in order to support corporate proxies ProxyDialer proxy.Dialer diff --git a/errors.go b/errors.go index 8ca9c2cd..2ef595e5 100644 --- a/errors.go +++ b/errors.go @@ -135,4 +135,8 @@ var ( errICEWriteSTUNMessage = errors.New("the ICE conn can't write STUN messages") errUDPMuxDisabled = errors.New("UDPMux is not enabled") errCandidateIPNotFound = errors.New("could not determine local IP for Mux candidate") + errNoXorAddrMapping = errors.New("no address mapping") + errSendSTUNPacket = errors.New("failed to send STUN packet") + errXORMappedAddrTimeout = errors.New("timeout while waiting for XORMappedAddr") + errNotImplemented = errors.New("not implemented yet") ) diff --git a/gather.go b/gather.go index e248c08d..8d2c40db 100644 --- a/gather.go +++ b/gather.go @@ -97,7 +97,11 @@ func (a *Agent) gatherCandidates(ctx context.Context) { case CandidateTypeServerReflexive: wg.Add(1) go func() { - a.gatherCandidatesSrflx(ctx, a.urls, a.networkTypes) + if a.udpMuxSrflx != nil { + a.gatherCandidatesSrflxUDPMux(ctx, a.urls, a.networkTypes) + } else { + a.gatherCandidatesSrflx(ctx, a.urls, a.networkTypes) + } wg.Done() }() if a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeServerReflexive { @@ -333,6 +337,68 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes [] } } +func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, networkTypes []NetworkType) { + var wg sync.WaitGroup + defer wg.Wait() + + for _, networkType := range networkTypes { + if networkType.IsTCP() { + continue + } + + for i := range urls { + wg.Add(1) + go func(url URL, network string) { + defer wg.Done() + + hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port) + serverAddr, err := a.net.ResolveUDPAddr(network, hostPort) + if err != nil { + a.log.Warnf("failed to resolve stun host: %s: %v", hostPort, err) + return + } + + xoraddr, err := a.udpMuxSrflx.GetXORMappedAddr(serverAddr, stunGatherTimeout) + if err != nil { + a.log.Warnf("could not get server reflexive address %s %s: %v\n", network, url, err) + return + } + + conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String()) + if err != nil { + a.log.Warnf("could not find connection in UDPMuxSrflx %s %s: %v\n", network, url, err) + return + } + + ip := xoraddr.IP + port := xoraddr.Port + + laddr := conn.LocalAddr().(*net.UDPAddr) + srflxConfig := CandidateServerReflexiveConfig{ + Network: network, + Address: ip.String(), + Port: port, + Component: ComponentRTP, + RelAddr: laddr.IP.String(), + RelPort: laddr.Port, + } + c, err := NewCandidateServerReflexive(&srflxConfig) + if err != nil { + closeConnAndLog(conn, a.log, fmt.Sprintf("Failed to create server reflexive candidate: %s %s %d: %v\n", network, ip, port, err)) + return + } + + if err := a.addCandidate(ctx, c, conn); err != nil { + if closeErr := c.close(); closeErr != nil { + a.log.Warnf("Failed to close candidate: %v", closeErr) + } + a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v\n", err) + } + }(*urls[i], networkType.String()) + } + } +} + func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*URL, networkTypes []NetworkType) { var wg sync.WaitGroup defer wg.Wait() diff --git a/gather_test.go b/gather_test.go index e69d9875..23c8296a 100644 --- a/gather_test.go +++ b/gather_test.go @@ -11,12 +11,14 @@ import ( "reflect" "sort" "strconv" + "sync" "testing" "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/pkg/crypto/selfsign" "github.com/pion/logging" + "github.com/pion/stun" "github.com/pion/transport/test" "github.com/pion/turn/v2" "github.com/stretchr/testify/assert" @@ -484,3 +486,93 @@ func TestTURNProxyDialer(t *testing.T) { assert.NoError(t, a.Close()) } + +// Assert that UniversalUDPMux is used while gathering when configured in the Agent +func TestUniversalUDPMuxUsage(t *testing.T) { + report := test.CheckRoutines(t) + defer report() + + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: randomPort(t)}) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + udpMuxSrflx := &universalUDPMuxMock{ + conn: conn, + } + + numSTUNS := 3 + urls := []*URL{} + for i := 0; i < numSTUNS; i++ { + urls = append(urls, &URL{ + Scheme: SchemeTypeSTUN, + Host: "127.0.0.1", + Port: 3478 + i, + }) + } + + a, err := NewAgent(&AgentConfig{ + NetworkTypes: supportedNetworkTypes(), + Urls: urls, + CandidateTypes: []CandidateType{CandidateTypeServerReflexive}, + UDPMuxSrflx: udpMuxSrflx, + }) + assert.NoError(t, err) + + candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) + assert.NoError(t, a.OnCandidate(func(c Candidate) { + if c == nil { + candidateGatheredFunc() + return + } + t.Log(c.NetworkType(), c.Priority(), c) + })) + assert.NoError(t, a.GatherCandidates()) + + <-candidateGathered.Done() + + assert.NoError(t, a.Close()) + // twice because of 2 STUN servers configured + assert.Equal(t, numSTUNS, udpMuxSrflx.getXORMappedAddrUsedTimes, "expected times that GetXORMappedAddr should be called") + // one for Restart() when agent has been initialized and one time when Close() the agent + assert.Equal(t, 2, udpMuxSrflx.removeConnByUfragTimes, "expected times that RemoveConnByUfrag should be called") + // twice because of 2 STUN servers configured + assert.Equal(t, numSTUNS, udpMuxSrflx.getConnForURLTimes, "expected times that GetConnForURL should be called") +} + +type universalUDPMuxMock struct { + UDPMux + getXORMappedAddrUsedTimes int + removeConnByUfragTimes int + getConnForURLTimes int + mu sync.Mutex + conn *net.UDPConn +} + +func (m *universalUDPMuxMock) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) { + return nil, errNotImplemented +} + +func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string) (net.PacketConn, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.getConnForURLTimes++ + return m.conn, nil +} + +func (m *universalUDPMuxMock) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.getXORMappedAddrUsedTimes++ + return &stun.XORMappedAddress{IP: net.IP{100, 64, 0, 1}, Port: 77878}, nil +} + +func (m *universalUDPMuxMock) RemoveConnByUfrag(ufrag string) { + m.mu.Lock() + defer m.mu.Unlock() + m.removeConnByUfragTimes++ +} diff --git a/udp_mux_universal.go b/udp_mux_universal.go new file mode 100644 index 00000000..c6198775 --- /dev/null +++ b/udp_mux_universal.go @@ -0,0 +1,265 @@ +package ice + +import ( + "fmt" + "net" + "time" + + "github.com/pion/logging" + "github.com/pion/stun" +) + +// UniversalUDPMux allows multiple connections to go over a single UDP port for +// host, server reflexive and relayed candidates. +// Actual connection muxing is happening in the UDPMux. +type UniversalUDPMux interface { + UDPMux + GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) + GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) + GetConnForURL(ufrag string, url string) (net.PacketConn, error) +} + +// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom. +// It the passes packets to the UDPMux that does the actual connection muxing. +type UniversalUDPMuxDefault struct { + *UDPMuxDefault + params UniversalUDPMuxParams + + // since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents + // stun.XORMappedAddress indexed by the STUN server addr + xorMappedMap map[string]*xorMapped +} + +// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive. +type UniversalUDPMuxParams struct { + Logger logging.LeveledLogger + UDPConn net.PacketConn + XORMappedAddrCacheTTL time.Duration +} + +// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux +func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault { + if params.Logger == nil { + params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") + } + if params.XORMappedAddrCacheTTL == 0 { + params.XORMappedAddrCacheTTL = time.Second * 25 + } + + m := &UniversalUDPMuxDefault{ + params: params, + xorMappedMap: make(map[string]*xorMapped), + } + + // wrap UDP connection, process server reflexive messages + // before they are passed to the UDPMux connection handler (connWorker) + m.params.UDPConn = &udpConn{ + PacketConn: params.UDPConn, + mux: m, + logger: params.Logger, + } + + // embed UDPMux + udpMuxParams := UDPMuxParams{ + Logger: params.Logger, + UDPConn: m.params.UDPConn, + } + m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) + + return m +} + +// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets +type udpConn struct { + net.PacketConn + mux *UniversalUDPMuxDefault + logger logging.LeveledLogger +} + +// GetRelayedAddr creates relayed connection to the given TURN service and returns the relayed addr. +// Not implemented yet. +func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) { + return nil, errNotImplemented +} + +// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers +// and return a unique connection per server. +func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string) (net.PacketConn, error) { + return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url)) +} + +// ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address. +// It passes processed packets further to the UDPMux (maybe this is not really necessary). +func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = c.PacketConn.ReadFrom(p) + if err != nil { + return + } + + if stun.IsMessage(p[:n]) { + msg := &stun.Message{ + Raw: append([]byte{}, p[:n]...), + } + + if err = msg.Decode(); err != nil { + c.logger.Warnf("Failed to handle decode ICE from %s: %v\n", addr.String(), err) + return n, addr, nil + } + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + // message about this err will be logged in the UDPMux + return + } + + if c.mux.isXORMappedResponse(msg, udpAddr.String()) { + err = c.mux.handleXORMappedResponse(udpAddr, msg) + if err != nil { + c.logger.Debugf("%w: %v", errGetXorMappedAddrResponse, err) + return n, addr, nil + } + return + } + } + return n, addr, err +} + +// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. +func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool { + m.mu.Lock() + defer m.mu.Unlock() + // check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess + _, ok := m.xorMappedMap[stunAddr] + _, err := msg.Get(stun.AttrXORMappedAddress) + return err == nil && ok +} + +// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute +// and set the mapped address for the server +func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error { + m.mu.Lock() + defer m.mu.Unlock() + + mappedAddr, ok := m.xorMappedMap[stunAddr.String()] + if !ok { + return errNoXorAddrMapping + } + + var addr stun.XORMappedAddress + if err := addr.GetFrom(msg); err != nil { + return err + } + + m.xorMappedMap[stunAddr.String()] = mappedAddr + mappedAddr.SetAddr(&addr) + + return nil +} + +// GetXORMappedAddr returns *stun.XORMappedAddress if already present for a given STUN server. +// Makes a STUN binding request to discover mapped address otherwise. +// Blocks until the stun.XORMappedAddress has been discovered or deadline. +// Method is safe for concurrent use. +func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) { + m.mu.Lock() + mappedAddr, ok := m.xorMappedMap[serverAddr.String()] + // if we already have a mapping for this STUN server (address already received) + // and if it is not too old we return it without making a new request to STUN server + if ok { + if mappedAddr.expired() { + mappedAddr.closeWaiters() + delete(m.xorMappedMap, serverAddr.String()) + ok = false + } else if mappedAddr.pending() { + ok = false + } + } + m.mu.Unlock() + if ok { + return mappedAddr.addr, nil + } + + // otherwise, make a STUN request to discover the address + // or wait for already sent request to complete + waitAddrReceived, err := m.sendStun(serverAddr) + if err != nil { + return nil, errSendSTUNPacket + } + + // block until response was handled by the connWorker routine and XORMappedAddress was updated + select { + case <-waitAddrReceived: + // when channel closed, addr was obtained + m.mu.Lock() + mappedAddr := *m.xorMappedMap[serverAddr.String()] + m.mu.Unlock() + if mappedAddr.addr == nil { + return nil, errNoXorAddrMapping + } + return mappedAddr.addr, nil + case <-time.After(deadline): + return nil, errXORMappedAddrTimeout + } +} + +// sendStun sends a STUN request via UDP conn. +// +// The returned channel is closed when the STUN response has been received. +// Method is safe for concurrent use. +func (m *UniversalUDPMuxDefault) sendStun(serverAddr net.Addr) (chan struct{}, error) { + m.mu.Lock() + defer m.mu.Unlock() + + // if record present in the map, we already sent a STUN request, + // just wait when waitAddrReceived will be closed + addrMap, ok := m.xorMappedMap[serverAddr.String()] + if !ok { + addrMap = &xorMapped{ + expiresAt: time.Now().Add(m.params.XORMappedAddrCacheTTL), + waitAddrReceived: make(chan struct{}), + } + m.xorMappedMap[serverAddr.String()] = addrMap + } + + req, err := stun.Build(stun.BindingRequest, stun.TransactionID) + if err != nil { + return nil, err + } + + if _, err = m.params.UDPConn.WriteTo(req.Raw, serverAddr); err != nil { + return nil, err + } + + return addrMap.waitAddrReceived, nil +} + +type xorMapped struct { + addr *stun.XORMappedAddress + waitAddrReceived chan struct{} + expiresAt time.Time +} + +func (a *xorMapped) closeWaiters() { + select { + case <-a.waitAddrReceived: + // notify was close, ok, that means we received duplicate response + // just exit + break + default: + // notify tha twe have a new addr + close(a.waitAddrReceived) + } +} + +func (a *xorMapped) pending() bool { + return a.addr == nil +} + +func (a *xorMapped) expired() bool { + return a.expiresAt.Before(time.Now()) +} + +func (a *xorMapped) SetAddr(addr *stun.XORMappedAddress) { + a.addr = addr + a.closeWaiters() +} diff --git a/udp_mux_universal_test.go b/udp_mux_universal_test.go new file mode 100644 index 00000000..c263bf61 --- /dev/null +++ b/udp_mux_universal_test.go @@ -0,0 +1,127 @@ +// +build !js + +package ice + +import ( + "net" + "sync" + "testing" + "time" + + "github.com/pion/stun" + "github.com/stretchr/testify/require" +) + +func TestUniversalUDPMux(t *testing.T) { + conn, err := net.ListenUDP(udp, &net.UDPAddr{}) + require.NoError(t, err) + + udpMux := NewUniversalUDPMuxDefault(UniversalUDPMuxParams{ + Logger: nil, + UDPConn: conn, + }) + + require.NoError(t, err) + defer func() { + _ = udpMux.Close() + _ = conn.Close() + }() + + require.NotNil(t, udpMux.LocalAddr(), "tcpMux.LocalAddr() is nil") + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + testMuxSrflxConnection(t, udpMux, "ufrag4", udp) + }() + + wg.Wait() +} + +func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) { + pktConn, err := udpMux.GetConn(ufrag) + require.NoError(t, err, "error retrieving muxed connection for ufrag") + defer func() { + _ = pktConn.Close() + }() + + remoteConn, err := net.DialUDP(network, nil, &net.UDPAddr{ + Port: udpMux.LocalAddr().(*net.UDPAddr).Port, + }) + require.NoError(t, err, "error dialing test udp connection") + defer func() { + _ = remoteConn.Close() + }() + + // use small value for TTL to check expiration of the address + udpMux.params.XORMappedAddrCacheTTL = time.Millisecond * 20 + testXORIP := net.ParseIP("213.141.156.236") + testXORPort := 21254 + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + address, e := udpMux.GetXORMappedAddr(remoteConn.LocalAddr(), time.Second) + require.NoError(t, e) + require.NotNil(t, address) + require.True(t, address.IP.Equal(testXORIP)) + require.Equal(t, address.Port, testXORPort) + }() + + // wait until GetXORMappedAddr calls sendStun method + time.Sleep(time.Millisecond) + + // check that mapped address filled correctly after sent stun + udpMux.mu.Lock() + mappedAddr, ok := udpMux.xorMappedMap[remoteConn.LocalAddr().String()] + require.True(t, ok) + require.NotNil(t, mappedAddr) + require.True(t, mappedAddr.pending()) + require.False(t, mappedAddr.expired()) + udpMux.mu.Unlock() + + // clean receiver read buffer + buf := make([]byte, receiveMTU) + _, err = remoteConn.Read(buf) + require.NoError(t, err) + + // write back to udpMux XOR message with address + msg := stun.New() + msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} + msg.Add(stun.AttrUsername, []byte(ufrag+":otherufrag")) + addr := &stun.XORMappedAddress{ + IP: testXORIP, + Port: testXORPort, + } + err = addr.AddTo(msg) + require.NoError(t, err) + + msg.Encode() + _, err = remoteConn.Write(msg.Raw) + require.NoError(t, err) + + // wait for the packet to be consumed and parsed by udpMux + wg.Wait() + + // we should get address immediately from the cached map + address, err := udpMux.GetXORMappedAddr(remoteConn.LocalAddr(), time.Second) + require.NoError(t, err) + require.NotNil(t, address) + + udpMux.mu.Lock() + // check mappedAddr is not pending, we didn't send stun twice + require.False(t, mappedAddr.pending()) + + // check expiration by TTL + time.Sleep(time.Millisecond * 21) + require.True(t, mappedAddr.expired()) + udpMux.mu.Unlock() + + // after expire, we send stun request again + // but we not receive response in 5 milliseconds and should get error here + address, err = udpMux.GetXORMappedAddr(remoteConn.LocalAddr(), time.Millisecond*5) + require.NotNil(t, err) + require.Nil(t, address) +}