Skip to content

Commit

Permalink
sshTunnelGateway refactor (#3784)
Browse files Browse the repository at this point in the history
  • Loading branch information
fatedier committed Nov 22, 2023
1 parent 8b432e1 commit d5b41f1
Show file tree
Hide file tree
Showing 34 changed files with 1,036 additions and 1,255 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ vet:
go vet ./...

frps:
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -o bin/frps ./cmd/frps
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -tags frps -o bin/frps ./cmd/frps

frpc:
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -o bin/frpc ./cmd/frpc
env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -tags frpc -o bin/frpc ./cmd/frpc

test: gotest

Expand Down
223 changes: 223 additions & 0 deletions client/connector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
// Copyright 2023 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package client

import (
"context"
"crypto/tls"
"io"
"net"
"strconv"
"strings"
"time"

libdial "github.com/fatedier/golib/net/dial"
fmux "github.com/hashicorp/yamux"
quic "github.com/quic-go/quic-go"
"github.com/samber/lo"

v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/xlog"
)

// Connector is a interface for establishing connections to the server.
type Connector interface {
Open() error
Connect() (net.Conn, error)
Close() error
}

// defaultConnectorImpl is the default implementation of Connector for normal frpc.
type defaultConnectorImpl struct {
ctx context.Context
cfg *v1.ClientCommonConfig

muxSession *fmux.Session
quicConn quic.Connection
}

func NewConnector(ctx context.Context, cfg *v1.ClientCommonConfig) Connector {
return &defaultConnectorImpl{
ctx: ctx,
cfg: cfg,
}
}

// Open opens a underlying connection to the server.
// The underlying connection is either a TCP connection or a QUIC connection.
// After the underlying connection is established, you can call Connect() to get a stream.
// If TCPMux isn't enabled, the underlying connection is nil, you will get a new real TCP connection every time you call Connect().
func (c *defaultConnectorImpl) Open() error {
xl := xlog.FromContextSafe(c.ctx)

// special for quic
if strings.EqualFold(c.cfg.Transport.Protocol, "quic") {
var tlsConfig *tls.Config
var err error
sn := c.cfg.Transport.TLS.ServerName
if sn == "" {
sn = c.cfg.ServerAddr
}
if lo.FromPtr(c.cfg.Transport.TLS.Enable) {
tlsConfig, err = transport.NewClientTLSConfig(
c.cfg.Transport.TLS.CertFile,
c.cfg.Transport.TLS.KeyFile,
c.cfg.Transport.TLS.TrustedCaFile,
sn)
} else {
tlsConfig, err = transport.NewClientTLSConfig("", "", "", sn)
}
if err != nil {
xl.Warn("fail to build tls configuration, err: %v", err)
return err
}
tlsConfig.NextProtos = []string{"frp"}

conn, err := quic.DialAddr(
c.ctx,
net.JoinHostPort(c.cfg.ServerAddr, strconv.Itoa(c.cfg.ServerPort)),
tlsConfig, &quic.Config{
MaxIdleTimeout: time.Duration(c.cfg.Transport.QUIC.MaxIdleTimeout) * time.Second,
MaxIncomingStreams: int64(c.cfg.Transport.QUIC.MaxIncomingStreams),
KeepAlivePeriod: time.Duration(c.cfg.Transport.QUIC.KeepalivePeriod) * time.Second,
})
if err != nil {
return err
}
c.quicConn = conn
return nil
}

if !lo.FromPtr(c.cfg.Transport.TCPMux) {
return nil
}

conn, err := c.realConnect()
if err != nil {
return err
}

fmuxCfg := fmux.DefaultConfig()
fmuxCfg.KeepAliveInterval = time.Duration(c.cfg.Transport.TCPMuxKeepaliveInterval) * time.Second
fmuxCfg.LogOutput = io.Discard
fmuxCfg.MaxStreamWindowSize = 6 * 1024 * 1024
session, err := fmux.Client(conn, fmuxCfg)
if err != nil {
return err
}
c.muxSession = session
return nil
}

// Connect returns a stream from the underlying connection, or a new TCP connection if TCPMux isn't enabled.
func (c *defaultConnectorImpl) Connect() (net.Conn, error) {
if c.quicConn != nil {
stream, err := c.quicConn.OpenStreamSync(context.Background())
if err != nil {
return nil, err
}
return utilnet.QuicStreamToNetConn(stream, c.quicConn), nil
} else if c.muxSession != nil {
stream, err := c.muxSession.OpenStream()
if err != nil {
return nil, err
}
return stream, nil
}

return c.realConnect()
}

func (c *defaultConnectorImpl) realConnect() (net.Conn, error) {
xl := xlog.FromContextSafe(c.ctx)
var tlsConfig *tls.Config
var err error
tlsEnable := lo.FromPtr(c.cfg.Transport.TLS.Enable)
if c.cfg.Transport.Protocol == "wss" {
tlsEnable = true
}
if tlsEnable {
sn := c.cfg.Transport.TLS.ServerName
if sn == "" {
sn = c.cfg.ServerAddr
}

tlsConfig, err = transport.NewClientTLSConfig(
c.cfg.Transport.TLS.CertFile,
c.cfg.Transport.TLS.KeyFile,
c.cfg.Transport.TLS.TrustedCaFile,
sn)
if err != nil {
xl.Warn("fail to build tls configuration, err: %v", err)
return nil, err
}
}

proxyType, addr, auth, err := libdial.ParseProxyURL(c.cfg.Transport.ProxyURL)
if err != nil {
xl.Error("fail to parse proxy url")
return nil, err
}
dialOptions := []libdial.DialOption{}
protocol := c.cfg.Transport.Protocol
switch protocol {
case "websocket":
protocol = "tcp"
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, "")}))
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(c.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
}))
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
case "wss":
protocol = "tcp"
dialOptions = append(dialOptions, libdial.WithTLSConfigAndPriority(100, tlsConfig))
// Make sure that if it is wss, the websocket hook is executed after the tls hook.
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, tlsConfig.ServerName), Priority: 110}))
default:
dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(c.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
}))
dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
}

if c.cfg.Transport.ConnectServerLocalIP != "" {
dialOptions = append(dialOptions, libdial.WithLocalAddr(c.cfg.Transport.ConnectServerLocalIP))
}
dialOptions = append(dialOptions,
libdial.WithProtocol(protocol),
libdial.WithTimeout(time.Duration(c.cfg.Transport.DialServerTimeout)*time.Second),
libdial.WithKeepAlive(time.Duration(c.cfg.Transport.DialServerKeepAlive)*time.Second),
libdial.WithProxy(proxyType, addr),
libdial.WithProxyAuth(auth),
)
conn, err := libdial.DialContext(
c.ctx,
net.JoinHostPort(c.cfg.ServerAddr, strconv.Itoa(c.cfg.ServerPort)),
dialOptions...,
)
return conn, err
}

func (c *defaultConnectorImpl) Close() error {
if c.quicConn != nil {
_ = c.quicConn.CloseWithError(0, "")
}
if c.muxSession != nil {
_ = c.muxSession.Close()
}
return nil
}
18 changes: 11 additions & 7 deletions client/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ type Control struct {
// control connection. Once conn is closed, the msgDispatcher and the entire Control will exit.
conn net.Conn

// use cm to create new connections, which could be real TCP connections or virtual streams.
cm *ConnectionManager
// use connector to create new connections, which could be real TCP connections or virtual streams.
connector Connector

doneCh chan struct{}

Expand All @@ -77,7 +77,7 @@ type Control struct {
}

func NewControl(
ctx context.Context, runID string, conn net.Conn, cm *ConnectionManager,
ctx context.Context, runID string, conn net.Conn, connector Connector,
clientCfg *v1.ClientCommonConfig,
pxyCfgs []v1.ProxyConfigurer,
visitorCfgs []v1.VisitorConfigurer,
Expand All @@ -92,7 +92,7 @@ func NewControl(
runID: runID,
pxyCfgs: pxyCfgs,
conn: conn,
cm: cm,
connector: connector,
doneCh: make(chan struct{}),
}
ctl.lastPong.Store(time.Now())
Expand Down Expand Up @@ -122,6 +122,10 @@ func (ctl *Control) Run() {
go ctl.vm.Run()
}

func (ctl *Control) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
ctl.pm.SetInWorkConnCallback(cb)
}

func (ctl *Control) handleReqWorkConn(_ msg.Message) {
xl := ctl.xl
workConn, err := ctl.connectServer()
Expand Down Expand Up @@ -207,7 +211,7 @@ func (ctl *Control) GracefulClose(d time.Duration) error {
time.Sleep(d)

ctl.conn.Close()
ctl.cm.Close()
ctl.connector.Close()
return nil
}

Expand All @@ -218,7 +222,7 @@ func (ctl *Control) Done() <-chan struct{} {

// connectServer return a new connection to frps
func (ctl *Control) connectServer() (conn net.Conn, err error) {
return ctl.cm.Connect()
return ctl.connector.Connect()
}

func (ctl *Control) registerMsgHandlers() {
Expand Down Expand Up @@ -282,7 +286,7 @@ func (ctl *Control) worker() {

ctl.pm.Close()
ctl.vm.Close()
ctl.cm.Close()
ctl.connector.Close()

close(ctl.doneCh)
}
Expand Down
17 changes: 13 additions & 4 deletions client/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,9 @@ func RegisterProxyFactory(proxyConfType reflect.Type, factory func(*BaseProxy, v
// Proxy defines how to handle work connections for different proxy type.
type Proxy interface {
Run() error

// InWorkConn accept work connections registered to server.
InWorkConn(net.Conn, *msg.StartWorkConn)

SetInWorkConnCallback(func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) /* continue */ bool)
Close()
}

Expand Down Expand Up @@ -89,7 +88,8 @@ type BaseProxy struct {
limiter *rate.Limiter
// proxyPlugin is used to handle connections instead of dialing to local service.
// It's only validate for TCP protocol now.
proxyPlugin plugin.Plugin
proxyPlugin plugin.Plugin
inWorkConnCallback func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) /* continue */ bool

mu sync.RWMutex
xl *xlog.Logger
Expand All @@ -113,7 +113,16 @@ func (pxy *BaseProxy) Close() {
}
}

func (pxy *BaseProxy) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
pxy.inWorkConnCallback = cb
}

func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
if pxy.inWorkConnCallback != nil {
if !pxy.inWorkConnCallback(pxy.baseCfg, conn, m) {
return
}
}
pxy.HandleTCPWorkConnection(conn, m, []byte(pxy.clientCfg.Auth.Token))
}

Expand All @@ -132,7 +141,7 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
})
}

xl.Trace("handle tcp work connection, use_encryption: %t, use_compression: %t",
xl.Trace("handle tcp work connection, useEncryption: %t, useCompression: %t",
baseCfg.Transport.UseEncryption, baseCfg.Transport.UseCompression)
if baseCfg.Transport.UseEncryption {
remote, err = libio.WithEncryption(remote, encKey)
Expand Down
12 changes: 10 additions & 2 deletions client/proxy/proxy_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ import (
)

type Manager struct {
proxies map[string]*Wrapper
msgTransporter transport.MessageTransporter
proxies map[string]*Wrapper
msgTransporter transport.MessageTransporter
inWorkConnCallback func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool

closed bool
mu sync.RWMutex
Expand Down Expand Up @@ -71,6 +72,10 @@ func (pm *Manager) StartProxy(name string, remoteAddr string, serverRespErr stri
return nil
}

func (pm *Manager) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
pm.inWorkConnCallback = cb
}

func (pm *Manager) Close() {
pm.mu.Lock()
defer pm.mu.Unlock()
Expand Down Expand Up @@ -146,6 +151,9 @@ func (pm *Manager) Reload(pxyCfgs []v1.ProxyConfigurer) {
name := cfg.GetBaseConfig().Name
if _, ok := pm.proxies[name]; !ok {
pxy := NewWrapper(pm.ctx, cfg, pm.clientCfg, pm.HandleEvent, pm.msgTransporter)
if pm.inWorkConnCallback != nil {
pxy.SetInWorkConnCallback(pm.inWorkConnCallback)
}
pm.proxies[name] = pxy
addPxyNames = append(addPxyNames, name)

Expand Down
4 changes: 4 additions & 0 deletions client/proxy/proxy_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ func NewWrapper(
return pw
}

func (pw *Wrapper) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
pw.pxy.SetInWorkConnCallback(cb)
}

func (pw *Wrapper) SetRunningStatus(remoteAddr string, respErr string) error {
pw.mu.Lock()
defer pw.mu.Unlock()
Expand Down
Loading

0 comments on commit d5b41f1

Please sign in to comment.