Skip to content

Commit 31e5123

Browse files
committed
fix(transport): make connection multiaddrs match the full multiaddr including sni and certhash components
1 parent b198a51 commit 31e5123

File tree

10 files changed

+293
-37
lines changed

10 files changed

+293
-37
lines changed

p2p/test/transport/gating_test.go

+6-8
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ func TestInterceptSecuredOutgoing(t *testing.T) {
101101
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
102102
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
103103
// remove the certhash component from WebTransport and WebRTC addresses
104-
require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String())
104+
require.Equal(t, h2.Addrs()[0].String(), addrs.RemoteMultiaddr().String())
105105
}),
106106
)
107107
err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})
@@ -135,8 +135,7 @@ func TestInterceptUpgradedOutgoing(t *testing.T) {
135135
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
136136
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true),
137137
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
138-
// remove the certhash component from WebTransport addresses
139-
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.RemoteMultiaddr())
138+
require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr())
140139
require.Equal(t, h1.ID(), c.LocalPeer())
141140
require.Equal(t, h2.ID(), c.RemotePeer())
142141
}))
@@ -170,12 +169,12 @@ func TestInterceptAccept(t *testing.T) {
170169
// In WebRTC, retransmissions of the STUN packet might cause us to create multiple connections,
171170
// if the first connection attempt is rejected.
172171
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
173-
// remove the certhash component from WebTransport addresses
172+
// remove the certhash component from WebRTC and WebTransport addresses
174173
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
175174
}).AnyTimes()
176175
} else {
177176
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
178-
// remove the certhash component from WebTransport addresses
177+
// remove the certhash component from WebRTC and WebTransport addresses
179178
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
180179
})
181180
}
@@ -213,8 +212,7 @@ func TestInterceptSecuredIncoming(t *testing.T) {
213212
gomock.InOrder(
214213
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true),
215214
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
216-
// remove the certhash component from WebTransport addresses
217-
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
215+
require.Equal(t, h2.Addrs()[0], addrs.LocalMultiaddr())
218216
}),
219217
)
220218
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
@@ -248,7 +246,7 @@ func TestInterceptUpgradedIncoming(t *testing.T) {
248246
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true),
249247
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
250248
// remove the certhash component from WebTransport addresses
251-
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.LocalMultiaddr())
249+
require.Equal(t, h2.Addrs()[0], c.LocalMultiaddr())
252250
require.Equal(t, h1.ID(), c.RemotePeer())
253251
require.Equal(t, h2.ID(), c.LocalPeer())
254252
}),

p2p/test/transport/transport_test.go

+97-2
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,17 @@ package transport_integration
33
import (
44
"bytes"
55
"context"
6+
"crypto/ecdsa"
7+
"crypto/elliptic"
68
"crypto/rand"
9+
"crypto/tls"
10+
"crypto/x509"
11+
"crypto/x509/pkix"
12+
"encoding/pem"
713
"errors"
814
"fmt"
915
"io"
16+
"math/big"
1017
"net"
1118
"runtime"
1219
"strings"
@@ -30,8 +37,9 @@ import (
3037
"github.com/libp2p/go-libp2p/p2p/net/swarm"
3138
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
3239
"github.com/libp2p/go-libp2p/p2p/security/noise"
33-
tls "github.com/libp2p/go-libp2p/p2p/security/tls"
40+
sectls "github.com/libp2p/go-libp2p/p2p/security/tls"
3441
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
42+
"github.com/libp2p/go-libp2p/p2p/transport/websocket"
3543
"go.uber.org/mock/gomock"
3644

3745
ma "github.com/multiformats/go-multiaddr"
@@ -48,6 +56,7 @@ type TransportTestCaseOpts struct {
4856
NoRcmgr bool
4957
ConnGater connmgr.ConnectionGater
5058
ResourceManager network.ResourceManager
59+
HostSeed string
5160
}
5261

5362
func transformOpts(opts TransportTestCaseOpts) []config.Option {
@@ -87,7 +96,7 @@ var transportsToTest = []TransportTestCase{
8796
Name: "TCP / TLS / Yamux",
8897
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
8998
libp2pOpts := transformOpts(opts)
90-
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
99+
libp2pOpts = append(libp2pOpts, libp2p.Security(sectls.ID, sectls.New))
91100
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
92101
if opts.NoListen {
93102
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
@@ -113,6 +122,26 @@ var transportsToTest = []TransportTestCase{
113122
return h
114123
},
115124
},
125+
{
126+
Name: "Secure WebSocket with CA Certificate",
127+
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
128+
libp2pOpts := transformOpts(opts)
129+
wsOpts := []interface{}{websocket.WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})}
130+
if opts.NoListen {
131+
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
132+
} else {
133+
dnsName := fmt.Sprintf("example%s.com", opts.HostSeed)
134+
cert, err := generateSelfSignedCert(dnsName)
135+
require.NoError(t, err)
136+
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings(fmt.Sprintf("/ip4/127.0.0.1/tcp/0/tls/sni/%s/ws", dnsName)))
137+
wsOpts = append(wsOpts, websocket.WithTLSConfig(&tls.Config{Certificates: []tls.Certificate{cert}}))
138+
}
139+
libp2pOpts = append(libp2pOpts, libp2p.Transport(websocket.New, wsOpts...))
140+
h, err := libp2p.New(libp2pOpts...)
141+
require.NoError(t, err)
142+
return h
143+
},
144+
},
116145
{
117146
Name: "QUIC",
118147
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
@@ -158,6 +187,46 @@ var transportsToTest = []TransportTestCase{
158187
},
159188
}
160189

190+
func generateSelfSignedCert(dnsName string) (tls.Certificate, error) {
191+
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
192+
if err != nil {
193+
return tls.Certificate{}, err
194+
}
195+
196+
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
197+
if err != nil {
198+
return tls.Certificate{}, err
199+
}
200+
201+
template := x509.Certificate{
202+
SerialNumber: serialNumber,
203+
Subject: pkix.Name{
204+
Organization: []string{"My Organization"},
205+
},
206+
NotBefore: time.Now(),
207+
NotAfter: time.Now().Add(365 * 24 * time.Hour), // Valid for 1 year
208+
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
209+
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
210+
BasicConstraintsValid: true,
211+
DNSNames: []string{dnsName},
212+
}
213+
214+
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
215+
if err != nil {
216+
return tls.Certificate{}, err
217+
}
218+
219+
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
220+
privDER, err := x509.MarshalECPrivateKey(priv)
221+
if err != nil {
222+
return tls.Certificate{}, err
223+
}
224+
privPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privDER})
225+
226+
// Load the certificate and key into tls.Certificate
227+
return tls.X509KeyPair(certPEM, privPEM)
228+
}
229+
161230
func TestPing(t *testing.T) {
162231
for _, tc := range transportsToTest {
163232
t.Run(tc.Name, func(t *testing.T) {
@@ -798,3 +867,29 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) {
798867
})
799868
}
800869
}
870+
871+
func TestConnMatchingAddress(t *testing.T) {
872+
for _, tc := range transportsToTest {
873+
t.Run(tc.Name, func(t *testing.T) {
874+
server := tc.HostGenerator(t, TransportTestCaseOpts{})
875+
client1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
876+
client2 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
877+
defer server.Close()
878+
defer client1.Close()
879+
defer client2.Close()
880+
881+
client1.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL)
882+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
883+
defer cancel()
884+
err := client1.Connect(ctx, peer.AddrInfo{ID: server.ID(), Addrs: server.Addrs()})
885+
require.NoError(t, err)
886+
887+
client1Conns := client1.Network().ConnsToPeer(server.ID())
888+
require.Equal(t, 1, len(client1Conns))
889+
remoteMA := client1Conns[0].RemoteMultiaddr()
890+
891+
err = client2.Connect(ctx, peer.AddrInfo{ID: server.ID(), Addrs: []ma.Multiaddr{remoteMA}})
892+
require.NoError(t, err)
893+
})
894+
}
895+
}

p2p/transport/webrtc/listener.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,13 @@ func (l *listener) setupConnection(
264264
return nil, err
265265
}
266266

267-
localMultiaddrWithoutCerthash, _ := ma.SplitFunc(l.localMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH })
268267
conn, err := newConnection(
269268
network.DirInbound,
270269
w.PeerConnection,
271270
l.transport,
272271
scope,
273272
l.transport.localPeerId,
274-
localMultiaddrWithoutCerthash,
273+
l.localMultiaddr,
275274
remotePeer,
276275
remotePubKey,
277276
remoteMultiaddr,

p2p/transport/webrtc/transport.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,6 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
387387
if err != nil {
388388
return nil, err
389389
}
390-
remoteMultiaddrWithoutCerthash, _ := ma.SplitFunc(remoteMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH })
391390

392391
conn, err := newConnection(
393392
network.DirOutbound,
@@ -398,7 +397,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
398397
localAddr,
399398
p,
400399
remotePubKey,
401-
remoteMultiaddrWithoutCerthash,
400+
remoteMultiaddr,
402401
w.IncomingDataChannels,
403402
w.PeerConnectionClosedCh,
404403
)

p2p/transport/websocket/conn.go

+91-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package websocket
22

33
import (
4+
"fmt"
5+
ma "github.com/multiformats/go-multiaddr"
6+
manet "github.com/multiformats/go-multiaddr/net"
47
"io"
58
"net"
69
"sync"
@@ -25,10 +28,88 @@ type Conn struct {
2528
closeOnce sync.Once
2629

2730
readLock, writeLock sync.Mutex
31+
32+
laddr, raddr *Addr
33+
laddrma, raddrma ma.Multiaddr
2834
}
2935

3036
var _ net.Conn = (*Conn)(nil)
3137

38+
// NewConn creates a Conn given a regular gorilla/websocket Conn.
39+
func NewOutboundConn(raw *ws.Conn, secure bool, sni string) (*Conn, error) {
40+
laddr := NewAddrWithScheme(raw.LocalAddr().String(), secure)
41+
raddr := NewAddrWithScheme(raw.RemoteAddr().String(), secure)
42+
43+
laddrma, err := manet.FromNetAddr(laddr)
44+
if err != nil {
45+
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
46+
}
47+
48+
raddrma, err := manet.FromNetAddr(raddr)
49+
if err != nil {
50+
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
51+
}
52+
53+
if secure {
54+
if withoutWSS := raddrma.Decapsulate(ma.StringCast("/wss")); withoutWSS.Equal(raddrma) {
55+
return nil, fmt.Errorf("missing wss component from converted multiaddr")
56+
} else {
57+
tlsSniWsMa, err := ma.NewMultiaddr(fmt.Sprintf("/tls/sni/%s/ws", sni))
58+
if err != nil {
59+
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
60+
}
61+
raddrma = withoutWSS.Encapsulate(tlsSniWsMa)
62+
}
63+
}
64+
65+
return &Conn{
66+
Conn: raw,
67+
secure: secure,
68+
DefaultMessageType: ws.BinaryMessage,
69+
laddr: laddr,
70+
raddr: raddr,
71+
laddrma: laddrma,
72+
raddrma: raddrma,
73+
}, nil
74+
}
75+
76+
func NewInboundConn(raw *ws.Conn, secure bool, sni string) (*Conn, error) {
77+
laddr := NewAddrWithScheme(raw.LocalAddr().String(), secure)
78+
raddr := NewAddrWithScheme(raw.RemoteAddr().String(), secure)
79+
80+
laddrma, err := manet.FromNetAddr(laddr)
81+
if err != nil {
82+
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
83+
}
84+
85+
raddrma, err := manet.FromNetAddr(raddr)
86+
if err != nil {
87+
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
88+
}
89+
90+
if secure {
91+
if withoutWSS := laddrma.Decapsulate(ma.StringCast("/wss")); withoutWSS.Equal(laddrma) {
92+
return nil, fmt.Errorf("missing wss component from converted multiaddr")
93+
} else {
94+
tlsSniWsMa, err := ma.NewMultiaddr(fmt.Sprintf("/tls/sni/%s/ws", sni))
95+
if err != nil {
96+
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
97+
}
98+
laddrma = withoutWSS.Encapsulate(tlsSniWsMa)
99+
}
100+
}
101+
102+
return &Conn{
103+
Conn: raw,
104+
secure: secure,
105+
DefaultMessageType: ws.BinaryMessage,
106+
laddr: laddr,
107+
raddr: raddr,
108+
laddrma: laddrma,
109+
raddrma: raddrma,
110+
}, nil
111+
}
112+
32113
// NewConn creates a Conn given a regular gorilla/websocket Conn.
33114
func NewConn(raw *ws.Conn, secure bool) *Conn {
34115
return &Conn{
@@ -122,11 +203,19 @@ func (c *Conn) Close() error {
122203
}
123204

124205
func (c *Conn) LocalAddr() net.Addr {
125-
return NewAddrWithScheme(c.Conn.LocalAddr().String(), c.secure)
206+
return c.laddr
126207
}
127208

128209
func (c *Conn) RemoteAddr() net.Addr {
129-
return NewAddrWithScheme(c.Conn.RemoteAddr().String(), c.secure)
210+
return c.raddr
211+
}
212+
213+
func (c *Conn) LocalMultiaddr() ma.Multiaddr {
214+
return c.laddrma
215+
}
216+
217+
func (c *Conn) RemoteMultiaddr() ma.Multiaddr {
218+
return c.raddrma
130219
}
131220

132221
func (c *Conn) SetDeadline(t time.Time) error {

p2p/transport/websocket/listener.go

+13-9
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,20 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
112112
return
113113
}
114114

115+
var sni string
116+
if r.TLS != nil {
117+
sni = r.TLS.ServerName
118+
}
119+
mnc, err := NewInboundConn(c, l.isWss, sni)
120+
if err != nil {
121+
_ = c.Close()
122+
return
123+
}
124+
115125
select {
116-
case l.incoming <- NewConn(c, l.isWss):
126+
case l.incoming <- mnc:
117127
case <-l.closed:
118-
c.Close()
128+
mnc.Close()
119129
}
120130
// The connection has been hijacked, it's safe to return.
121131
}
@@ -126,13 +136,7 @@ func (l *listener) Accept() (manet.Conn, error) {
126136
if !ok {
127137
return nil, transport.ErrListenerClosed
128138
}
129-
130-
mnc, err := manet.WrapNetConn(c)
131-
if err != nil {
132-
c.Close()
133-
return nil, err
134-
}
135-
return mnc, nil
139+
return c, nil
136140
case <-l.closed:
137141
return nil, transport.ErrListenerClosed
138142
}

0 commit comments

Comments
 (0)