From 069883c15a3f78c89ad65c38268eea8745011008 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sat, 30 Aug 2025 01:26:55 +0800 Subject: [PATCH 01/38] feat: gateway v2 scaffolding --- packages/api/api.go | 63 +++++ packages/api/model.go | 40 +++ packages/cmd/network.go | 263 ++++++++++++++++++ packages/gateway-v2/gateway.go | 428 +++++++++++++++++++++++++++++ packages/proxy/proxy.go | 486 +++++++++++++++++++++++++++++++++ 5 files changed, 1280 insertions(+) create mode 100644 packages/cmd/network.go create mode 100644 packages/gateway-v2/gateway.go create mode 100644 packages/proxy/proxy.go diff --git a/packages/api/api.go b/packages/api/api.go index a9b204b6..7d61eb8a 100644 --- a/packages/api/api.go +++ b/packages/api/api.go @@ -40,6 +40,9 @@ const ( operationCallExchangeRelayCertV1 = "CallExchangeRelayCertV1" operationCallGatewayHeartBeatV1 = "CallGatewayHeartBeatV1" operationCallBootstrapInstance = "CallBootstrapInstance" + operationCallRegisterInstanceProxy = "CallRegisterInstanceProxy" + operationCallRegisterOrgProxy = "CallRegisterOrgProxy" + operationCallRegisterGateway = "CallRegisterGateway" ) func CallGetEncryptedWorkspaceKey(httpClient *resty.Client, request GetEncryptedWorkspaceKeyRequest) (GetEncryptedWorkspaceKeyResponse, error) { @@ -671,3 +674,63 @@ func CallBootstrapInstance(httpClient *resty.Client, request BootstrapInstanceRe return resBody, nil } + +func CallRegisterInstanceProxy(httpClient *resty.Client, request RegisterProxyRequest) (RegisterProxyResponse, error) { + var resBody RegisterProxyResponse + response, err := httpClient. + R(). + SetResult(&resBody). + SetHeader("User-Agent", USER_AGENT). + SetBody(request). + Post(fmt.Sprintf("%v/v1/proxies/register-instance-proxy", config.INFISICAL_URL)) + + if err != nil { + return RegisterProxyResponse{}, NewGenericRequestError(operationCallRegisterInstanceProxy, err) + } + + if response.IsError() { + return RegisterProxyResponse{}, NewAPIErrorWithResponse(operationCallRegisterInstanceProxy, response, nil) + } + + return resBody, nil +} + +func CallRegisterProxy(httpClient *resty.Client, request RegisterProxyRequest) (RegisterProxyResponse, error) { + var resBody RegisterProxyResponse + response, err := httpClient. + R(). + SetResult(&resBody). + SetHeader("User-Agent", USER_AGENT). + SetBody(request). + Post(fmt.Sprintf("%v/v1/proxies/register-org-proxy", config.INFISICAL_URL)) + + if err != nil { + return RegisterProxyResponse{}, NewGenericRequestError(operationCallRegisterOrgProxy, err) + } + + if response.IsError() { + return RegisterProxyResponse{}, NewAPIErrorWithResponse(operationCallRegisterOrgProxy, response, nil) + } + + return resBody, nil +} + +func CallRegisterGateway(httpClient *resty.Client, request RegisterGatewayRequest) (RegisterGatewayResponse, error) { + var resBody RegisterGatewayResponse + response, err := httpClient. + R(). + SetResult(&resBody). + SetHeader("User-Agent", USER_AGENT). + SetBody(request). + Post(fmt.Sprintf("%v/v2/gateways", config.INFISICAL_URL)) + + if err != nil { + return RegisterGatewayResponse{}, NewGenericRequestError(operationCallRegisterGateway, err) + } + + if response.IsError() { + return RegisterGatewayResponse{}, NewAPIErrorWithResponse(operationCallRegisterGateway, response, nil) + } + + return resBody, nil +} diff --git a/packages/api/model.go b/packages/api/model.go index 3f10b4ca..ad172278 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -703,3 +703,43 @@ type BootstrapUser struct { Username string `json:"username"` SuperAdmin bool `json:"superAdmin"` } + +type RegisterProxyRequest struct { + IP string `json:"ip"` + Name string `json:"name"` +} + +type RegisterProxyResponse struct { + PKI struct { + ServerCertificate string `json:"serverCertificate"` + ServerCertificateChain string `json:"serverCertificateChain"` + ServerPrivateKey string `json:"serverPrivateKey"` + ClientCA string `json:"clientCA"` + } `json:"pki"` + SSH struct { + ServerCertificate string `json:"serverCertificate"` + ServerPrivateKey string `json:"serverPrivateKey"` + ClientCAPublicKey string `json:"clientCAPublicKey"` + } `json:"ssh"` +} + +type RegisterGatewayRequest struct { + ProxyName string `json:"proxyName"` + Name string `json:"name"` +} + +type RegisterGatewayResponse struct { + GatewayID string `json:"gatewayId"` + ProxyIP string `json:"proxyIp"` + PKI struct { + ServerCertificate string `json:"serverCertificate"` + ServerCertificateChain string `json:"serverCertificateChain"` + ServerPrivateKey string `json:"serverPrivateKey"` + ClientCA string `json:"clientCA"` + } `json:"pki"` + SSH struct { + ClientCertificate string `json:"clientCertificate"` + ClientPrivateKey string `json:"clientPrivateKey"` + ServerCAPublicKey string `json:"serverCAPublicKey"` + } `json:"ssh"` +} diff --git a/packages/cmd/network.go b/packages/cmd/network.go new file mode 100644 index 00000000..8ec9c6b3 --- /dev/null +++ b/packages/cmd/network.go @@ -0,0 +1,263 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "sync/atomic" + "syscall" + "time" + + gatewayv2 "github.com/Infisical/infisical-merge/packages/gateway-v2" + "github.com/Infisical/infisical-merge/packages/proxy" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" +) + +var networkCmd = &cobra.Command{ + Use: "network", + Short: "Network-related commands", + Long: "Network-related commands for Infisical", +} + +var networkProxyCmd = &cobra.Command{ + Use: "proxy", + Short: "Run the Infisical proxy component", + Long: "Run the Infisical proxy component", + Run: func(cmd *cobra.Command, args []string) { + + proxyName, err := cmd.Flags().GetString("name") + if err != nil || proxyName == "" { + util.HandleError(err, "unable to get name flag") + } + + ip, err := cmd.Flags().GetString("ip") + if err != nil || ip == "" { + util.HandleError(err, "unable to get ip flag") + } + + instanceType, err := cmd.Flags().GetString("type") + if err != nil { + util.HandleError(err, "unable to get type flag") + } + + proxyInstance, err := proxy.NewProxy(&proxy.ProxyConfig{ + ProxyName: proxyName, + SSHPort: "2222", + TLSPort: "443", + StaticIP: ip, + Type: instanceType, + }) + + if err != nil { + util.HandleError(err, "unable to create proxy instance") + } + + if instanceType == "instance" { + proxyAuthSecret := os.Getenv("PROXY_AUTH_SECRET") + if proxyAuthSecret == "" { + util.HandleError(fmt.Errorf("PROXY_AUTH_SECRET is not set"), "unable to get proxy auth secret") + } + + proxyInstance.SetToken(proxyAuthSecret) + } else { + infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) + if err != nil { + util.HandleError(err, "unable to get infisical client") + } + defer cancelSdk() + + var accessToken atomic.Value + accessToken.Store(infisicalClient.Auth().GetAccessToken()) + + if accessToken.Load().(string) == "" { + util.HandleError(errors.New("no access token found")) + } + + proxyInstance.SetToken(accessToken.Load().(string)) + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + ctx, cancelCmd := context.WithCancel(cmd.Context()) + defer cancelCmd() + + go func() { + <-sigCh + log.Info().Msg("Received shutdown signal, shutting down proxy...") + cancelCmd() + cancelSdk() + + // If we get a second signal, force exit + <-sigCh + log.Warn().Msgf("Force exit triggered") + os.Exit(1) + }() + + // Token refresh goroutine - runs every 10 seconds + go func() { + tokenRefreshTicker := time.NewTicker(10 * time.Second) + defer tokenRefreshTicker.Stop() + + for { + select { + case <-tokenRefreshTicker.C: + if ctx.Err() != nil { + return + } + + newToken := infisicalClient.Auth().GetAccessToken() + if newToken != "" && newToken != accessToken.Load().(string) { + accessToken.Store(newToken) + proxyInstance.SetToken(newToken) + } + + case <-ctx.Done(): + return + } + } + }() + } + + // Use the same context for the proxy server + err = proxyInstance.Start(cmd.Context()) + if err != nil { + util.HandleError(err, "unable to start proxy instance") + } + }, +} + +var networkProxyInstallCmd = &cobra.Command{ + Use: "proxy install", + Short: "Install and enable systemd service for the proxy (requires sudo)", + Long: "Install and enable systemd service for the proxy. Must be run with sudo on Linux.", + Run: func(cmd *cobra.Command, args []string) { + // TODO: Implement this + }, +} + +var networkGatewayCmd = &cobra.Command{ + Use: "gateway", + Short: "Run the Infisical gateway component", + Long: "Run the Infisical gateway component", + Run: func(cmd *cobra.Command, args []string) { + + proxyName, err := cmd.Flags().GetString("proxy-name") + if err != nil || proxyName == "" { + util.HandleError(err, "unable to get proxy-name flag") + } + + gatewayName, err := cmd.Flags().GetString("name") + if err != nil || gatewayName == "" { + util.HandleError(err, "unable to get name flag") + } + + gatewayInstance, err := gatewayv2.NewGateway(&gatewayv2.GatewayConfig{ + Name: gatewayName, + ProxyName: proxyName, + ReconnectDelay: 10 * time.Second, + }) + + if err != nil { + util.HandleError(err, "unable to create gateway instance") + } + + infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) + if err != nil { + util.HandleError(err, "unable to get infisical client") + } + defer cancelSdk() + + var accessToken atomic.Value + accessToken.Store(infisicalClient.Auth().GetAccessToken()) + + if accessToken.Load().(string) == "" { + util.HandleError(errors.New("no access token found")) + } + + gatewayInstance.SetToken(accessToken.Load().(string)) + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + ctx, cancelCmd := context.WithCancel(cmd.Context()) + defer cancelCmd() + + go func() { + <-sigCh + log.Info().Msg("Received shutdown signal, shutting down gateway...") + cancelCmd() + cancelSdk() + + // If we get a second signal, force exit + <-sigCh + log.Warn().Msgf("Force exit triggered") + os.Exit(1) + }() + + // Token refresh goroutine - runs every 10 seconds + go func() { + tokenRefreshTicker := time.NewTicker(10 * time.Second) + defer tokenRefreshTicker.Stop() + + for { + select { + case <-tokenRefreshTicker.C: + if ctx.Err() != nil { + return + } + + newToken := infisicalClient.Auth().GetAccessToken() + if newToken != "" && newToken != accessToken.Load().(string) { + accessToken.Store(newToken) + gatewayInstance.SetToken(newToken) + } + + case <-ctx.Done(): + return + } + } + }() + + err = gatewayInstance.Start(ctx) + if err != nil { + util.HandleError(err, "unable to start gateway instance") + } + + }, +} + +func init() { + networkGatewayCmd.Flags().String("proxy-name", "", "The name of the proxy to connect to") + networkGatewayCmd.Flags().String("name", "", "The name of the gateway") + networkGatewayCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") + networkGatewayCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") + networkGatewayCmd.Flags().String("client-id", "", "client id for universal auth") + networkGatewayCmd.Flags().String("client-secret", "", "client secret for universal auth") + networkGatewayCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") + networkGatewayCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") + networkGatewayCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") + networkGatewayCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") + + networkProxyCmd.Flags().String("type", "org", "The type of proxy to run. Must be either 'instance' or 'org'") + networkProxyCmd.Flags().String("ip", "", "The IP address of the proxy") + networkProxyCmd.Flags().String("name", "", "The name of the proxy") + networkProxyCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") + networkProxyCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") + networkProxyCmd.Flags().String("client-id", "", "client id for universal auth") + networkProxyCmd.Flags().String("client-secret", "", "client secret for universal auth") + networkProxyCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") + networkProxyCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") + networkProxyCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") + networkProxyCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") + + networkProxyCmd.AddCommand(networkProxyInstallCmd) + + networkCmd.AddCommand(networkProxyCmd) + networkCmd.AddCommand(networkGatewayCmd) + + rootCmd.AddCommand(networkCmd) +} diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go new file mode 100644 index 00000000..2208e8cd --- /dev/null +++ b/packages/gateway-v2/gateway.go @@ -0,0 +1,428 @@ +package gatewayv2 + +import ( + "bytes" + "context" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "io" + "log" + "net" + "sync" + "time" + + "github.com/Infisical/infisical-merge/packages/api" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/go-resty/resty/v2" + "golang.org/x/crypto/ssh" +) + +type GatewayConfig struct { + Name string + ProxyName string + IdentityToken string + SSHPort int + ReconnectDelay time.Duration +} + +type Gateway struct { + GatewayID string + + httpClient *resty.Client + config *GatewayConfig + sshClient *ssh.Client + + // Certificate storage + certificates *api.RegisterGatewayResponse + + // mTLS server components + tlsConfig *tls.Config + tlsCACert []byte + tlsCAKey *rsa.PrivateKey + + // Connection management + mu sync.RWMutex + isConnected bool + ctx context.Context + cancel context.CancelFunc +} + +// NewGateway creates a new gateway instance +func NewGateway(config *GatewayConfig) (*Gateway, error) { + httpClient, err := util.GetRestyClientWithCustomHeaders() + if err != nil { + return nil, fmt.Errorf("unable to get client with custom headers [err=%v]", err) + } + + httpClient.SetAuthToken(config.IdentityToken) + + ctx, cancel := context.WithCancel(context.Background()) + + // Set default SSH port if not specified + if config.SSHPort == 0 { + config.SSHPort = 2222 + } + + return &Gateway{ + httpClient: httpClient, + config: config, + ctx: ctx, + cancel: cancel, + }, nil +} + +// Change the Start method to accept a context +func (g *Gateway) Start(ctx context.Context) error { + log.Printf("Starting gateway") + for { + select { + case <-ctx.Done(): + log.Printf("Gateway stopped by context cancellation") + return nil + default: + if err := g.connectAndServe(); err != nil { + log.Printf("Connection failed: %v, retrying in %v...", err, g.config.ReconnectDelay) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(g.config.ReconnectDelay): + continue + } + } + // If we get here, the connection was closed gracefully + log.Printf("Connection closed, reconnecting in 10 seconds...") + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(10 * time.Second): + continue + } + } + } +} + +func (g *Gateway) SetToken(token string) { + g.httpClient.SetAuthToken(token) +} + +func (g *Gateway) Stop() { + g.cancel() + + g.mu.Lock() + if g.sshClient != nil { + g.sshClient.Close() + g.sshClient = nil + } + g.isConnected = false + g.mu.Unlock() +} + +func (g *Gateway) connectAndServe() error { + if err := g.registerGateway(); err != nil { + return fmt.Errorf("failed to register gateway: %v", err) + } + + // Create SSH client config + sshConfig, err := g.createSSHConfig() + if err != nil { + return fmt.Errorf("failed to create SSH config: %v", err) + } + + // Connect to Proxy server + log.Printf("Connecting to SSH server on %s:%d...", g.certificates.ProxyIP, g.config.SSHPort) + client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", g.certificates.ProxyIP, g.config.SSHPort), sshConfig) + if err != nil { + return fmt.Errorf("failed to connect to SSH server: %v", err) + } + + g.mu.Lock() + g.sshClient = client + g.isConnected = true + g.mu.Unlock() + + defer func() { + g.mu.Lock() + g.sshClient = nil + g.isConnected = false + g.mu.Unlock() + client.Close() + }() + + log.Printf("SSH connection established for gateway") + + // Handle incoming channels from the server + channels := client.HandleChannelOpen("direct-tcpip") + if channels == nil { + return fmt.Errorf("failed to handle channel open") + } + + // Process incoming channels + for newChannel := range channels { + go g.handleIncomingChannel(newChannel) + } + + return nil // Connection closed +} + +func (g *Gateway) registerGateway() error { + body := api.RegisterGatewayRequest{ + ProxyName: g.config.ProxyName, + Name: g.config.Name, + } + + certResp, err := api.CallRegisterGateway(g.httpClient, body) + if err != nil { + return fmt.Errorf("failed to register gateway: %v", err) + } + + g.GatewayID = certResp.GatewayID + g.certificates = &certResp + log.Printf("Successfully registered gateway and received certificates") + return nil +} + +func (g *Gateway) createSSHConfig() (*ssh.ClientConfig, error) { + privateKey, err := ssh.ParsePrivateKey([]byte(g.certificates.SSH.ClientPrivateKey)) + if err != nil { + return nil, fmt.Errorf("failed to parse SSH private key: %v", err) + } + + // Parse certificate + cert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(g.certificates.SSH.ClientCertificate)) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %v", err) + } + + // Create certificate signer + certSigner, err := ssh.NewCertSigner(cert.(*ssh.Certificate), privateKey) + if err != nil { + return nil, fmt.Errorf("failed to create certificate signer: %v", err) + } + + // Create SSH client config + config := &ssh.ClientConfig{ + User: g.GatewayID, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(certSigner), + }, + HostKeyCallback: g.createHostKeyCallback(), + Timeout: 30 * time.Second, + Config: ssh.Config{ + KeyExchanges: []string{ + "diffie-hellman-group14-sha256", + "diffie-hellman-group16-sha512", + "diffie-hellman-group18-sha512", + }, + Ciphers: []string{ + "aes128-ctr", + "aes192-ctr", + "aes256-ctr", + }, + MACs: []string{ + "hmac-sha2-256", + "hmac-sha2-512", + }, + }, + } + + return config, nil +} + +func (g *Gateway) createHostKeyCallback() ssh.HostKeyCallback { + // Parse CA public key once when creating the callback + caKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(g.certificates.SSH.ServerCAPublicKey)) + if err != nil { + // Return a callback that always fails since we can't parse the CA key + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return fmt.Errorf("failed to parse CA public key: %v", err) + } + } + + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + cert, ok := key.(*ssh.Certificate) + if !ok { + return fmt.Errorf("host certificates required, raw host keys not allowed") + } + + return g.validateHostCertificate(cert, hostname, caKey) + } +} + +func (g *Gateway) validateHostCertificate(cert *ssh.Certificate, hostname string, caKey ssh.PublicKey) error { + checker := &ssh.CertChecker{ + IsHostAuthority: func(auth ssh.PublicKey, address string) bool { + return bytes.Equal(auth.Marshal(), caKey.Marshal()) + }, + } + + if err := checker.CheckCert(hostname, cert); err != nil { + return fmt.Errorf("host certificate check failed: %v", err) + } + + log.Printf("Host certificate validated successfully for %s", hostname) + return nil +} + +func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { + var req struct { + Host string + Port uint32 + OriginHost string + OriginPort uint32 + } + + if err := ssh.Unmarshal(newChannel.ExtraData(), &req); err != nil { + log.Printf("Failed to parse channel request: %v", err) + newChannel.Reject(ssh.Prohibited, "invalid request") + return + } + + log.Printf("Incoming connection request to %s:%d from %s:%d", + req.Host, req.Port, req.OriginHost, req.OriginPort) + + // Accept the channel + channel, requests, err := newChannel.Accept() + if err != nil { + log.Printf("Failed to accept channel: %v", err) + return + } + defer channel.Close() + + go ssh.DiscardRequests(requests) + + // Determine the target address + target := fmt.Sprintf("%s:%d", req.Host, req.Port) + log.Printf("Creating TCP tunnel to: %s", target) + + // Create mTLS server configuration + tlsConfig, err := g.createMTLSConfig() + if err != nil { + log.Printf("Failed to create mTLS config: %v", err) + return + } + + // Create a virtual connection that pipes data between SSH channel and TLS + virtualConn := &virtualConnection{ + channel: channel, + } + + // Wrap the virtual connection with TLS + tlsConn := tls.Server(virtualConn, tlsConfig) + + // Perform TLS handshake + if err := tlsConn.Handshake(); err != nil { + log.Printf("TLS handshake failed: %v", err) + return + } + + log.Printf("mTLS connection established with client: %s", tlsConn.ConnectionState().ServerName) + + // Connect to local service + localConn, err := net.Dial("tcp", target) + if err != nil { + log.Printf("Failed to connect to local service %s: %v", target, err) + return + } + defer localConn.Close() + + log.Printf("TCP tunnel established to %s", target) + + // Create bidirectional tunnel with TLS + // Forward data from TLS connection to local service + go func() { + io.Copy(localConn, tlsConn) + localConn.Close() + log.Printf("TLS -> local service tunnel closed") + }() + + // Forward data from local service to TLS connection + io.Copy(tlsConn, localConn) + log.Printf("Local service -> TLS tunnel closed") +} + +func (g *Gateway) createMTLSConfig() (*tls.Config, error) { + // Parse server certificate + serverCertBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerCertificate)) + if serverCertBlock == nil { + return nil, fmt.Errorf("failed to decode server certificate") + } + + // Parse server private key + serverKeyBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerPrivateKey)) + if serverKeyBlock == nil { + return nil, fmt.Errorf("failed to decode server private key") + } + + serverKey, err := x509.ParsePKCS8PrivateKey(serverKeyBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse server private key: %v", err) + } + + // Parse client CA certificate + clientCABlock, _ := pem.Decode([]byte(g.certificates.PKI.ClientCA)) + if clientCABlock == nil { + return nil, fmt.Errorf("failed to decode client CA certificate") + } + + clientCA, err := x509.ParseCertificate(clientCABlock.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse client CA certificate: %v", err) + } + + // Create certificate pool for client CAs + clientCAPool := x509.NewCertPool() + clientCAPool.AddCert(clientCA) + + // Create TLS config + return &tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{serverCertBlock.Bytes}, + PrivateKey: serverKey, + }, + }, + ClientCAs: clientCAPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS12, + }, nil +} + +// virtualConnection implements net.Conn to bridge SSH channel and TLS +type virtualConnection struct { + channel ssh.Channel +} + +func (vc *virtualConnection) Read(b []byte) (n int, err error) { + return vc.channel.Read(b) +} + +func (vc *virtualConnection) Write(b []byte) (n int, err error) { + return vc.channel.Write(b) +} + +func (vc *virtualConnection) Close() error { + return vc.channel.Close() +} + +func (vc *virtualConnection) LocalAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} + +func (vc *virtualConnection) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} + +func (vc *virtualConnection) SetDeadline(t time.Time) error { + return nil +} + +func (vc *virtualConnection) SetReadDeadline(t time.Time) error { + return nil +} + +func (vc *virtualConnection) SetWriteDeadline(t time.Time) error { + return nil +} diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go new file mode 100644 index 00000000..492dee61 --- /dev/null +++ b/packages/proxy/proxy.go @@ -0,0 +1,486 @@ +package proxy + +import ( + "bytes" + "context" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "io" + "log" + "net" + + "strconv" + "strings" + "sync" + + "github.com/Infisical/infisical-merge/packages/api" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/go-resty/resty/v2" + "golang.org/x/crypto/ssh" +) + +type ProxyConfig struct { + // API Configuration + Token string + ProxyName string + + Type string + + // Server Ports + SSHPort string + TLSPort string + + // Network Configuration + StaticIP string +} + +type Proxy struct { + httpClient *resty.Client + config *ProxyConfig + + // Certificate storage + certificates *api.RegisterProxyResponse + + // SSH server components + sshConfig *ssh.ServerConfig + sshCA ssh.Signer + + // TLS server components + tlsConfig *tls.Config + tlsCACert []byte + tlsCAKey *rsa.PrivateKey + + // Tunnel storage (Gateway ID -> SSH connection) + tunnels map[string]*ssh.ServerConn + mu sync.RWMutex + + // Server listeners + sshListener net.Listener + tlsListener net.Listener +} + +func NewProxy(config *ProxyConfig) (*Proxy, error) { + httpClient, err := util.GetRestyClientWithCustomHeaders() + if err != nil { + return nil, fmt.Errorf("unable to get client with custom headers [err=%v]", err) + } + + httpClient.SetAuthToken(config.Token) + + return &Proxy{ + httpClient: httpClient, + config: config, + tunnels: make(map[string]*ssh.ServerConn), + }, nil +} + +func (p *Proxy) SetToken(token string) { + p.httpClient.SetAuthToken(token) +} + +func (p *Proxy) Start(ctx context.Context) error { + // Register proxy and get certificates from API + if err := p.registerProxy(); err != nil { + return fmt.Errorf("failed to register proxy: %v", err) + } + + // Setup SSH server + if err := p.setupSSHServer(); err != nil { + return fmt.Errorf("failed to setup SSH server: %v", err) + } + + // Setup TLS server + if err := p.setupTLSServer(); err != nil { + return fmt.Errorf("failed to setup TLS server: %v", err) + } + + // Start SSH server + go p.startSSHServer() + + // Start TLS server + go p.startTLSServer() + + log.Printf("Proxy server started successfully") + + // Wait for context cancellation + <-ctx.Done() + + // Cleanup + p.cleanup() + return nil +} + +func (p *Proxy) registerProxy() error { + body := api.RegisterProxyRequest{ + IP: p.config.StaticIP, + Name: p.config.ProxyName, + } + + if p.config.Type == "instance" { + certResp, err := api.CallRegisterInstanceProxy(p.httpClient, body) + if err != nil { + return fmt.Errorf("failed to register instance proxy: %v", err) + } + p.certificates = &certResp + } else { + certResp, err := api.CallRegisterProxy(p.httpClient, body) + if err != nil { + return fmt.Errorf("failed to register org proxy: %v", err) + } + p.certificates = &certResp + } + + log.Printf("Successfully registered proxy and received certificates from API") + return nil +} + +func (p *Proxy) setupSSHServer() error { + // Parse SSH CA public key + sshCAPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(p.certificates.SSH.ClientCAPublicKey)) + if err != nil { + return fmt.Errorf("failed to parse SSH CA public key: %v", err) + } + + // Parse SSH server private key + sshServerKey, err := ssh.ParsePrivateKey([]byte(p.certificates.SSH.ServerPrivateKey)) + if err != nil { + return fmt.Errorf("failed to parse SSH server private key: %v", err) + } + + // Parse SSH server certificate + sshServerCert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(p.certificates.SSH.ServerCertificate)) + if err != nil { + return fmt.Errorf("failed to parse SSH server certificate: %v", err) + } + + // Create certificate signer + certSigner, err := ssh.NewCertSigner(sshServerCert.(*ssh.Certificate), sshServerKey) + if err != nil { + return fmt.Errorf("failed to create SSH certificate signer: %v", err) + } + + // Setup SSH server config + p.sshConfig = &ssh.ServerConfig{ + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + // Check if this is an SSH certificate + cert, ok := key.(*ssh.Certificate) + if !ok { + log.Printf("Gateway '%s' tried to authenticate with raw public key (rejected)", conn.User()) + return nil, fmt.Errorf("certificates required, raw public keys not allowed") + } + + // Validate the certificate + if err := p.validateSSHCertificate(cert, conn.User(), sshCAPubKey); err != nil { + log.Printf("Gateway '%s' certificate validation failed: %v", conn.User(), err) + return nil, err + } + + gatewayId := "" + if len(cert.ValidPrincipals) > 0 { + gatewayId = cert.ValidPrincipals[0] + } + + if gatewayId == "" { + return nil, fmt.Errorf("gateway id is required") + } + + return &ssh.Permissions{ + Extensions: map[string]string{ + "gateway-id": gatewayId, + }, + }, nil + }, + } + + p.sshConfig.AddHostKey(certSigner) + return nil +} + +func (p *Proxy) setupTLSServer() error { + // Parse TLS server certificate + serverCertBlock, _ := pem.Decode([]byte(p.certificates.PKI.ServerCertificate)) + if serverCertBlock == nil { + return fmt.Errorf("failed to decode server certificate") + } + + // Note: serverCert is parsed for validation but not used in the TLS config + // since we use the raw bytes directly + _, err := x509.ParseCertificate(serverCertBlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse server certificate: %v", err) + } + + // Parse TLS server private key + serverKeyBlock, _ := pem.Decode([]byte(p.certificates.PKI.ServerPrivateKey)) + if serverKeyBlock == nil { + return fmt.Errorf("failed to decode server private key") + } + + serverKey, err := x509.ParsePKCS8PrivateKey(serverKeyBlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse server private key: %v", err) + } + + // Parse client CA certificate + clientCABlock, _ := pem.Decode([]byte(p.certificates.PKI.ClientCA)) + if clientCABlock == nil { + return fmt.Errorf("failed to decode client CA certificate") + } + + clientCA, err := x509.ParseCertificate(clientCABlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse client CA certificate: %v", err) + } + + // Create certificate pool for client CAs + clientCAPool := x509.NewCertPool() + clientCAPool.AddCert(clientCA) + + // Create TLS config + p.tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{serverCertBlock.Bytes}, + PrivateKey: serverKey, + }, + }, + ClientCAs: clientCAPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS12, + } + + return nil +} + +func (p *Proxy) validateSSHCertificate(cert *ssh.Certificate, username string, caPubKey ssh.PublicKey) error { + // Check certificate type + if cert.CertType != ssh.UserCert { + return fmt.Errorf("invalid certificate type: %d", cert.CertType) + } + + // Check if certificate is signed by our CA + checker := &ssh.CertChecker{ + IsUserAuthority: func(auth ssh.PublicKey) bool { + return bytes.Equal(auth.Marshal(), caPubKey.Marshal()) + }, + } + + // Validate the certificate + if err := checker.CheckCert(username, cert); err != nil { + return fmt.Errorf("certificate check failed: %v", err) + } + + log.Printf("SSH certificate valid for user '%s', principals: %v", username, cert.ValidPrincipals) + return nil +} + +func (p *Proxy) startSSHServer() { + listener, err := net.Listen("tcp", ":"+p.config.SSHPort) + if err != nil { + log.Fatalf("Failed to start SSH server: %v", err) + } + p.sshListener = listener + + log.Printf("SSH server listening on :%s for gateways", p.config.SSHPort) + + for { + conn, err := listener.Accept() + if err != nil { + log.Printf("Failed to accept SSH connection: %v", err) + continue + } + go p.handleSSHAgent(conn) + } +} + +func (p *Proxy) handleSSHAgent(conn net.Conn) { + defer conn.Close() + + // SSH handshake + sshConn, chans, _, err := ssh.NewServerConn(conn, p.sshConfig) + if err != nil { + log.Printf("SSH handshake failed: %v", err) + return + } + + gatewayId := sshConn.Permissions.Extensions["gateway-id"] + log.Printf("SSH handshake successful for gateway: %s", gatewayId) + + // Store the connection + p.mu.Lock() + p.tunnels[gatewayId] = sshConn + p.mu.Unlock() + + // Clean up when agent disconnects + defer func() { + p.mu.Lock() + delete(p.tunnels, gatewayId) + p.mu.Unlock() + log.Printf("Gateway %s disconnected", gatewayId) + }() + + for newChannel := range chans { + switch newChannel.ChannelType() { + case "session": + newChannel.Reject(ssh.Prohibited, "no shell access") + case "x11": + newChannel.Reject(ssh.Prohibited, "no X11 forwarding") + case "auth-agent": + newChannel.Reject(ssh.Prohibited, "no agent forwarding") + } + } +} + +func (p *Proxy) startTLSServer() { + listener, err := net.Listen("tcp", ":"+p.config.TLSPort) + if err != nil { + log.Fatalf("Failed to start TLS server: %v", err) + } + p.tlsListener = listener + + log.Printf("TLS server listening on :%s for clients", p.config.TLSPort) + + for { + conn, err := listener.Accept() + if err != nil { + log.Printf("Failed to accept TLS connection: %v", err) + continue + } + go p.handleTLSClient(conn) + } +} + +func (p *Proxy) handleTLSClient(conn net.Conn) { + defer conn.Close() + + log.Printf("Client connected from %s", conn.RemoteAddr()) + + // Wrap connection with TLS + tlsConn := tls.Server(conn, p.tlsConfig) + if err := tlsConn.Handshake(); err != nil { + log.Printf("TLS handshake failed: %v", err) + return + } + + // Log client certificate info + if len(tlsConn.ConnectionState().PeerCertificates) > 0 { + cert := tlsConn.ConnectionState().PeerCertificates[0] + log.Printf("Client connected with certificate: %s", cert.Subject.CommonName) + } + + p.handleClient(tlsConn) +} + +func (p *Proxy) handleClient(clientConn net.Conn) { + defer clientConn.Close() + + // Read the first few bytes to determine which agent to connect to + // Format: "agent1:host:port\n" or "agent1:host:port" followed by data + buffer := make([]byte, 1024) + n, err := clientConn.Read(buffer) + if err != nil { + log.Printf("Failed to read from client: %v", err) + return + } + + // Find the first newline to separate agent info from data + data := buffer[:n] + log.Printf("Received %d bytes from client: %q", n, string(data)) + newlineIndex := bytes.IndexByte(data, '\n') + + var gatewayId, targetHost string + var targetPort uint32 + var remainingData []byte + + if newlineIndex != -1 { + // Agent info is everything before the newline + agentInfo := string(data[:newlineIndex]) + remainingData = data[newlineIndex+1:] + + // Parse agent info in format "agent:host:port" + parts := strings.Split(agentInfo, ":") + if len(parts) != 3 { + log.Printf("Invalid client data format, expected 'agent:host:port', got: %s", agentInfo) + clientConn.Write([]byte("ERROR: Invalid format. Expected 'agent:host:port'\n")) + return + } + + gatewayId = parts[0] + targetHost = parts[1] + portStr := parts[2] + + // Parse port number + port, err := strconv.ParseUint(portStr, 10, 32) + if err != nil { + log.Printf("Invalid port number: %s", portStr) + clientConn.Write([]byte("ERROR: Invalid port number\n")) + return + } + targetPort = uint32(port) + + log.Printf("Extracted gateway: %s, target: %s:%d", gatewayId, targetHost, targetPort) + } else { + log.Printf("Invalid client data format - no newline found") + clientConn.Write([]byte("ERROR: Please use format 'gatewayId:host:port'\n")) + return + } + + // Get the SSH connection for this agent + p.mu.RLock() + conn, exists := p.tunnels[gatewayId] + p.mu.RUnlock() + + if !exists { + log.Printf("Gateway '%s' not connected", gatewayId) + clientConn.Write([]byte("ERROR: Gateway not connected\n")) + return + } + + log.Printf("Routing TCP connection to gateway: %s", gatewayId) + + // Open SSH channel to connect to agent's local service through the tunnel + payload := struct { + Host string + Port uint32 + _ string + _ uint32 + }{targetHost, targetPort, "", 0} + + channel, _, err := conn.OpenChannel("direct-tcpip", ssh.Marshal(&payload)) + if err != nil { + log.Printf("Failed to connect to agent: %v", err) + clientConn.Write([]byte("ERROR: Failed to connect to agent\n")) + return + } + defer channel.Close() + + // If we have remaining data from the initial read, write it to the channel + if len(remainingData) > 0 { + channel.Write(remainingData) + } + + // Bidirectional forwarding + go func() { + io.Copy(channel, clientConn) + channel.CloseWrite() + }() + + io.Copy(clientConn, channel) + log.Printf("Client %s disconnected", clientConn.RemoteAddr()) +} + +func (p *Proxy) cleanup() { + log.Printf("Shutting down proxy server...") + + if p.sshListener != nil { + p.sshListener.Close() + } + if p.tlsListener != nil { + p.tlsListener.Close() + } + + log.Printf("Proxy server shutdown complete") +} From 1fb0a482b9b89f421e8982e6bf95e93afb6ed3eb Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sat, 30 Aug 2025 01:58:09 +0800 Subject: [PATCH 02/38] misc: updated proxy to start tls server instead of tcp --- packages/proxy/proxy.go | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 492dee61..e177e10c 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -335,7 +335,7 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { } func (p *Proxy) startTLSServer() { - listener, err := net.Listen("tcp", ":"+p.config.TLSPort) + listener, err := tls.Listen("tcp", ":"+p.config.TLSPort, p.tlsConfig) if err != nil { log.Fatalf("Failed to start TLS server: %v", err) } @@ -349,34 +349,21 @@ func (p *Proxy) startTLSServer() { log.Printf("Failed to accept TLS connection: %v", err) continue } - go p.handleTLSClient(conn) + go p.handleClient(conn) } } -func (p *Proxy) handleTLSClient(conn net.Conn) { - defer conn.Close() - - log.Printf("Client connected from %s", conn.RemoteAddr()) - - // Wrap connection with TLS - tlsConn := tls.Server(conn, p.tlsConfig) - if err := tlsConn.Handshake(); err != nil { - log.Printf("TLS handshake failed: %v", err) - return - } - - // Log client certificate info - if len(tlsConn.ConnectionState().PeerCertificates) > 0 { - cert := tlsConn.ConnectionState().PeerCertificates[0] - log.Printf("Client connected with certificate: %s", cert.Subject.CommonName) - } - - p.handleClient(tlsConn) -} - func (p *Proxy) handleClient(clientConn net.Conn) { defer clientConn.Close() + // Log client certificate info if this is a TLS connection + if tlsConn, ok := clientConn.(*tls.Conn); ok { + if len(tlsConn.ConnectionState().PeerCertificates) > 0 { + cert := tlsConn.ConnectionState().PeerCertificates[0] + log.Printf("Client connected with certificate: %s", cert.Subject.CommonName) + } + } + // Read the first few bytes to determine which agent to connect to // Format: "agent1:host:port\n" or "agent1:host:port" followed by data buffer := make([]byte, 1024) From cda3ac3e49a1d00ef8def690b9d009ede0e7a70f Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sat, 30 Aug 2025 02:19:20 +0800 Subject: [PATCH 03/38] misc: added full server certificate chain to proxy tls --- packages/proxy/proxy.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index e177e10c..d2c61039 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -213,6 +213,18 @@ func (p *Proxy) setupTLSServer() error { return fmt.Errorf("failed to parse server certificate: %v", err) } + // Parse all certificates from the chain (intermediate + root CAs) + var chainCerts [][]byte + chainData := []byte(p.certificates.PKI.ServerCertificateChain) + for { + block, rest := pem.Decode(chainData) + if block == nil { + break + } + chainCerts = append(chainCerts, block.Bytes) + chainData = rest + } + // Parse TLS server private key serverKeyBlock, _ := pem.Decode([]byte(p.certificates.PKI.ServerPrivateKey)) if serverKeyBlock == nil { @@ -239,11 +251,15 @@ func (p *Proxy) setupTLSServer() error { clientCAPool := x509.NewCertPool() clientCAPool.AddCert(clientCA) + // Create certificate chain: server cert + chain certs (intermediate + root) + certChain := [][]byte{serverCertBlock.Bytes} + certChain = append(certChain, chainCerts...) + // Create TLS config p.tlsConfig = &tls.Config{ Certificates: []tls.Certificate{ { - Certificate: [][]byte{serverCertBlock.Bytes}, + Certificate: certChain, PrivateKey: serverKey, }, }, From 6f7eda5af231ac2c2d188726dc2ede3b47f74fc0 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sat, 30 Aug 2025 02:31:03 +0800 Subject: [PATCH 04/38] misc: added log --- packages/proxy/proxy.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index d2c61039..174e7ccd 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -255,6 +255,19 @@ func (p *Proxy) setupTLSServer() error { certChain := [][]byte{serverCertBlock.Bytes} certChain = append(certChain, chainCerts...) + // Debug: log the complete certificate chain as PEM + var chainPEM strings.Builder + for i, certBytes := range certChain { + chainPEM.WriteString(fmt.Sprintf("--- Certificate %d ---\n", i+1)) + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }) + chainPEM.Write(certPEM) + chainPEM.WriteString("\n") + } + log.Printf("Complete certificate chain PEM:\n%s", chainPEM.String()) + // Create TLS config p.tlsConfig = &tls.Config{ Certificates: []tls.Certificate{ From 33692075b2455e443973e94a503daa5134fcf53f Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sat, 30 Aug 2025 03:29:47 +0800 Subject: [PATCH 05/38] misc: updated proxy to fetch client pem chain --- packages/api/model.go | 3 +-- packages/proxy/proxy.go | 58 ++++++++++++++--------------------------- 2 files changed, 20 insertions(+), 41 deletions(-) diff --git a/packages/api/model.go b/packages/api/model.go index ad172278..f2128455 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -712,9 +712,8 @@ type RegisterProxyRequest struct { type RegisterProxyResponse struct { PKI struct { ServerCertificate string `json:"serverCertificate"` - ServerCertificateChain string `json:"serverCertificateChain"` ServerPrivateKey string `json:"serverPrivateKey"` - ClientCA string `json:"clientCA"` + ClientCertificateChain string `json:"clientCertificateChain"` } `json:"pki"` SSH struct { ServerCertificate string `json:"serverCertificate"` diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 174e7ccd..4d8794e3 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -213,18 +213,6 @@ func (p *Proxy) setupTLSServer() error { return fmt.Errorf("failed to parse server certificate: %v", err) } - // Parse all certificates from the chain (intermediate + root CAs) - var chainCerts [][]byte - chainData := []byte(p.certificates.PKI.ServerCertificateChain) - for { - block, rest := pem.Decode(chainData) - if block == nil { - break - } - chainCerts = append(chainCerts, block.Bytes) - chainData = rest - } - // Parse TLS server private key serverKeyBlock, _ := pem.Decode([]byte(p.certificates.PKI.ServerPrivateKey)) if serverKeyBlock == nil { @@ -236,43 +224,35 @@ func (p *Proxy) setupTLSServer() error { return fmt.Errorf("failed to parse server private key: %v", err) } - // Parse client CA certificate - clientCABlock, _ := pem.Decode([]byte(p.certificates.PKI.ClientCA)) - if clientCABlock == nil { - return fmt.Errorf("failed to decode client CA certificate") - } - - clientCA, err := x509.ParseCertificate(clientCABlock.Bytes) - if err != nil { - return fmt.Errorf("failed to parse client CA certificate: %v", err) - } - // Create certificate pool for client CAs clientCAPool := x509.NewCertPool() - clientCAPool.AddCert(clientCA) - // Create certificate chain: server cert + chain certs (intermediate + root) - certChain := [][]byte{serverCertBlock.Bytes} - certChain = append(certChain, chainCerts...) + var chainCerts [][]byte + chainData := []byte(p.certificates.PKI.ClientCertificateChain) + for { + block, rest := pem.Decode(chainData) + if block == nil { + break + } + chainCerts = append(chainCerts, block.Bytes) + chainData = rest + } - // Debug: log the complete certificate chain as PEM - var chainPEM strings.Builder - for i, certBytes := range certChain { - chainPEM.WriteString(fmt.Sprintf("--- Certificate %d ---\n", i+1)) - certPEM := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - }) - chainPEM.Write(certPEM) - chainPEM.WriteString("\n") + for i, certBytes := range chainCerts { + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + log.Printf("Failed to parse client chain certificate %d: %v", i+1, err) + continue + } + clientCAPool.AddCert(cert) + log.Printf("Added client CA certificate %d to pool: %s", i+1, cert.Subject.CommonName) } - log.Printf("Complete certificate chain PEM:\n%s", chainPEM.String()) // Create TLS config p.tlsConfig = &tls.Config{ Certificates: []tls.Certificate{ { - Certificate: certChain, + Certificate: [][]byte{serverCertBlock.Bytes}, PrivateKey: serverKey, }, }, From 97b9d174780c6f35a8f4d5b4d06f4a1ae0ed133f Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sat, 30 Aug 2025 04:26:29 +0800 Subject: [PATCH 06/38] misc: added log point --- packages/proxy/proxy.go | 64 ++++++----------------------------------- 1 file changed, 8 insertions(+), 56 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 4d8794e3..56979266 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -12,8 +12,6 @@ import ( "log" "net" - "strconv" - "strings" "sync" "github.com/Infisical/infisical-merge/packages/api" @@ -365,64 +363,23 @@ func (p *Proxy) startTLSServer() { func (p *Proxy) handleClient(clientConn net.Conn) { defer clientConn.Close() + var gatewayId string + // Log client certificate info if this is a TLS connection if tlsConn, ok := clientConn.(*tls.Conn); ok { + fmt.Println(tlsConn.ConnectionState().PeerCertificates) if len(tlsConn.ConnectionState().PeerCertificates) > 0 { cert := tlsConn.ConnectionState().PeerCertificates[0] log.Printf("Client connected with certificate: %s", cert.Subject.CommonName) + gatewayId = cert.Subject.CommonName } } - // Read the first few bytes to determine which agent to connect to - // Format: "agent1:host:port\n" or "agent1:host:port" followed by data - buffer := make([]byte, 1024) - n, err := clientConn.Read(buffer) - if err != nil { - log.Printf("Failed to read from client: %v", err) - return - } - - // Find the first newline to separate agent info from data - data := buffer[:n] - log.Printf("Received %d bytes from client: %q", n, string(data)) - newlineIndex := bytes.IndexByte(data, '\n') - - var gatewayId, targetHost string - var targetPort uint32 - var remainingData []byte - - if newlineIndex != -1 { - // Agent info is everything before the newline - agentInfo := string(data[:newlineIndex]) - remainingData = data[newlineIndex+1:] - - // Parse agent info in format "agent:host:port" - parts := strings.Split(agentInfo, ":") - if len(parts) != 3 { - log.Printf("Invalid client data format, expected 'agent:host:port', got: %s", agentInfo) - clientConn.Write([]byte("ERROR: Invalid format. Expected 'agent:host:port'\n")) - return - } - - gatewayId = parts[0] - targetHost = parts[1] - portStr := parts[2] + fmt.Println("gatewayId", gatewayId) - // Parse port number - port, err := strconv.ParseUint(portStr, 10, 32) - if err != nil { - log.Printf("Invalid port number: %s", portStr) - clientConn.Write([]byte("ERROR: Invalid port number\n")) - return - } - targetPort = uint32(port) - - log.Printf("Extracted gateway: %s, target: %s:%d", gatewayId, targetHost, targetPort) - } else { - log.Printf("Invalid client data format - no newline found") - clientConn.Write([]byte("ERROR: Please use format 'gatewayId:host:port'\n")) - return - } + // TODO: extract these from the certificate + targetHost := "localhost" + targetPort := uint32(22) // Get the SSH connection for this agent p.mu.RLock() @@ -453,11 +410,6 @@ func (p *Proxy) handleClient(clientConn net.Conn) { } defer channel.Close() - // If we have remaining data from the initial read, write it to the channel - if len(remainingData) > 0 { - channel.Write(remainingData) - } - // Bidirectional forwarding go func() { io.Copy(channel, clientConn) From ef24451d49827862a1d30c1733449b2e5fac59d9 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sat, 30 Aug 2025 04:34:58 +0800 Subject: [PATCH 07/38] misc: added handshake forcing --- packages/proxy/proxy.go | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 56979266..b99a8f31 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -11,7 +11,7 @@ import ( "io" "log" "net" - + "strings" "sync" "github.com/Infisical/infisical-merge/packages/api" @@ -367,12 +367,33 @@ func (p *Proxy) handleClient(clientConn net.Conn) { // Log client certificate info if this is a TLS connection if tlsConn, ok := clientConn.(*tls.Conn); ok { - fmt.Println(tlsConn.ConnectionState().PeerCertificates) - if len(tlsConn.ConnectionState().PeerCertificates) > 0 { - cert := tlsConn.ConnectionState().PeerCertificates[0] + log.Printf("TLS connection detected, forcing handshake...") + err := tlsConn.Handshake() + if err != nil { + log.Printf("TLS handshake failed: %v", err) + return + } + + state := tlsConn.ConnectionState() + log.Printf("TLS handshake completed, peer certificates count: %d", len(state.PeerCertificates)) + + if len(state.PeerCertificates) > 0 { + cert := state.PeerCertificates[0] log.Printf("Client connected with certificate: %s", cert.Subject.CommonName) - gatewayId = cert.Subject.CommonName + parts := strings.Split(cert.Subject.CommonName, ":") + if len(parts) >= 2 { + gatewayId = parts[1] + } else { + log.Printf("Invalid CommonName format, expected 'part1:part2', got: %s", cert.Subject.CommonName) + return + } + } else { + log.Printf("No peer certificates found") + return } + } else { + log.Printf("Not a TLS connection, connection type: %T", clientConn) + return } fmt.Println("gatewayId", gatewayId) From a233a3f3cbf81639403d798302189260e20b8d34 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sun, 31 Aug 2025 23:05:39 +0800 Subject: [PATCH 08/38] misc: updated gateway to fetch client certificate chain --- packages/api/model.go | 3 +-- packages/gateway-v2/gateway.go | 30 +++++++++++++++++++----------- packages/proxy/proxy.go | 3 --- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/packages/api/model.go b/packages/api/model.go index f2128455..c436d117 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -732,9 +732,8 @@ type RegisterGatewayResponse struct { ProxyIP string `json:"proxyIp"` PKI struct { ServerCertificate string `json:"serverCertificate"` - ServerCertificateChain string `json:"serverCertificateChain"` ServerPrivateKey string `json:"serverPrivateKey"` - ClientCA string `json:"clientCA"` + ClientCertificateChain string `json:"clientCertificateChain"` } `json:"pki"` SSH struct { ClientCertificate string `json:"clientCertificate"` diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 2208e8cd..12ad4acd 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -361,21 +361,29 @@ func (g *Gateway) createMTLSConfig() (*tls.Config, error) { return nil, fmt.Errorf("failed to parse server private key: %v", err) } - // Parse client CA certificate - clientCABlock, _ := pem.Decode([]byte(g.certificates.PKI.ClientCA)) - if clientCABlock == nil { - return nil, fmt.Errorf("failed to decode client CA certificate") + // Create certificate pool for client CAs + clientCAPool := x509.NewCertPool() + var chainCerts [][]byte + chainData := []byte(g.certificates.PKI.ClientCertificateChain) + for { + block, rest := pem.Decode(chainData) + if block == nil { + break + } + chainCerts = append(chainCerts, block.Bytes) + chainData = rest } - clientCA, err := x509.ParseCertificate(clientCABlock.Bytes) - if err != nil { - return nil, fmt.Errorf("failed to parse client CA certificate: %v", err) + for i, certBytes := range chainCerts { + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + log.Printf("Failed to parse client chain certificate %d: %v", i+1, err) + continue + } + clientCAPool.AddCert(cert) + log.Printf("Added client CA certificate %d to pool: %s", i+1, cert.Subject.CommonName) } - // Create certificate pool for client CAs - clientCAPool := x509.NewCertPool() - clientCAPool.AddCert(clientCA) - // Create TLS config return &tls.Config{ Certificates: []tls.Certificate{ diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index b99a8f31..f57aaea2 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -365,7 +365,6 @@ func (p *Proxy) handleClient(clientConn net.Conn) { var gatewayId string - // Log client certificate info if this is a TLS connection if tlsConn, ok := clientConn.(*tls.Conn); ok { log.Printf("TLS connection detected, forcing handshake...") err := tlsConn.Handshake() @@ -396,8 +395,6 @@ func (p *Proxy) handleClient(clientConn net.Conn) { return } - fmt.Println("gatewayId", gatewayId) - // TODO: extract these from the certificate targetHost := "localhost" targetPort := uint32(22) From 74db2f340c5bea456baa51d0f16dc0e206ad3b11 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Sun, 31 Aug 2025 23:25:42 +0800 Subject: [PATCH 09/38] misc: set target host of proxy to gateway --- packages/proxy/proxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index f57aaea2..279dc2c8 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -396,7 +396,7 @@ func (p *Proxy) handleClient(clientConn net.Conn) { } // TODO: extract these from the certificate - targetHost := "localhost" + targetHost := "gateway" targetPort := uint32(22) // Get the SSH connection for this agent From 36c069dbeef68754f5438c85b0b8ab5622f4630d Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Mon, 1 Sep 2025 22:20:36 +0800 Subject: [PATCH 10/38] feat: added TCP and HTTP forward handling to gateway --- packages/cmd/network.go | 16 +- packages/gateway-v2/connection.go | 143 ++++++++++++++++ packages/gateway-v2/constants.go | 22 +++ packages/gateway-v2/gateway.go | 271 +++++++++++++++++++++--------- packages/proxy/proxy.go | 1 - 5 files changed, 360 insertions(+), 93 deletions(-) create mode 100644 packages/gateway-v2/connection.go create mode 100644 packages/gateway-v2/constants.go diff --git a/packages/cmd/network.go b/packages/cmd/network.go index 8ec9c6b3..2d6d67f4 100644 --- a/packages/cmd/network.go +++ b/packages/cmd/network.go @@ -57,9 +57,9 @@ var networkProxyCmd = &cobra.Command{ } if instanceType == "instance" { - proxyAuthSecret := os.Getenv("PROXY_AUTH_SECRET") + proxyAuthSecret := os.Getenv(gatewayv2.PROXY_AUTH_SECRET_ENV_NAME) if proxyAuthSecret == "" { - util.HandleError(fmt.Errorf("PROXY_AUTH_SECRET is not set"), "unable to get proxy auth secret") + util.HandleError(fmt.Errorf("%s is not set", gatewayv2.PROXY_AUTH_SECRET_ENV_NAME), "unable to get proxy auth secret") } proxyInstance.SetToken(proxyAuthSecret) @@ -145,14 +145,14 @@ var networkGatewayCmd = &cobra.Command{ Long: "Run the Infisical gateway component", Run: func(cmd *cobra.Command, args []string) { - proxyName, err := cmd.Flags().GetString("proxy-name") - if err != nil || proxyName == "" { - util.HandleError(err, "unable to get proxy-name flag") + proxyName, err := util.GetCmdFlagOrEnv(cmd, "proxy-name", []string{gatewayv2.PROXY_NAME_ENV_NAME}) + if err != nil { + util.HandleError(err, fmt.Sprintf("unable to get proxy-name flag or %s env", gatewayv2.PROXY_NAME_ENV_NAME)) } - gatewayName, err := cmd.Flags().GetString("name") - if err != nil || gatewayName == "" { - util.HandleError(err, "unable to get name flag") + gatewayName, err := util.GetCmdFlagOrEnv(cmd, "name", []string{gatewayv2.GATEWAY_NAME_ENV_NAME}) + if err != nil { + util.HandleError(err, fmt.Sprintf("unable to get name flag or %s env", gatewayv2.GATEWAY_NAME_ENV_NAME)) } gatewayInstance, err := gatewayv2.NewGateway(&gatewayv2.GatewayConfig{ diff --git a/packages/gateway-v2/connection.go b/packages/gateway-v2/connection.go new file mode 100644 index 00000000..ad521392 --- /dev/null +++ b/packages/gateway-v2/connection.go @@ -0,0 +1,143 @@ +package gatewayv2 + +import ( + "bufio" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "time" + + "github.com/rs/zerolog/log" +) + +func buildHttpInternalServerError(message string) string { + return fmt.Sprintf("HTTP/1.1 500 Internal Server Error\r\nContent-Type: application/json\r\n\r\n{\"message\": \"gateway: %s\"}", message) +} + +func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, targetURL string, caCert []byte, verifyTLS bool) error { + transport := &http.Transport{ + DisableKeepAlives: false, + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + } + + if strings.HasPrefix(targetURL, "https://") { + tlsConfig := &tls.Config{} + + if len(caCert) > 0 { + caCertPool := x509.NewCertPool() + if caCertPool.AppendCertsFromPEM(caCert) { + tlsConfig.RootCAs = caCertPool + log.Info().Msg("Using provided CA certificate from gateway client") + } else { + log.Error().Msg("Failed to parse provided CA certificate") + } + } + + tlsConfig.InsecureSkipVerify = !verifyTLS + log.Info().Msgf("TLS verification set to: %v", verifyTLS) + + transport.TLSClientConfig = tlsConfig + } + + // Loop to handle multiple HTTP requests on the same connection + for { + log.Info().Msg("Attempting to read HTTP request...") + req, err := http.ReadRequest(reader) + + if err != nil { + if errors.Is(err, io.EOF) { + log.Info().Msg("Client closed HTTP connection") + return nil + } + + log.Error().Msgf("Failed to read HTTP request: %v", err) + return fmt.Errorf("failed to read HTTP request: %v", err) + } + log.Info().Msgf("Received HTTP request: %s", req.URL.Path) + + // Build full target URL + var targetFullURL string + if strings.HasPrefix(targetURL, "http://") || strings.HasPrefix(targetURL, "https://") { + baseURL := strings.TrimSuffix(targetURL, "/") + targetFullURL = baseURL + req.URL.Path + if req.URL.RawQuery != "" { + targetFullURL += "?" + req.URL.RawQuery + } + } else { + baseURL := strings.TrimSuffix("http://"+targetURL, "/") + targetFullURL = baseURL + req.URL.Path + if req.URL.RawQuery != "" { + targetFullURL += "?" + req.URL.RawQuery + } + } + + // create the request to the target + proxyReq, err := http.NewRequest(req.Method, targetFullURL, req.Body) + if err != nil { + log.Error().Msgf("Failed to create proxy request: %v", err) + conn.Write([]byte(buildHttpInternalServerError("failed to create proxy request"))) + continue // Continue to next request + } + proxyReq.Header = req.Header.Clone() + + log.Info().Msgf("Proxying %s %s to %s", req.Method, req.URL.Path, targetFullURL) + + client := &http.Client{ + Transport: transport, + Timeout: 30 * time.Second, + } + + resp, err := client.Do(proxyReq) + if err != nil { + log.Error().Msgf("Failed to reach target: %v", err) + conn.Write([]byte(buildHttpInternalServerError(fmt.Sprintf("failed to reach target due to networking error: %s", err.Error())))) + continue // Continue to next request + } + + // Write the entire response (status line, headers, body) to the connection + resp.Header.Del("Connection") + + log.Info().Msgf("Writing response to connection: %s", resp.Status) + + if err := resp.Write(conn); err != nil { + log.Error().Err(err).Msg("Failed to write response to connection") + resp.Body.Close() + return fmt.Errorf("failed to write response to connection: %w", err) + } + + resp.Body.Close() + + // Check if client wants to close connection + if req.Header.Get("Connection") == "close" { + log.Info().Msg("Client requested connection close") + return nil + } + } +} + +func handleTCPProxy(conn *tls.Conn, target string) error { + localConn, err := net.Dial("tcp", target) + if err != nil { + log.Error().Msgf("Failed to connect to local service %s: %v", target, err) + return fmt.Errorf("failed to connect to local service %s: %v", target, err) + } + defer localConn.Close() + + // Create bidirectional tunnel with TLS + // Forward data from TLS connection to local service + go func() { + io.Copy(localConn, conn) + localConn.Close() + }() + + // Forward data from local service to TLS connection + io.Copy(conn, localConn) + + return nil +} diff --git a/packages/gateway-v2/constants.go b/packages/gateway-v2/constants.go new file mode 100644 index 00000000..996068d6 --- /dev/null +++ b/packages/gateway-v2/constants.go @@ -0,0 +1,22 @@ +package gatewayv2 + +const ( + KUBERNETES_SERVICE_HOST_ENV_NAME = "KUBERNETES_SERVICE_HOST" + KUBERNETES_SERVICE_PORT_HTTPS_ENV_NAME = "KUBERNETES_SERVICE_PORT_HTTPS" + KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" + KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/token" + + PROXY_NAME_ENV_NAME = "PROXY_NAME" + GATEWAY_NAME_ENV_NAME = "GATEWAY_NAME" + + PROXY_AUTH_SECRET_ENV_NAME = "PROXY_AUTH_SECRET" + + INFISICAL_HTTP_PROXY_ACTION_HEADER = "x-infisical-action" +) + +type HttpProxyAction string + +const ( + HttpProxyActionInjectGatewayK8sServiceAccountToken HttpProxyAction = "inject-k8s-sa-auth-token" + HttpProxyActionUseGatewayK8sServiceAccount HttpProxyAction = "use-k8s-sa" +) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 12ad4acd..85f7b4ab 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -1,25 +1,52 @@ package gatewayv2 import ( + "bufio" "bytes" "context" "crypto/rsa" "crypto/tls" "crypto/x509" + "encoding/base64" + "encoding/json" "encoding/pem" "fmt" - "io" - "log" "net" + "strconv" + "strings" "sync" "time" "github.com/Infisical/infisical-merge/packages/api" "github.com/Infisical/infisical-merge/packages/util" "github.com/go-resty/resty/v2" + "github.com/rs/zerolog/log" "golang.org/x/crypto/ssh" ) +// ForwardMode represents the type of forwarding +type ForwardMode string + +const ( + ForwardModeHTTP ForwardMode = "HTTP" + ForwardModeTCP ForwardMode = "TCP" +) + +// ForwardConfig contains the configuration for forwarding +type ForwardConfig struct { + Mode ForwardMode + CACertificate []byte // Decoded CA certificate for HTTPS verification + VerifyTLS bool // Whether to verify TLS certificates + TargetHost string + TargetPort int +} + +// RoutingInfo represents the routing information embedded in client certificates +type RoutingInfo struct { + TargetHost string `json:"targetHost"` + TargetPort int `json:"targetPort"` +} + type GatewayConfig struct { Name string ProxyName string @@ -76,11 +103,11 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) { // Change the Start method to accept a context func (g *Gateway) Start(ctx context.Context) error { - log.Printf("Starting gateway") + log.Info().Msgf("Starting gateway") for { select { case <-ctx.Done(): - log.Printf("Gateway stopped by context cancellation") + log.Info().Msgf("Gateway stopped by context cancellation") return nil default: if err := g.connectAndServe(); err != nil { @@ -93,7 +120,7 @@ func (g *Gateway) Start(ctx context.Context) error { } } // If we get here, the connection was closed gracefully - log.Printf("Connection closed, reconnecting in 10 seconds...") + log.Info().Msgf("Connection closed, reconnecting in 10 seconds...") select { case <-ctx.Done(): return ctx.Err() @@ -132,7 +159,7 @@ func (g *Gateway) connectAndServe() error { } // Connect to Proxy server - log.Printf("Connecting to SSH server on %s:%d...", g.certificates.ProxyIP, g.config.SSHPort) + log.Info().Msgf("Connecting to SSH server on %s:%d...", g.certificates.ProxyIP, g.config.SSHPort) client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", g.certificates.ProxyIP, g.config.SSHPort), sshConfig) if err != nil { return fmt.Errorf("failed to connect to SSH server: %v", err) @@ -151,7 +178,7 @@ func (g *Gateway) connectAndServe() error { client.Close() }() - log.Printf("SSH connection established for gateway") + log.Info().Msgf("SSH connection established for gateway") // Handle incoming channels from the server channels := client.HandleChannelOpen("direct-tcpip") @@ -180,7 +207,57 @@ func (g *Gateway) registerGateway() error { g.GatewayID = certResp.GatewayID g.certificates = &certResp - log.Printf("Successfully registered gateway and received certificates") + log.Info().Msgf("Successfully registered gateway and received certificates") + + // Create mTLS config once during registration + serverCertBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerCertificate)) + if serverCertBlock == nil { + return fmt.Errorf("failed to decode server certificate") + } + + serverKeyBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerPrivateKey)) + if serverKeyBlock == nil { + return fmt.Errorf("failed to decode server private key") + } + + serverKey, err := x509.ParsePKCS8PrivateKey(serverKeyBlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse server private key: %v", err) + } + + clientCAPool := x509.NewCertPool() + var chainCerts [][]byte + chainData := []byte(g.certificates.PKI.ClientCertificateChain) + for { + block, rest := pem.Decode(chainData) + if block == nil { + break + } + chainCerts = append(chainCerts, block.Bytes) + chainData = rest + } + + for i, certBytes := range chainCerts { + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + log.Info().Msgf("Failed to parse client chain certificate %d: %v", i+1, err) + continue + } + clientCAPool.AddCert(cert) + } + + g.tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{serverCertBlock.Bytes}, + PrivateKey: serverKey, + }, + }, + ClientCAs: clientCAPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS12, + } + return nil } @@ -232,10 +309,8 @@ func (g *Gateway) createSSHConfig() (*ssh.ClientConfig, error) { } func (g *Gateway) createHostKeyCallback() ssh.HostKeyCallback { - // Parse CA public key once when creating the callback caKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(g.certificates.SSH.ServerCAPublicKey)) if err != nil { - // Return a callback that always fails since we can't parse the CA key return func(hostname string, remote net.Addr, key ssh.PublicKey) error { return fmt.Errorf("failed to parse CA public key: %v", err) } @@ -262,7 +337,7 @@ func (g *Gateway) validateHostCertificate(cert *ssh.Certificate, hostname string return fmt.Errorf("host certificate check failed: %v", err) } - log.Printf("Host certificate validated successfully for %s", hostname) + log.Info().Msgf("Host certificate validated successfully for %s", hostname) return nil } @@ -275,32 +350,24 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { } if err := ssh.Unmarshal(newChannel.ExtraData(), &req); err != nil { - log.Printf("Failed to parse channel request: %v", err) + log.Info().Msgf("Failed to parse channel request: %v", err) newChannel.Reject(ssh.Prohibited, "invalid request") return } - log.Printf("Incoming connection request to %s:%d from %s:%d", - req.Host, req.Port, req.OriginHost, req.OriginPort) - - // Accept the channel channel, requests, err := newChannel.Accept() if err != nil { - log.Printf("Failed to accept channel: %v", err) + log.Info().Msgf("Failed to accept channel: %v", err) return } defer channel.Close() go ssh.DiscardRequests(requests) - // Determine the target address - target := fmt.Sprintf("%s:%d", req.Host, req.Port) - log.Printf("Creating TCP tunnel to: %s", target) - // Create mTLS server configuration - tlsConfig, err := g.createMTLSConfig() - if err != nil { - log.Printf("Failed to create mTLS config: %v", err) + tlsConfig := g.tlsConfig + if tlsConfig == nil { + log.Info().Msgf("TLS config not initialized, cannot create mTLS server") return } @@ -314,88 +381,124 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { // Perform TLS handshake if err := tlsConn.Handshake(); err != nil { - log.Printf("TLS handshake failed: %v", err) + log.Info().Msgf("TLS handshake failed: %v", err) return } - log.Printf("mTLS connection established with client: %s", tlsConn.ConnectionState().ServerName) + log.Info().Msgf("mTLS connection established with client") - // Connect to local service - localConn, err := net.Dial("tcp", target) + // Create reader for the TLS connection + reader := bufio.NewReader(tlsConn) + + // Get the forward mode here + forwardConfig, err := g.parseForwardConfig(tlsConn, reader) if err != nil { - log.Printf("Failed to connect to local service %s: %v", target, err) + log.Info().Msgf("Failed to parse forward command: %v", err) return } - defer localConn.Close() - log.Printf("TCP tunnel established to %s", target) + // Use target from certificate + target := fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort) + log.Info().Msgf("Using target from certificate: %s", target) - // Create bidirectional tunnel with TLS - // Forward data from TLS connection to local service - go func() { - io.Copy(localConn, tlsConn) - localConn.Close() - log.Printf("TLS -> local service tunnel closed") - }() - - // Forward data from local service to TLS connection - io.Copy(tlsConn, localConn) - log.Printf("Local service -> TLS tunnel closed") + if forwardConfig.Mode == ForwardModeHTTP { + handleHTTPProxy(tlsConn, reader, target, forwardConfig.CACertificate, forwardConfig.VerifyTLS) + return + } else if forwardConfig.Mode == ForwardModeTCP { + handleTCPProxy(tlsConn, target) + return + } } -func (g *Gateway) createMTLSConfig() (*tls.Config, error) { - // Parse server certificate - serverCertBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerCertificate)) - if serverCertBlock == nil { - return nil, fmt.Errorf("failed to decode server certificate") - } +func (g *Gateway) parseForwardConfig(tlsConn *tls.Conn, reader *bufio.Reader) (*ForwardConfig, error) { + config := &ForwardConfig{} - // Parse server private key - serverKeyBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerPrivateKey)) - if serverKeyBlock == nil { - return nil, fmt.Errorf("failed to decode server private key") + if err := g.parseRoutingInfoFromCertificate(tlsConn, config); err != nil { + return nil, fmt.Errorf("failed to parse routing info from certificate: %v", err) } - serverKey, err := x509.ParsePKCS8PrivateKey(serverKeyBlock.Bytes) - if err != nil { - return nil, fmt.Errorf("failed to parse server private key: %v", err) + for { + msg, err := reader.ReadBytes('\n') + if err != nil { + return nil, fmt.Errorf("failed to read command: %v", err) + } + + cmd := strings.ToUpper(strings.TrimSpace(string(strings.Split(string(msg), " ")[0]))) + args := strings.TrimSpace(strings.TrimPrefix(string(msg), strings.Split(string(msg), " ")[0])) + + switch cmd { + case "FORWARD-TCP": + config.Mode = ForwardModeTCP + return config, nil + + case "FORWARD-HTTP": + config.Mode = ForwardModeHTTP + if args != "" { + if err := g.parseForwardHTTPParams(args, config); err != nil { + return nil, fmt.Errorf("failed to parse HTTP parameters: %v", err) + } + } + + return config, nil + + default: + return nil, fmt.Errorf("invalid forward command: %s", cmd) + } } +} - // Create certificate pool for client CAs - clientCAPool := x509.NewCertPool() - var chainCerts [][]byte - chainData := []byte(g.certificates.PKI.ClientCertificateChain) - for { - block, rest := pem.Decode(chainData) - if block == nil { - break +func (g *Gateway) parseForwardHTTPParams(params string, config *ForwardConfig) error { + parts := strings.Fields(params) + + for _, part := range parts { + if strings.HasPrefix(part, "ca=") { + caB64 := strings.TrimPrefix(part, "ca=") + caCert, err := base64.StdEncoding.DecodeString(caB64) + if err != nil { + return fmt.Errorf("invalid base64 CA certificate: %v", err) + } + config.CACertificate = caCert + } else if strings.HasPrefix(part, "verify=") { + verifyStr := strings.TrimPrefix(part, "verify=") + verify, err := strconv.ParseBool(verifyStr) + if err != nil { + return fmt.Errorf("invalid verify parameter: %s", verifyStr) + } + config.VerifyTLS = verify } - chainCerts = append(chainCerts, block.Bytes) - chainData = rest } - for i, certBytes := range chainCerts { - cert, err := x509.ParseCertificate(certBytes) - if err != nil { - log.Printf("Failed to parse client chain certificate %d: %v", i+1, err) - continue + return nil +} + +// parseRoutingInfoFromCertificate extracts target host and port from client certificate custom extension +func (g *Gateway) parseRoutingInfoFromCertificate(tlsConn *tls.Conn, config *ForwardConfig) error { + const GATEWAY_ROUTING_INFO_OID = "1.3.6.1.4.1.12345.100.1" + + // Get the peer certificates + state := tlsConn.ConnectionState() + if len(state.PeerCertificates) == 0 { + return fmt.Errorf("no peer certificates found") + } + + clientCert := state.PeerCertificates[0] + + // Look for the routing extension + for _, ext := range clientCert.Extensions { + if ext.Id.String() == GATEWAY_ROUTING_INFO_OID { + var routingInfo RoutingInfo + if err := json.Unmarshal(ext.Value, &routingInfo); err != nil { + return fmt.Errorf("failed to parse routing info JSON: %v", err) + } + + config.TargetHost = routingInfo.TargetHost + config.TargetPort = routingInfo.TargetPort + + return nil } - clientCAPool.AddCert(cert) - log.Printf("Added client CA certificate %d to pool: %s", i+1, cert.Subject.CommonName) } - // Create TLS config - return &tls.Config{ - Certificates: []tls.Certificate{ - { - Certificate: [][]byte{serverCertBlock.Bytes}, - PrivateKey: serverKey, - }, - }, - ClientCAs: clientCAPool, - ClientAuth: tls.RequireAndVerifyClientCert, - MinVersion: tls.VersionTLS12, - }, nil + return fmt.Errorf("routing extension with OID %s not found in client certificate", GATEWAY_ROUTING_INFO_OID) } // virtualConnection implements net.Conn to bridge SSH channel and TLS diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 279dc2c8..7ae7589f 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -374,7 +374,6 @@ func (p *Proxy) handleClient(clientConn net.Conn) { } state := tlsConn.ConnectionState() - log.Printf("TLS handshake completed, peer certificates count: %d", len(state.PeerCertificates)) if len(state.PeerCertificates) > 0 { cert := state.PeerCertificates[0] From f7ed054857d4f203818197718c426b06d792f01d Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Tue, 2 Sep 2025 01:01:48 +0800 Subject: [PATCH 11/38] feat: added auth injection for k8 and platform checks --- packages/gateway-v2/connection.go | 61 +++++++++++++++++++++++++++- packages/gateway-v2/gateway.go | 46 ++++++++++++++------- packages/proxy/proxy.go | 66 ++++++++++++++----------------- 3 files changed, 119 insertions(+), 54 deletions(-) diff --git a/packages/gateway-v2/connection.go b/packages/gateway-v2/connection.go index ad521392..68517870 100644 --- a/packages/gateway-v2/connection.go +++ b/packages/gateway-v2/connection.go @@ -9,6 +9,7 @@ import ( "io" "net" "net/http" + "os" "strings" "time" @@ -19,7 +20,11 @@ func buildHttpInternalServerError(message string) string { return fmt.Sprintf("HTTP/1.1 500 Internal Server Error\r\nContent-Type: application/json\r\n\r\n{\"message\": \"gateway: %s\"}", message) } -func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, targetURL string, caCert []byte, verifyTLS bool) error { +func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, forwardConfig *ForwardConfig) error { + targetURL := fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort) + caCert := forwardConfig.CACertificate + verifyTLS := forwardConfig.VerifyTLS + transport := &http.Transport{ DisableKeepAlives: false, MaxIdleConns: 10, @@ -61,6 +66,57 @@ func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, targetURL string, caC } log.Info().Msgf("Received HTTP request: %s", req.URL.Path) + actionHeader := HttpProxyAction(req.Header.Get(INFISICAL_HTTP_PROXY_ACTION_HEADER)) + + // Only platform actor can perform privileged actions + if actionHeader != "" && forwardConfig.ActorType == ActorTypePlatform { + if actionHeader == HttpProxyActionInjectGatewayK8sServiceAccountToken { + token, err := os.ReadFile(KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH) + if err != nil { + conn.Write([]byte(buildHttpInternalServerError("failed to read k8s sa auth token"))) + continue // Continue to next request instead of returning + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(token))) + log.Info().Msgf("Injected gateway k8s SA auth token in request to %s", targetURL) + } else if actionHeader == HttpProxyActionUseGatewayK8sServiceAccount { // will work without a target URL set + // set the ca cert to the pod's k8s service account ca cert: + caCert, err := os.ReadFile(KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH) + if err != nil { + conn.Write([]byte(buildHttpInternalServerError("failed to read k8s sa ca cert"))) + continue + } + + caCertPool := x509.NewCertPool() + if ok := caCertPool.AppendCertsFromPEM(caCert); !ok { + conn.Write([]byte(buildHttpInternalServerError("failed to parse k8s sa ca cert"))) + continue + } + + transport.TLSClientConfig = &tls.Config{ + RootCAs: caCertPool, + } + + // set authorization header to the pod's k8s service account token: + token, err := os.ReadFile(KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH) + if err != nil { + conn.Write([]byte(buildHttpInternalServerError("failed to read k8s sa auth token"))) + continue + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(token))) + + // update the target URL to point to the kubernetes API server: + kubernetesServiceHost := os.Getenv(KUBERNETES_SERVICE_HOST_ENV_NAME) + kubernetesServicePort := os.Getenv(KUBERNETES_SERVICE_PORT_HTTPS_ENV_NAME) + + fullBaseUrl := fmt.Sprintf("https://%s:%s", kubernetesServiceHost, kubernetesServicePort) + targetURL = fullBaseUrl + + log.Info().Msgf("Redirected request to Kubernetes API server: %s", targetURL) + } + + req.Header.Del(INFISICAL_HTTP_PROXY_ACTION_HEADER) + } + // Build full target URL var targetFullURL string if strings.HasPrefix(targetURL, "http://") || strings.HasPrefix(targetURL, "https://") { @@ -121,7 +177,8 @@ func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, targetURL string, caC } } -func handleTCPProxy(conn *tls.Conn, target string) error { +func handleTCPProxy(conn *tls.Conn, forwardConfig *ForwardConfig) error { + target := fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort) localConn, err := net.Dial("tcp", target) if err != nil { log.Error().Msgf("Failed to connect to local service %s: %v", target, err) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 85f7b4ab..881876b9 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -32,6 +32,16 @@ const ( ForwardModeTCP ForwardMode = "TCP" ) +type ActorType string + +const ( + ActorTypePlatform ActorType = "platform" + ActorTypeUser ActorType = "user" +) + +const GATEWAY_ROUTING_INFO_OID = "1.3.6.1.4.1.12345.100.1" +const GATEWAY_ACTOR_OID = "1.3.6.1.4.1.12345.100.2" + // ForwardConfig contains the configuration for forwarding type ForwardConfig struct { Mode ForwardMode @@ -39,6 +49,7 @@ type ForwardConfig struct { VerifyTLS bool // Whether to verify TLS certificates TargetHost string TargetPort int + ActorType ActorType } // RoutingInfo represents the routing information embedded in client certificates @@ -47,6 +58,10 @@ type RoutingInfo struct { TargetPort int `json:"targetPort"` } +type ActorDetails struct { + Type string `json:"type"` +} + type GatewayConfig struct { Name string ProxyName string @@ -111,7 +126,7 @@ func (g *Gateway) Start(ctx context.Context) error { return nil default: if err := g.connectAndServe(); err != nil { - log.Printf("Connection failed: %v, retrying in %v...", err, g.config.ReconnectDelay) + log.Error().Msgf("Connection failed: %v, retrying in %v...", err, g.config.ReconnectDelay) select { case <-ctx.Done(): return ctx.Err() @@ -397,15 +412,13 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { return } - // Use target from certificate - target := fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort) - log.Info().Msgf("Using target from certificate: %s", target) + log.Info().Msgf("Forward config: %+v", forwardConfig) if forwardConfig.Mode == ForwardModeHTTP { - handleHTTPProxy(tlsConn, reader, target, forwardConfig.CACertificate, forwardConfig.VerifyTLS) + handleHTTPProxy(tlsConn, reader, forwardConfig) return } else if forwardConfig.Mode == ForwardModeTCP { - handleTCPProxy(tlsConn, target) + handleTCPProxy(tlsConn, forwardConfig) return } } @@ -413,7 +426,7 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { func (g *Gateway) parseForwardConfig(tlsConn *tls.Conn, reader *bufio.Reader) (*ForwardConfig, error) { config := &ForwardConfig{} - if err := g.parseRoutingInfoFromCertificate(tlsConn, config); err != nil { + if err := g.parseDetailsFromCertificate(tlsConn, config); err != nil { return nil, fmt.Errorf("failed to parse routing info from certificate: %v", err) } @@ -471,10 +484,7 @@ func (g *Gateway) parseForwardHTTPParams(params string, config *ForwardConfig) e return nil } -// parseRoutingInfoFromCertificate extracts target host and port from client certificate custom extension -func (g *Gateway) parseRoutingInfoFromCertificate(tlsConn *tls.Conn, config *ForwardConfig) error { - const GATEWAY_ROUTING_INFO_OID = "1.3.6.1.4.1.12345.100.1" - +func (g *Gateway) parseDetailsFromCertificate(tlsConn *tls.Conn, config *ForwardConfig) error { // Get the peer certificates state := tlsConn.ConnectionState() if len(state.PeerCertificates) == 0 { @@ -483,8 +493,8 @@ func (g *Gateway) parseRoutingInfoFromCertificate(tlsConn *tls.Conn, config *For clientCert := state.PeerCertificates[0] - // Look for the routing extension for _, ext := range clientCert.Extensions { + // Extract target host and port from client certificate custom extension if ext.Id.String() == GATEWAY_ROUTING_INFO_OID { var routingInfo RoutingInfo if err := json.Unmarshal(ext.Value, &routingInfo); err != nil { @@ -493,12 +503,18 @@ func (g *Gateway) parseRoutingInfoFromCertificate(tlsConn *tls.Conn, config *For config.TargetHost = routingInfo.TargetHost config.TargetPort = routingInfo.TargetPort - - return nil + } + // Extract actor type from client certificate custom extension + if ext.Id.String() == GATEWAY_ACTOR_OID { + var actorDetails ActorDetails + if err := json.Unmarshal(ext.Value, &actorDetails); err != nil { + return fmt.Errorf("failed to parse actor details JSON: %v", err) + } + config.ActorType = ActorType(actorDetails.Type) } } - return fmt.Errorf("routing extension with OID %s not found in client certificate", GATEWAY_ROUTING_INFO_OID) + return nil } // virtualConnection implements net.Conn to bridge SSH channel and TLS diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 7ae7589f..1b050a6d 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -9,14 +9,13 @@ import ( "encoding/pem" "fmt" "io" - "log" "net" - "strings" "sync" "github.com/Infisical/infisical-merge/packages/api" "github.com/Infisical/infisical-merge/packages/util" "github.com/go-resty/resty/v2" + "github.com/rs/zerolog/log" "golang.org/x/crypto/ssh" ) @@ -101,7 +100,7 @@ func (p *Proxy) Start(ctx context.Context) error { // Start TLS server go p.startTLSServer() - log.Printf("Proxy server started successfully") + log.Info().Msg("Proxy server started successfully") // Wait for context cancellation <-ctx.Done() @@ -131,7 +130,7 @@ func (p *Proxy) registerProxy() error { p.certificates = &certResp } - log.Printf("Successfully registered proxy and received certificates from API") + log.Info().Msg("Successfully registered proxy and received certificates from API") return nil } @@ -166,13 +165,13 @@ func (p *Proxy) setupSSHServer() error { // Check if this is an SSH certificate cert, ok := key.(*ssh.Certificate) if !ok { - log.Printf("Gateway '%s' tried to authenticate with raw public key (rejected)", conn.User()) + log.Warn().Msgf("Gateway '%s' tried to authenticate with raw public key (rejected)", conn.User()) return nil, fmt.Errorf("certificates required, raw public keys not allowed") } // Validate the certificate if err := p.validateSSHCertificate(cert, conn.User(), sshCAPubKey); err != nil { - log.Printf("Gateway '%s' certificate validation failed: %v", conn.User(), err) + log.Error().Msgf("Gateway '%s' certificate validation failed: %v", conn.User(), err) return nil, err } @@ -239,11 +238,10 @@ func (p *Proxy) setupTLSServer() error { for i, certBytes := range chainCerts { cert, err := x509.ParseCertificate(certBytes) if err != nil { - log.Printf("Failed to parse client chain certificate %d: %v", i+1, err) + log.Error().Msgf("Failed to parse client chain certificate %d: %v", i+1, err) continue } clientCAPool.AddCert(cert) - log.Printf("Added client CA certificate %d to pool: %s", i+1, cert.Subject.CommonName) } // Create TLS config @@ -268,7 +266,7 @@ func (p *Proxy) validateSSHCertificate(cert *ssh.Certificate, username string, c return fmt.Errorf("invalid certificate type: %d", cert.CertType) } - // Check if certificate is signed by our CA + // Check if certificate is signed expected CA checker := &ssh.CertChecker{ IsUserAuthority: func(auth ssh.PublicKey) bool { return bytes.Equal(auth.Marshal(), caPubKey.Marshal()) @@ -280,23 +278,23 @@ func (p *Proxy) validateSSHCertificate(cert *ssh.Certificate, username string, c return fmt.Errorf("certificate check failed: %v", err) } - log.Printf("SSH certificate valid for user '%s', principals: %v", username, cert.ValidPrincipals) + log.Debug().Msgf("SSH certificate valid for user '%s', principals: %v", username, cert.ValidPrincipals) return nil } func (p *Proxy) startSSHServer() { listener, err := net.Listen("tcp", ":"+p.config.SSHPort) if err != nil { - log.Fatalf("Failed to start SSH server: %v", err) + log.Fatal().Msgf("Failed to start SSH server: %v", err) } p.sshListener = listener - log.Printf("SSH server listening on :%s for gateways", p.config.SSHPort) + log.Info().Msgf("SSH server listening on :%s for gateways", p.config.SSHPort) for { conn, err := listener.Accept() if err != nil { - log.Printf("Failed to accept SSH connection: %v", err) + log.Error().Msgf("Failed to accept SSH connection: %v", err) continue } go p.handleSSHAgent(conn) @@ -309,12 +307,12 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { // SSH handshake sshConn, chans, _, err := ssh.NewServerConn(conn, p.sshConfig) if err != nil { - log.Printf("SSH handshake failed: %v", err) + log.Error().Msgf("SSH handshake failed: %v", err) return } gatewayId := sshConn.Permissions.Extensions["gateway-id"] - log.Printf("SSH handshake successful for gateway: %s", gatewayId) + log.Info().Msgf("SSH handshake successful for gateway: %s", gatewayId) // Store the connection p.mu.Lock() @@ -326,7 +324,7 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { p.mu.Lock() delete(p.tunnels, gatewayId) p.mu.Unlock() - log.Printf("Gateway %s disconnected", gatewayId) + log.Info().Msgf("Gateway %s disconnected", gatewayId) }() for newChannel := range chans { @@ -344,16 +342,16 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { func (p *Proxy) startTLSServer() { listener, err := tls.Listen("tcp", ":"+p.config.TLSPort, p.tlsConfig) if err != nil { - log.Fatalf("Failed to start TLS server: %v", err) + log.Fatal().Msgf("Failed to start TLS server: %v", err) } p.tlsListener = listener - log.Printf("TLS server listening on :%s for clients", p.config.TLSPort) + log.Info().Msgf("TLS server listening on :%s for clients", p.config.TLSPort) for { conn, err := listener.Accept() if err != nil { - log.Printf("Failed to accept TLS connection: %v", err) + log.Error().Msgf("Failed to accept TLS connection: %v", err) continue } go p.handleClient(conn) @@ -366,10 +364,10 @@ func (p *Proxy) handleClient(clientConn net.Conn) { var gatewayId string if tlsConn, ok := clientConn.(*tls.Conn); ok { - log.Printf("TLS connection detected, forcing handshake...") + log.Debug().Msg("TLS connection detected, forcing handshake...") err := tlsConn.Handshake() if err != nil { - log.Printf("TLS handshake failed: %v", err) + log.Error().Msgf("TLS handshake failed: %v", err) return } @@ -377,20 +375,14 @@ func (p *Proxy) handleClient(clientConn net.Conn) { if len(state.PeerCertificates) > 0 { cert := state.PeerCertificates[0] - log.Printf("Client connected with certificate: %s", cert.Subject.CommonName) - parts := strings.Split(cert.Subject.CommonName, ":") - if len(parts) >= 2 { - gatewayId = parts[1] - } else { - log.Printf("Invalid CommonName format, expected 'part1:part2', got: %s", cert.Subject.CommonName) - return - } + log.Info().Msgf("Client connected with certificate: %s", cert.Subject.CommonName) + gatewayId = cert.Subject.CommonName } else { - log.Printf("No peer certificates found") + log.Warn().Msg("No peer certificates found") return } } else { - log.Printf("Not a TLS connection, connection type: %T", clientConn) + log.Error().Msgf("Not a TLS connection, connection type: %T", clientConn) return } @@ -404,12 +396,12 @@ func (p *Proxy) handleClient(clientConn net.Conn) { p.mu.RUnlock() if !exists { - log.Printf("Gateway '%s' not connected", gatewayId) + log.Warn().Msgf("Gateway '%s' not connected", gatewayId) clientConn.Write([]byte("ERROR: Gateway not connected\n")) return } - log.Printf("Routing TCP connection to gateway: %s", gatewayId) + log.Info().Msgf("Routing TCP connection to gateway: %s", gatewayId) // Open SSH channel to connect to agent's local service through the tunnel payload := struct { @@ -421,7 +413,7 @@ func (p *Proxy) handleClient(clientConn net.Conn) { channel, _, err := conn.OpenChannel("direct-tcpip", ssh.Marshal(&payload)) if err != nil { - log.Printf("Failed to connect to agent: %v", err) + log.Error().Msgf("Failed to connect to agent: %v", err) clientConn.Write([]byte("ERROR: Failed to connect to agent\n")) return } @@ -434,11 +426,11 @@ func (p *Proxy) handleClient(clientConn net.Conn) { }() io.Copy(clientConn, channel) - log.Printf("Client %s disconnected", clientConn.RemoteAddr()) + log.Info().Msgf("Client %s disconnected", clientConn.RemoteAddr()) } func (p *Proxy) cleanup() { - log.Printf("Shutting down proxy server...") + log.Info().Msg("Shutting down proxy server...") if p.sshListener != nil { p.sshListener.Close() @@ -447,5 +439,5 @@ func (p *Proxy) cleanup() { p.tlsListener.Close() } - log.Printf("Proxy server shutdown complete") + log.Info().Msg("Proxy server shutdown complete") } From dc7a438f40ca56ebcb9f0dcedb924b504499f4a4 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Tue, 2 Sep 2025 21:49:11 +0800 Subject: [PATCH 12/38] feat: added heartbeat --- packages/api/api.go | 18 ++++++ packages/cmd/network.go | 26 ++++++--- packages/gateway-v2/connection.go | 89 +++++++++++++++++++++++++---- packages/gateway-v2/gateway.go | 95 ++++++++++++++++++++++++++----- 4 files changed, 197 insertions(+), 31 deletions(-) diff --git a/packages/api/api.go b/packages/api/api.go index 7d61eb8a..e20e6daf 100644 --- a/packages/api/api.go +++ b/packages/api/api.go @@ -39,6 +39,7 @@ const ( operationCallRegisterGatewayIdentityV1 = "CallRegisterGatewayIdentityV1" operationCallExchangeRelayCertV1 = "CallExchangeRelayCertV1" operationCallGatewayHeartBeatV1 = "CallGatewayHeartBeatV1" + operationCallGatewayHeartBeatV2 = "CallGatewayHeartBeatV2" operationCallBootstrapInstance = "CallBootstrapInstance" operationCallRegisterInstanceProxy = "CallRegisterInstanceProxy" operationCallRegisterOrgProxy = "CallRegisterOrgProxy" @@ -655,6 +656,23 @@ func CallGatewayHeartBeatV1(httpClient *resty.Client) error { return nil } +func CallGatewayHeartBeatV2(httpClient *resty.Client) error { + response, err := httpClient. + R(). + SetHeader("User-Agent", USER_AGENT). + Post(fmt.Sprintf("%v/v2/gateways/heartbeat", config.INFISICAL_URL)) + + if err != nil { + return NewGenericRequestError(operationCallGatewayHeartBeatV2, err) + } + + if response.IsError() { + return NewAPIErrorWithResponse(operationCallGatewayHeartBeatV2, response, nil) + } + + return nil +} + func CallBootstrapInstance(httpClient *resty.Client, request BootstrapInstanceRequest) (BootstrapInstanceResponse, error) { var resBody BootstrapInstanceResponse response, err := httpClient. diff --git a/packages/cmd/network.go b/packages/cmd/network.go index 2d6d67f4..2bf4d063 100644 --- a/packages/cmd/network.go +++ b/packages/cmd/network.go @@ -91,10 +91,15 @@ var networkProxyCmd = &cobra.Command{ cancelCmd() cancelSdk() - // If we get a second signal, force exit - <-sigCh - log.Warn().Msgf("Force exit triggered") - os.Exit(1) + // Give graceful shutdown 10 seconds, then force exit on second signal + select { + case <-sigCh: + log.Warn().Msg("Second signal received, force exit triggered") + os.Exit(1) + case <-time.After(10 * time.Second): + log.Info().Msg("Graceful shutdown completed") + os.Exit(0) + } }() // Token refresh goroutine - runs every 10 seconds @@ -192,10 +197,15 @@ var networkGatewayCmd = &cobra.Command{ cancelCmd() cancelSdk() - // If we get a second signal, force exit - <-sigCh - log.Warn().Msgf("Force exit triggered") - os.Exit(1) + // Give graceful shutdown 10 seconds, then force exit on second signal + select { + case <-sigCh: + log.Warn().Msg("Second signal received, force exit triggered") + os.Exit(1) + case <-time.After(10 * time.Second): + log.Info().Msg("Graceful shutdown completed") + os.Exit(0) + } }() // Token refresh goroutine - runs every 10 seconds diff --git a/packages/gateway-v2/connection.go b/packages/gateway-v2/connection.go index 68517870..141681f8 100644 --- a/packages/gateway-v2/connection.go +++ b/packages/gateway-v2/connection.go @@ -2,6 +2,7 @@ package gatewayv2 import ( "bufio" + "context" "crypto/tls" "crypto/x509" "errors" @@ -20,7 +21,7 @@ func buildHttpInternalServerError(message string) string { return fmt.Sprintf("HTTP/1.1 500 Internal Server Error\r\nContent-Type: application/json\r\n\r\n{\"message\": \"gateway: %s\"}", message) } -func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, forwardConfig *ForwardConfig) error { +func handleHTTPProxy(ctx context.Context, conn *tls.Conn, reader *bufio.Reader, forwardConfig *ForwardConfig) error { targetURL := fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort) caCert := forwardConfig.CACertificate verifyTLS := forwardConfig.VerifyTLS @@ -52,18 +53,45 @@ func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, forwardConfig *Forwar // Loop to handle multiple HTTP requests on the same connection for { + select { + case <-ctx.Done(): + log.Info().Msg("Context cancelled, closing HTTP proxy connection") + return ctx.Err() + default: + } + log.Info().Msg("Attempting to read HTTP request...") - req, err := http.ReadRequest(reader) - if err != nil { + // Create a channel to receive the request or error + reqCh := make(chan *http.Request, 1) + errCh := make(chan error, 1) + + // Read request in a goroutine so we can cancel it + go func() { + req, err := http.ReadRequest(reader) + if err != nil { + errCh <- err + } else { + reqCh <- req + } + }() + + var req *http.Request + select { + case <-ctx.Done(): + log.Info().Msg("Context cancelled while reading HTTP request") + return ctx.Err() + case err := <-errCh: if errors.Is(err, io.EOF) { log.Info().Msg("Client closed HTTP connection") return nil } - log.Error().Msgf("Failed to read HTTP request: %v", err) return fmt.Errorf("failed to read HTTP request: %v", err) + case req = <-reqCh: + // Successfully received request } + log.Info().Msgf("Received HTTP request: %s", req.URL.Path) actionHeader := HttpProxyAction(req.Header.Get(INFISICAL_HTTP_PROXY_ACTION_HEADER)) @@ -78,7 +106,8 @@ func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, forwardConfig *Forwar } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(token))) log.Info().Msgf("Injected gateway k8s SA auth token in request to %s", targetURL) - } else if actionHeader == HttpProxyActionUseGatewayK8sServiceAccount { // will work without a target URL set + } else if actionHeader == HttpProxyActionUseGatewayK8sServiceAccount { + // will work without a target URL set // set the ca cert to the pod's k8s service account ca cert: caCert, err := os.ReadFile(KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH) if err != nil { @@ -177,7 +206,7 @@ func handleHTTPProxy(conn *tls.Conn, reader *bufio.Reader, forwardConfig *Forwar } } -func handleTCPProxy(conn *tls.Conn, forwardConfig *ForwardConfig) error { +func handleTCPProxy(ctx context.Context, conn *tls.Conn, forwardConfig *ForwardConfig) error { target := fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort) localConn, err := net.Dial("tcp", target) if err != nil { @@ -186,15 +215,55 @@ func handleTCPProxy(conn *tls.Conn, forwardConfig *ForwardConfig) error { } defer localConn.Close() - // Create bidirectional tunnel with TLS + // Create a context for this connection that gets cancelled when the parent context is cancelled + // or when either connection closes + connCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Error channel to collect errors from both copy goroutines + errCh := make(chan error, 2) + // Forward data from TLS connection to local service go func() { - io.Copy(localConn, conn) - localConn.Close() + defer cancel() + _, err := io.Copy(localConn, conn) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + log.Debug().Msgf("TLS to local copy ended normally: %v", err) + } else { + log.Error().Msgf("TLS to local copy failed: %v", err) + } + } + errCh <- err }() // Forward data from local service to TLS connection - io.Copy(conn, localConn) + go func() { + defer cancel() + _, err := io.Copy(conn, localConn) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + log.Debug().Msgf("Local to TLS copy ended normally: %v", err) + } else { + log.Error().Msgf("Local to TLS copy failed: %v", err) + } + } + errCh <- err + }() + + // Wait for either context cancellation or one of the copy operations to complete + select { + case <-connCtx.Done(): + log.Info().Msg("TCP proxy connection cancelled") + return connCtx.Err() + case err := <-errCh: + // One of the copy operations completed (or failed) + // The defer cancel() will stop the other goroutine + return err + } +} +func handlePing(ctx context.Context, conn *tls.Conn, reader *bufio.Reader) error { + conn.Write([]byte("PONG\n")) return nil } diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 881876b9..d7c2c61b 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -30,6 +30,7 @@ type ForwardMode string const ( ForwardModeHTTP ForwardMode = "HTTP" ForwardModeTCP ForwardMode = "TCP" + ForwardModePing ForwardMode = "PING" ) type ActorType string @@ -116,9 +117,59 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) { }, nil } -// Change the Start method to accept a context +func (g *Gateway) registerHeartBeat(ctx context.Context, errCh chan error) { + sendHeartbeat := func() { + if err := api.CallGatewayHeartBeatV2(g.httpClient); err != nil { + log.Warn().Msgf("Heartbeat failed: %v", err) + select { + case errCh <- err: + default: + log.Warn().Msg("Error channel full, skipping heartbeat error report") + } + } else { + log.Info().Msg("Gateway is reachable by Infisical") + } + } + + go func() { + select { + case <-ctx.Done(): + return + case <-time.After(10 * time.Second): + sendHeartbeat() + } + + ticker := time.NewTicker(30 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + sendHeartbeat() + } + } + }() +} + func (g *Gateway) Start(ctx context.Context) error { log.Info().Msgf("Starting gateway") + + errCh := make(chan error, 1) + g.registerHeartBeat(ctx, errCh) + + go func() { + for { + select { + case <-ctx.Done(): + return + case err := <-errCh: + log.Warn().Msgf("Heartbeat error received: %v", err) + } + } + }() + for { select { case <-ctx.Done(): @@ -179,6 +230,7 @@ func (g *Gateway) connectAndServe() error { if err != nil { return fmt.Errorf("failed to connect to SSH server: %v", err) } + log.Info().Msgf("SSH connection established for gateway") g.mu.Lock() g.sshClient = client @@ -193,20 +245,33 @@ func (g *Gateway) connectAndServe() error { client.Close() }() - log.Info().Msgf("SSH connection established for gateway") - // Handle incoming channels from the server channels := client.HandleChannelOpen("direct-tcpip") if channels == nil { return fmt.Errorf("failed to handle channel open") } - // Process incoming channels - for newChannel := range channels { - go g.handleIncomingChannel(newChannel) - } + // Monitor for context cancellation and close SSH client + go func() { + <-g.ctx.Done() + log.Info().Msg("Context cancelled, closing SSH connection...") + client.Close() + }() - return nil // Connection closed + // Process incoming channels with context cancellation support + for { + select { + case <-g.ctx.Done(): + log.Info().Msg("Context cancelled, stopping channel processing") + return g.ctx.Err() + case newChannel, ok := <-channels: + if !ok { + log.Info().Msg("SSH channels closed") + return nil + } + go g.handleIncomingChannel(newChannel) + } + } } func (g *Gateway) registerGateway() error { @@ -352,7 +417,6 @@ func (g *Gateway) validateHostCertificate(cert *ssh.Certificate, hostname string return fmt.Errorf("host certificate check failed: %v", err) } - log.Info().Msgf("Host certificate validated successfully for %s", hostname) return nil } @@ -400,8 +464,6 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { return } - log.Info().Msgf("mTLS connection established with client") - // Create reader for the TLS connection reader := bufio.NewReader(tlsConn) @@ -415,10 +477,13 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { log.Info().Msgf("Forward config: %+v", forwardConfig) if forwardConfig.Mode == ForwardModeHTTP { - handleHTTPProxy(tlsConn, reader, forwardConfig) + handleHTTPProxy(g.ctx, tlsConn, reader, forwardConfig) return } else if forwardConfig.Mode == ForwardModeTCP { - handleTCPProxy(tlsConn, forwardConfig) + handleTCPProxy(g.ctx, tlsConn, forwardConfig) + return + } else if forwardConfig.Mode == ForwardModePing { + handlePing(g.ctx, tlsConn, reader) return } } @@ -454,6 +519,10 @@ func (g *Gateway) parseForwardConfig(tlsConn *tls.Conn, reader *bufio.Reader) (* return config, nil + case "PING": + config.Mode = ForwardModePing + return config, nil + default: return nil, fmt.Errorf("invalid forward command: %s", cmd) } From 2dbb176e4a45e9b6ebfa755ff2fee7b6aff5c4ad Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Tue, 2 Sep 2025 23:54:55 +0800 Subject: [PATCH 13/38] feat: added systemd support --- packages/cmd/network.go | 105 +++++++++++++++++++++---- packages/gateway-v2/constants.go | 4 +- packages/gateway-v2/systemd.go | 128 +++++++++++++++++++++++++++++++ 3 files changed, 219 insertions(+), 18 deletions(-) create mode 100644 packages/gateway-v2/systemd.go diff --git a/packages/cmd/network.go b/packages/cmd/network.go index 2bf4d063..cbd43ca4 100644 --- a/packages/cmd/network.go +++ b/packages/cmd/network.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "os/signal" + "runtime" "sync/atomic" "syscall" "time" @@ -24,9 +25,12 @@ var networkCmd = &cobra.Command{ } var networkProxyCmd = &cobra.Command{ - Use: "proxy", - Short: "Run the Infisical proxy component", - Long: "Run the Infisical proxy component", + Use: "proxy", + Short: "Run the Infisical proxy component", + Long: "Run the Infisical proxy component", + Example: "infisical network proxy --type=instance --ip= --name= --token=", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { proxyName, err := cmd.Flags().GetString("name") @@ -135,19 +139,13 @@ var networkProxyCmd = &cobra.Command{ }, } -var networkProxyInstallCmd = &cobra.Command{ - Use: "proxy install", - Short: "Install and enable systemd service for the proxy (requires sudo)", - Long: "Install and enable systemd service for the proxy. Must be run with sudo on Linux.", - Run: func(cmd *cobra.Command, args []string) { - // TODO: Implement this - }, -} - var networkGatewayCmd = &cobra.Command{ - Use: "gateway", - Short: "Run the Infisical gateway component", - Long: "Run the Infisical gateway component", + Use: "gateway", + Short: "Run the Infisical gateway component", + Long: "Run the Infisical gateway component. Use 'network gateway install' to set up the systemd service.", + Example: "infisical network gateway --proxy-name= --name= --token=", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { proxyName, err := util.GetCmdFlagOrEnv(cmd, "proxy-name", []string{gatewayv2.PROXY_NAME_ENV_NAME}) @@ -240,6 +238,75 @@ var networkGatewayCmd = &cobra.Command{ }, } +var networkGatewayInstallCmd = &cobra.Command{ + Use: "install", + Short: "Install and enable systemd service for the gateway (requires sudo)", + Long: "Install and enable systemd service for the gateway. Must be run with sudo on Linux.", + Example: "sudo infisical network gateway install --token= --domain= --name= --proxy-name=", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + if runtime.GOOS != "linux" { + util.HandleError(fmt.Errorf("systemd service installation is only supported on Linux")) + } + + if os.Geteuid() != 0 { + util.HandleError(fmt.Errorf("systemd service installation requires root/sudo privileges")) + } + + token, err := util.GetInfisicalToken(cmd) + if err != nil { + util.HandleError(err, "Unable to parse flag") + } + + if token == nil { + util.HandleError(errors.New("Token not found")) + } + + domain, err := cmd.Flags().GetString("domain") + if err != nil { + util.HandleError(err, "Unable to parse domain flag") + } + + gatewayName, err := cmd.Flags().GetString("name") + if err != nil { + util.HandleError(err, "Unable to parse name flag") + } + + proxyName, err := cmd.Flags().GetString("proxy-name") + if err != nil { + util.HandleError(err, "Unable to parse proxy-name flag") + } + + err = gatewayv2.InstallGatewaySystemdService(token.Token, domain, gatewayName, proxyName) + if err != nil { + util.HandleError(err, "Unable to install systemd service") + } + }, +} + +var networkGatewayUninstallCmd = &cobra.Command{ + Use: "uninstall", + Short: "Uninstall and remove systemd service for the gateway (requires sudo)", + Long: "Uninstall and remove systemd service for the gateway. Must be run with sudo on Linux.", + Example: "sudo infisical network gateway uninstall", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + if runtime.GOOS != "linux" { + util.HandleError(fmt.Errorf("systemd service installation is only supported on Linux")) + } + + if os.Geteuid() != 0 { + util.HandleError(fmt.Errorf("systemd service installation requires root/sudo privileges")) + } + + if err := gatewayv2.UninstallGatewaySystemdService(); err != nil { + util.HandleError(err, "Failed to uninstall systemd service") + } + }, +} + func init() { networkGatewayCmd.Flags().String("proxy-name", "", "The name of the proxy to connect to") networkGatewayCmd.Flags().String("name", "", "The name of the gateway") @@ -264,7 +331,13 @@ func init() { networkProxyCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") networkProxyCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") - networkProxyCmd.AddCommand(networkProxyInstallCmd) + networkGatewayInstallCmd.Flags().String("token", "", "Connect with Infisical using machine identity access token") + networkGatewayInstallCmd.Flags().String("domain", "", "Domain of your self-hosted Infisical instance") + networkGatewayInstallCmd.Flags().String("name", "", "The name of the gateway") + networkGatewayInstallCmd.Flags().String("proxy-name", "", "The name of the proxy") + + networkGatewayCmd.AddCommand(networkGatewayInstallCmd) + networkGatewayCmd.AddCommand(networkGatewayUninstallCmd) networkCmd.AddCommand(networkProxyCmd) networkCmd.AddCommand(networkGatewayCmd) diff --git a/packages/gateway-v2/constants.go b/packages/gateway-v2/constants.go index 996068d6..f746f558 100644 --- a/packages/gateway-v2/constants.go +++ b/packages/gateway-v2/constants.go @@ -6,8 +6,8 @@ const ( KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/token" - PROXY_NAME_ENV_NAME = "PROXY_NAME" - GATEWAY_NAME_ENV_NAME = "GATEWAY_NAME" + PROXY_NAME_ENV_NAME = "INFISICAL_PROXY_NAME" + GATEWAY_NAME_ENV_NAME = "INFISICAL_GATEWAY_NAME" PROXY_AUTH_SECRET_ENV_NAME = "PROXY_AUTH_SECRET" diff --git a/packages/gateway-v2/systemd.go b/packages/gateway-v2/systemd.go new file mode 100644 index 00000000..794509ea --- /dev/null +++ b/packages/gateway-v2/systemd.go @@ -0,0 +1,128 @@ +package gatewayv2 + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + + "github.com/rs/zerolog/log" +) + +const systemdServiceTemplate = `[Unit] +Description=Infisical Gateway Service +After=network.target + +[Service] +Type=notify +NotifyAccess=all +EnvironmentFile=/etc/infisical/gateway.conf +ExecStart=infisical network gateway +Restart=on-failure +InaccessibleDirectories=/home +PrivateTmp=yes +LimitCORE=infinity +LimitNOFILE=1000000 +LimitNPROC=60000 +LimitRTPRIO=infinity +LimitRTTIME=7000000 + +[Install] +WantedBy=multi-user.target +` + +func InstallGatewaySystemdService(token string, domain string, name string, proxyName string) error { + if runtime.GOOS != "linux" { + log.Info().Msg("Skipping systemd service installation - not on Linux") + return nil + } + + if os.Geteuid() != 0 { + log.Info().Msg("Skipping systemd service installation - not running as root/sudo") + return nil + } + + configDir := "/etc/infisical" + if err := os.MkdirAll(configDir, 0755); err != nil { + return fmt.Errorf("failed to create config directory: %v", err) + } + + configContent := fmt.Sprintf("INFISICAL_TOKEN=%s\n", token) + if domain != "" { + configContent += fmt.Sprintf("INFISICAL_API_URL=%s\n", domain) + } + + if name != "" { + configContent += fmt.Sprintf("%s=%s\n", GATEWAY_NAME_ENV_NAME, name) + } + if proxyName != "" { + configContent += fmt.Sprintf("%s=%s\n", PROXY_NAME_ENV_NAME, proxyName) + } + + configPath := filepath.Join(configDir, "gateway.conf") + if err := os.WriteFile(configPath, []byte(configContent), 0600); err != nil { + return fmt.Errorf("failed to write config file: %v", err) + } + + servicePath := "/etc/systemd/system/infisical-gateway.service" + if err := os.WriteFile(servicePath, []byte(systemdServiceTemplate), 0644); err != nil { + return fmt.Errorf("failed to write systemd service file: %v", err) + } + + reloadCmd := exec.Command("systemctl", "daemon-reload") + if err := reloadCmd.Run(); err != nil { + return fmt.Errorf("failed to reload systemd: %v", err) + } + + log.Info().Msg("Successfully installed systemd service") + log.Info().Msg("To start the service, run: sudo systemctl start infisical-gateway") + log.Info().Msg("To enable the service on boot, run: sudo systemctl enable infisical-gateway") + + return nil +} + +func UninstallGatewaySystemdService() error { + if runtime.GOOS != "linux" { + log.Info().Msg("Skipping systemd service uninstallation - not on Linux") + return nil + } + + if os.Geteuid() != 0 { + log.Info().Msg("Skipping systemd service uninstallation - not running as root/sudo") + return nil + } + + // Stop the service if it's running + stopCmd := exec.Command("systemctl", "stop", "infisical-gateway") + if err := stopCmd.Run(); err != nil { + log.Warn().Msgf("Failed to stop service: %v", err) + } + + // Disable the service + disableCmd := exec.Command("systemctl", "disable", "infisical-gateway") + if err := disableCmd.Run(); err != nil { + log.Warn().Msgf("Failed to disable service: %v", err) + } + + // Remove the service file + servicePath := "/etc/systemd/system/infisical-gateway.service" + if err := os.Remove(servicePath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove systemd service file: %v", err) + } + + // Remove the configuration file + configPath := "/etc/infisical/gateway.conf" + if err := os.Remove(configPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove config file: %v", err) + } + + // Reload systemd to apply changes + reloadCmd := exec.Command("systemctl", "daemon-reload") + if err := reloadCmd.Run(); err != nil { + return fmt.Errorf("failed to reload systemd: %v", err) + } + + log.Info().Msg("Successfully uninstalled Infisical Gateway systemd service") + return nil +} From 085de6d98cca611a42de692350413704ea1ee563 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 01:17:58 +0800 Subject: [PATCH 14/38] misc: added proxy name validation --- packages/cmd/network.go | 7 ++++++- packages/proxy/proxy.go | 10 ++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/packages/cmd/network.go b/packages/cmd/network.go index cbd43ca4..237aff9b 100644 --- a/packages/cmd/network.go +++ b/packages/cmd/network.go @@ -131,7 +131,6 @@ var networkProxyCmd = &cobra.Command{ }() } - // Use the same context for the proxy server err = proxyInstance.Start(cmd.Context()) if err != nil { util.HandleError(err, "unable to start proxy instance") @@ -272,11 +271,17 @@ var networkGatewayInstallCmd = &cobra.Command{ if err != nil { util.HandleError(err, "Unable to parse name flag") } + if gatewayName == "" { + util.HandleError(errors.New("Gateway name is required")) + } proxyName, err := cmd.Flags().GetString("proxy-name") if err != nil { util.HandleError(err, "Unable to parse proxy-name flag") } + if proxyName == "" { + util.HandleError(errors.New("Proxy name is required")) + } err = gatewayv2.InstallGatewaySystemdService(token.Token, domain, gatewayName, proxyName) if err != nil { diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 1b050a6d..dcd0f100 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -79,7 +79,6 @@ func (p *Proxy) SetToken(token string) { } func (p *Proxy) Start(ctx context.Context) error { - // Register proxy and get certificates from API if err := p.registerProxy(); err != nil { return fmt.Errorf("failed to register proxy: %v", err) } @@ -184,6 +183,13 @@ func (p *Proxy) setupSSHServer() error { return nil, fmt.Errorf("gateway id is required") } + // Validate that the user is authorized to connect to the current proxy + expectedKeyId := "client-" + p.config.ProxyName + if cert.KeyId != expectedKeyId { + log.Error().Msgf("Gateway '%s' certificate Key ID '%s' does not match expected '%s'", conn.User(), cert.KeyId, expectedKeyId) + return nil, fmt.Errorf("certificate Key ID does not match expected value") + } + return &ssh.Permissions{ Extensions: map[string]string{ "gateway-id": gatewayId, @@ -266,7 +272,7 @@ func (p *Proxy) validateSSHCertificate(cert *ssh.Certificate, username string, c return fmt.Errorf("invalid certificate type: %d", cert.CertType) } - // Check if certificate is signed expected CA + // Check if certificate is signed by expected CA checker := &ssh.CertChecker{ IsUserAuthority: func(auth ssh.PublicKey) bool { return bytes.Equal(auth.Marshal(), caPubKey.Marshal()) From 99091419edcfa6e67435d9a17a4096a1bac4f141 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 01:47:46 +0800 Subject: [PATCH 15/38] misc: added proxy cert auto-renewal --- packages/proxy/proxy.go | 46 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index dcd0f100..b301b26b 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -11,6 +11,7 @@ import ( "io" "net" "sync" + "time" "github.com/Infisical/infisical-merge/packages/api" "github.com/Infisical/infisical-merge/packages/util" @@ -93,6 +94,9 @@ func (p *Proxy) Start(ctx context.Context) error { return fmt.Errorf("failed to setup TLS server: %v", err) } + // Start certificate renewal goroutine + go p.startCertificateRenewal(ctx) + // Start SSH server go p.startSSHServer() @@ -447,3 +451,45 @@ func (p *Proxy) cleanup() { log.Info().Msg("Proxy server shutdown complete") } + +// startCertificateRenewal runs a background process to renew certificates every 24 hours +func (p *Proxy) startCertificateRenewal(ctx context.Context) { + log.Info().Msg("Starting certificate renewal goroutine") + ticker := time.NewTicker(30 * time.Second) // TODO: update this to be every 10 days + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Info().Msg("Certificate renewal goroutine stopping...") + return + case <-ticker.C: + log.Info().Msg("Checking certificates for renewal...") + if err := p.renewCertificates(); err != nil { + log.Error().Msgf("Failed to renew certificates: %v", err) + } else { + log.Info().Msg("Certificates renewed successfully") + } + } + } +} + +// renewCertificates fetches new certificates and updates the server configurations +func (p *Proxy) renewCertificates() error { + // Re-register proxy to get fresh certificates + if err := p.registerProxy(); err != nil { + return fmt.Errorf("failed to register proxy: %v", err) + } + + // Update SSH server configuration + if err := p.setupSSHServer(); err != nil { + return fmt.Errorf("failed to setup SSH server: %v", err) + } + + // Update TLS server configuration + if err := p.setupTLSServer(); err != nil { + return fmt.Errorf("failed to setup TLS server: %v", err) + } + + return nil +} From 6d0a02105aec6e0ad4b97353b3ccd1723c511aa0 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 02:00:15 +0800 Subject: [PATCH 16/38] misc: updated proxy tls server handling for cert renewal --- packages/proxy/proxy.go | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index b301b26b..829e5dbe 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -350,7 +350,7 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { } func (p *Proxy) startTLSServer() { - listener, err := tls.Listen("tcp", ":"+p.config.TLSPort, p.tlsConfig) + listener, err := net.Listen("tcp", ":"+p.config.TLSPort) if err != nil { log.Fatal().Msgf("Failed to start TLS server: %v", err) } @@ -364,10 +364,27 @@ func (p *Proxy) startTLSServer() { log.Error().Msgf("Failed to accept TLS connection: %v", err) continue } - go p.handleClient(conn) + go p.handleTLSClient(conn) } } +func (p *Proxy) handleTLSClient(conn net.Conn) { + defer conn.Close() + + // Perform TLS handshake using current TLS config + tlsConn := tls.Server(conn, p.tlsConfig) + defer tlsConn.Close() + + // Force TLS handshake + err := tlsConn.Handshake() + if err != nil { + log.Error().Msgf("TLS handshake failed: %v", err) + return + } + + p.handleClient(tlsConn) +} + func (p *Proxy) handleClient(clientConn net.Conn) { defer clientConn.Close() From b15233829bf7ec512a9d86f9fb99410fe4527b7b Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 02:08:23 +0800 Subject: [PATCH 17/38] misc: corrected client handling --- packages/proxy/proxy.go | 40 ++++++++++++---------------------------- 1 file changed, 12 insertions(+), 28 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 829e5dbe..97489bfa 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -385,35 +385,19 @@ func (p *Proxy) handleTLSClient(conn net.Conn) { p.handleClient(tlsConn) } -func (p *Proxy) handleClient(clientConn net.Conn) { - defer clientConn.Close() - +func (p *Proxy) handleClient(tlsConn *tls.Conn) { var gatewayId string + state := tlsConn.ConnectionState() - if tlsConn, ok := clientConn.(*tls.Conn); ok { - log.Debug().Msg("TLS connection detected, forcing handshake...") - err := tlsConn.Handshake() - if err != nil { - log.Error().Msgf("TLS handshake failed: %v", err) - return - } - - state := tlsConn.ConnectionState() - - if len(state.PeerCertificates) > 0 { - cert := state.PeerCertificates[0] - log.Info().Msgf("Client connected with certificate: %s", cert.Subject.CommonName) - gatewayId = cert.Subject.CommonName - } else { - log.Warn().Msg("No peer certificates found") - return - } + if len(state.PeerCertificates) > 0 { + cert := state.PeerCertificates[0] + log.Info().Msgf("Client connected with certificate: %s", cert.Subject.CommonName) + gatewayId = cert.Subject.CommonName } else { - log.Error().Msgf("Not a TLS connection, connection type: %T", clientConn) + log.Warn().Msg("No peer certificates found") return } - // TODO: extract these from the certificate targetHost := "gateway" targetPort := uint32(22) @@ -424,7 +408,7 @@ func (p *Proxy) handleClient(clientConn net.Conn) { if !exists { log.Warn().Msgf("Gateway '%s' not connected", gatewayId) - clientConn.Write([]byte("ERROR: Gateway not connected\n")) + tlsConn.Write([]byte("ERROR: Gateway not connected\n")) return } @@ -441,19 +425,19 @@ func (p *Proxy) handleClient(clientConn net.Conn) { channel, _, err := conn.OpenChannel("direct-tcpip", ssh.Marshal(&payload)) if err != nil { log.Error().Msgf("Failed to connect to agent: %v", err) - clientConn.Write([]byte("ERROR: Failed to connect to agent\n")) + tlsConn.Write([]byte("ERROR: Failed to connect to agent\n")) return } defer channel.Close() // Bidirectional forwarding go func() { - io.Copy(channel, clientConn) + io.Copy(channel, tlsConn) channel.CloseWrite() }() - io.Copy(clientConn, channel) - log.Info().Msgf("Client %s disconnected", clientConn.RemoteAddr()) + io.Copy(tlsConn, channel) + log.Info().Msgf("Client %s disconnected", tlsConn.RemoteAddr()) } func (p *Proxy) cleanup() { From 3bcf34c7ff0a8237e969fbec31f057f664ccb7aa Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 02:51:15 +0800 Subject: [PATCH 18/38] misc: addeed tls connection accept log --- packages/proxy/proxy.go | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 97489bfa..22b75fd2 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -364,6 +364,7 @@ func (p *Proxy) startTLSServer() { log.Error().Msgf("Failed to accept TLS connection: %v", err) continue } + log.Info().Msgf("TLS connection accepted from %s", conn.RemoteAddr()) go p.handleTLSClient(conn) } } From 9ccf30bfdb5aab88f008360029b42f0db24cee33 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 03:02:02 +0800 Subject: [PATCH 19/38] misc: add connection deadline for unauthenticated requests --- packages/proxy/proxy.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 22b75fd2..3fbeb9b3 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -364,7 +364,6 @@ func (p *Proxy) startTLSServer() { log.Error().Msgf("Failed to accept TLS connection: %v", err) continue } - log.Info().Msgf("TLS connection accepted from %s", conn.RemoteAddr()) go p.handleTLSClient(conn) } } @@ -376,13 +375,19 @@ func (p *Proxy) handleTLSClient(conn net.Conn) { tlsConn := tls.Server(conn, p.tlsConfig) defer tlsConn.Close() + // Set handshake timeout to avoid hanging on slow/malicious connections + tlsConn.SetDeadline(time.Now().Add(10 * time.Second)) + // Force TLS handshake err := tlsConn.Handshake() if err != nil { - log.Error().Msgf("TLS handshake failed: %v", err) + log.Debug().Msgf("TLS handshake failed from %s: %v", conn.RemoteAddr(), err) return } + // Clear deadline for actual data transfer + tlsConn.SetDeadline(time.Time{}) + p.handleClient(tlsConn) } From d39ef05297d27457fc9c8875724c1831548895cb Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 03:05:40 +0800 Subject: [PATCH 20/38] misc: finalized cert renewal interval to 10 days --- packages/proxy/proxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 3fbeb9b3..8d1a327c 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -462,7 +462,7 @@ func (p *Proxy) cleanup() { // startCertificateRenewal runs a background process to renew certificates every 24 hours func (p *Proxy) startCertificateRenewal(ctx context.Context) { log.Info().Msg("Starting certificate renewal goroutine") - ticker := time.NewTicker(30 * time.Second) // TODO: update this to be every 10 days + ticker := time.NewTicker(10 * 24 * time.Hour) defer ticker.Stop() for { From 60655841cd5338ee9dfd66038c9d330c358408b4 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 03:23:34 +0800 Subject: [PATCH 21/38] misc: add cert renewal to gateway server --- packages/gateway-v2/gateway.go | 45 +++++++++++++++++++++++++++++++++- packages/proxy/proxy.go | 2 +- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index d7c2c61b..98ac68f0 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -159,6 +159,9 @@ func (g *Gateway) Start(ctx context.Context) error { errCh := make(chan error, 1) g.registerHeartBeat(ctx, errCh) + // Start certificate renewal goroutine + go g.startCertificateRenewal(ctx) + go func() { for { select { @@ -289,7 +292,15 @@ func (g *Gateway) registerGateway() error { g.certificates = &certResp log.Info().Msgf("Successfully registered gateway and received certificates") - // Create mTLS config once during registration + // Setup mTLS config + if err := g.setupTLSConfig(); err != nil { + return fmt.Errorf("failed to setup TLS config: %v", err) + } + + return nil +} + +func (g *Gateway) setupTLSConfig() error { serverCertBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerCertificate)) if serverCertBlock == nil { return fmt.Errorf("failed to decode server certificate") @@ -622,3 +633,35 @@ func (vc *virtualConnection) SetReadDeadline(t time.Time) error { func (vc *virtualConnection) SetWriteDeadline(t time.Time) error { return nil } + +// startCertificateRenewal runs a background process to renew certificates every 10 days +func (g *Gateway) startCertificateRenewal(ctx context.Context) { + log.Info().Msg("Starting gateway certificate renewal goroutine") + ticker := time.NewTicker(10 * 24 * time.Hour) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Info().Msg("Gateway certificate renewal goroutine stopping...") + return + case <-ticker.C: + log.Info().Msg("Renewing gateway certificates...") + if err := g.renewCertificates(); err != nil { + log.Error().Msgf("Failed to renew gateway certificates: %v", err) + } else { + log.Info().Msg("Gateway certificates renewed successfully") + } + } + } +} + +// renewCertificates fetches new certificates and updates the gateway configurations +func (g *Gateway) renewCertificates() error { + // Re-register gateway to get fresh certificates + if err := g.registerGateway(); err != nil { + return fmt.Errorf("failed to register gateway: %v", err) + } + + return nil +} diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 8d1a327c..7d5d6251 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -471,7 +471,7 @@ func (p *Proxy) startCertificateRenewal(ctx context.Context) { log.Info().Msg("Certificate renewal goroutine stopping...") return case <-ticker.C: - log.Info().Msg("Checking certificates for renewal...") + log.Info().Msg("Renewing certificates...") if err := p.renewCertificates(); err != nil { log.Error().Msgf("Failed to renew certificates: %v", err) } else { From 4e6ee387be5f00ca3c0d9cb59601b9682bd19d60 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 19:04:01 +0800 Subject: [PATCH 22/38] misc: used non-standard port for proxy TLS --- packages/cmd/network.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/cmd/network.go b/packages/cmd/network.go index 237aff9b..753e7d22 100644 --- a/packages/cmd/network.go +++ b/packages/cmd/network.go @@ -51,7 +51,7 @@ var networkProxyCmd = &cobra.Command{ proxyInstance, err := proxy.NewProxy(&proxy.ProxyConfig{ ProxyName: proxyName, SSHPort: "2222", - TLSPort: "443", + TLSPort: "8443", StaticIP: ip, Type: instanceType, }) From 8eaf2a5ff18e5c530aa4d713c3c0d433cfdb5dbf Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 19:55:22 +0800 Subject: [PATCH 23/38] misc: improved security posture of proxy server --- packages/proxy/proxy.go | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 7d5d6251..c029bf88 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -164,6 +164,12 @@ func (p *Proxy) setupSSHServer() error { // Setup SSH server config p.sshConfig = &ssh.ServerConfig{ + MaxAuthTries: 3, + AuthLogCallback: func(conn ssh.ConnMetadata, method string, err error) { + if err != nil { + log.Warn().Msgf("Auth failed for %s@%s using %s: %v", conn.User(), conn.RemoteAddr(), method, err) + } + }, PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { // Check if this is an SSH certificate cert, ok := key.(*ssh.Certificate) @@ -315,7 +321,7 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { defer conn.Close() // SSH handshake - sshConn, chans, _, err := ssh.NewServerConn(conn, p.sshConfig) + sshConn, chans, reqs, err := ssh.NewServerConn(conn, p.sshConfig) if err != nil { log.Error().Msgf("SSH handshake failed: %v", err) return @@ -324,8 +330,16 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { gatewayId := sshConn.Permissions.Extensions["gateway-id"] log.Info().Msgf("SSH handshake successful for gateway: %s", gatewayId) - // Store the connection + // Store the connection (ensure only one connection per gateway) p.mu.Lock() + if existingConn, exists := p.tunnels[gatewayId]; exists { + p.mu.Unlock() + log.Warn().Msgf("Gateway '%s' already has an active connection, rejecting new connection", gatewayId) + sshConn.Close() + existingConn.Close() // Also close the existing connection to force re-auth + return + } + p.tunnels[gatewayId] = sshConn p.mu.Unlock() @@ -337,14 +351,34 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { log.Info().Msgf("Gateway %s disconnected", gatewayId) }() + // Handle global requests (reject all for security) + go func() { + for req := range reqs { + log.Debug().Msgf("Rejecting global request: %s from gateway %s", req.Type, gatewayId) + if req.WantReply { + req.Reply(false, nil) + } + } + }() + + // Handle channel requests for newChannel := range chans { switch newChannel.ChannelType() { case "session": + log.Debug().Msgf("Rejecting session channel from gateway %s", gatewayId) newChannel.Reject(ssh.Prohibited, "no shell access") case "x11": + log.Debug().Msgf("Rejecting X11 forwarding from gateway %s", gatewayId) newChannel.Reject(ssh.Prohibited, "no X11 forwarding") case "auth-agent": + log.Debug().Msgf("Rejecting auth-agent forwarding from gateway %s", gatewayId) newChannel.Reject(ssh.Prohibited, "no agent forwarding") + case "forwarded-tcpip": + log.Debug().Msgf("Rejecting forwarded-tcpip from gateway %s", gatewayId) + newChannel.Reject(ssh.Prohibited, "no port forwarding") + default: + log.Warn().Msgf("Rejecting unknown channel type '%s' from gateway %s", newChannel.ChannelType(), gatewayId) + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") } } } From ce41396d4a278749eef85691a1ebe2003a1bfb1d Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 20:11:48 +0800 Subject: [PATCH 24/38] misc: added sending of error message when multiple gateway is detected --- packages/proxy/proxy.go | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index c029bf88..de59db16 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -332,11 +332,23 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { // Store the connection (ensure only one connection per gateway) p.mu.Lock() - if existingConn, exists := p.tunnels[gatewayId]; exists { + if _, exists := p.tunnels[gatewayId]; exists { p.mu.Unlock() log.Warn().Msgf("Gateway '%s' already has an active connection, rejecting new connection", gatewayId) - sshConn.Close() - existingConn.Close() // Also close the existing connection to force re-auth + + // Send error message to the new connection before closing + go func() { + // Send a global request with error information + _, _, err := sshConn.SendRequest("duplicate-connection-error", false, []byte(fmt.Sprintf("Gateway '%s' already has an active connection. Only one connection per gateway is allowed.", gatewayId))) + if err != nil { + log.Debug().Msgf("Failed to send duplicate connection error message to gateway '%s': %v", gatewayId, err) + } + + // Give a moment for the message to be sent before closing + time.Sleep(1000 * time.Millisecond) + sshConn.Close() + }() + return } From c51d31f02f0a87f9e8cc1300391d6f65122bcc00 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 20:15:14 +0800 Subject: [PATCH 25/38] Revert "misc: added sending of error message when multiple gateway is detected" This reverts commit ce41396d4a278749eef85691a1ebe2003a1bfb1d. --- packages/proxy/proxy.go | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index de59db16..c029bf88 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -332,23 +332,11 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { // Store the connection (ensure only one connection per gateway) p.mu.Lock() - if _, exists := p.tunnels[gatewayId]; exists { + if existingConn, exists := p.tunnels[gatewayId]; exists { p.mu.Unlock() log.Warn().Msgf("Gateway '%s' already has an active connection, rejecting new connection", gatewayId) - - // Send error message to the new connection before closing - go func() { - // Send a global request with error information - _, _, err := sshConn.SendRequest("duplicate-connection-error", false, []byte(fmt.Sprintf("Gateway '%s' already has an active connection. Only one connection per gateway is allowed.", gatewayId))) - if err != nil { - log.Debug().Msgf("Failed to send duplicate connection error message to gateway '%s': %v", gatewayId, err) - } - - // Give a moment for the message to be sent before closing - time.Sleep(1000 * time.Millisecond) - sshConn.Close() - }() - + sshConn.Close() + existingConn.Close() // Also close the existing connection to force re-auth return } From 21d61c1a1ec3714ff98e9418585da87485841904 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 20:16:11 +0800 Subject: [PATCH 26/38] misc: only close new connection for duplicate gateway --- packages/proxy/proxy.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index c029bf88..533028b9 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -332,11 +332,10 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { // Store the connection (ensure only one connection per gateway) p.mu.Lock() - if existingConn, exists := p.tunnels[gatewayId]; exists { + if _, exists := p.tunnels[gatewayId]; exists { p.mu.Unlock() log.Warn().Msgf("Gateway '%s' already has an active connection, rejecting new connection", gatewayId) sshConn.Close() - existingConn.Close() // Also close the existing connection to force re-auth return } From 7d2276fd6f834aa84ad74bed9d8415ae36f9be47 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 20:47:31 +0800 Subject: [PATCH 27/38] misc: decreased tls deadline --- packages/proxy/proxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 533028b9..bfbf0f5c 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -409,7 +409,7 @@ func (p *Proxy) handleTLSClient(conn net.Conn) { defer tlsConn.Close() // Set handshake timeout to avoid hanging on slow/malicious connections - tlsConn.SetDeadline(time.Now().Add(10 * time.Second)) + tlsConn.SetDeadline(time.Now().Add(5 * time.Second)) // Force TLS handshake err := tlsConn.Handshake() From e5a426d1680755c279c4b30b5de32191bf22566a Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 21:45:29 +0800 Subject: [PATCH 28/38] misc: addressed greptile --- packages/gateway-v2/constants.go | 2 +- packages/gateway-v2/gateway.go | 23 ++++++----------------- packages/proxy/proxy.go | 13 +------------ 3 files changed, 8 insertions(+), 30 deletions(-) diff --git a/packages/gateway-v2/constants.go b/packages/gateway-v2/constants.go index f746f558..de54cd6f 100644 --- a/packages/gateway-v2/constants.go +++ b/packages/gateway-v2/constants.go @@ -9,7 +9,7 @@ const ( PROXY_NAME_ENV_NAME = "INFISICAL_PROXY_NAME" GATEWAY_NAME_ENV_NAME = "INFISICAL_GATEWAY_NAME" - PROXY_AUTH_SECRET_ENV_NAME = "PROXY_AUTH_SECRET" + PROXY_AUTH_SECRET_ENV_NAME = "INFISICAL_PROXY_AUTH_SECRET" INFISICAL_HTTP_PROXY_ACTION_HEADER = "x-infisical-action" ) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 98ac68f0..bbad19f9 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "context" - "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/base64" @@ -83,8 +82,6 @@ type Gateway struct { // mTLS server components tlsConfig *tls.Config - tlsCACert []byte - tlsCAKey *rsa.PrivateKey // Connection management mu sync.RWMutex @@ -364,8 +361,13 @@ func (g *Gateway) createSSHConfig() (*ssh.ClientConfig, error) { return nil, fmt.Errorf("failed to parse certificate: %v", err) } + sshCert, ok := cert.(*ssh.Certificate) + if !ok { + return nil, fmt.Errorf("parsed key is not an SSH certificate, got type: %T", cert) + } + // Create certificate signer - certSigner, err := ssh.NewCertSigner(cert.(*ssh.Certificate), privateKey) + certSigner, err := ssh.NewCertSigner(sshCert, privateKey) if err != nil { return nil, fmt.Errorf("failed to create certificate signer: %v", err) } @@ -432,19 +434,6 @@ func (g *Gateway) validateHostCertificate(cert *ssh.Certificate, hostname string } func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { - var req struct { - Host string - Port uint32 - OriginHost string - OriginPort uint32 - } - - if err := ssh.Unmarshal(newChannel.ExtraData(), &req); err != nil { - log.Info().Msgf("Failed to parse channel request: %v", err) - newChannel.Reject(ssh.Prohibited, "invalid request") - return - } - channel, requests, err := newChannel.Accept() if err != nil { log.Info().Msgf("Failed to accept channel: %v", err) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index bfbf0f5c..26af7e15 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -437,9 +437,6 @@ func (p *Proxy) handleClient(tlsConn *tls.Conn) { return } - targetHost := "gateway" - targetPort := uint32(22) - // Get the SSH connection for this agent p.mu.RLock() conn, exists := p.tunnels[gatewayId] @@ -453,15 +450,7 @@ func (p *Proxy) handleClient(tlsConn *tls.Conn) { log.Info().Msgf("Routing TCP connection to gateway: %s", gatewayId) - // Open SSH channel to connect to agent's local service through the tunnel - payload := struct { - Host string - Port uint32 - _ string - _ uint32 - }{targetHost, targetPort, "", 0} - - channel, _, err := conn.OpenChannel("direct-tcpip", ssh.Marshal(&payload)) + channel, _, err := conn.OpenChannel("direct-tcpip", nil) if err != nil { log.Error().Msgf("Failed to connect to agent: %v", err) tlsConn.Write([]byte("ERROR: Failed to connect to agent\n")) From fcdc1456df2c98d294bca56349cd3c46a05af950 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Wed, 3 Sep 2025 21:48:55 +0800 Subject: [PATCH 29/38] misc: removed proxy auth logging --- packages/proxy/proxy.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/packages/proxy/proxy.go b/packages/proxy/proxy.go index 26af7e15..2a2a6301 100644 --- a/packages/proxy/proxy.go +++ b/packages/proxy/proxy.go @@ -165,11 +165,6 @@ func (p *Proxy) setupSSHServer() error { // Setup SSH server config p.sshConfig = &ssh.ServerConfig{ MaxAuthTries: 3, - AuthLogCallback: func(conn ssh.ConnMetadata, method string, err error) { - if err != nil { - log.Warn().Msgf("Auth failed for %s@%s using %s: %v", conn.User(), conn.RemoteAddr(), method, err) - } - }, PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { // Check if this is an SSH certificate cert, ok := key.(*ssh.Certificate) From fc62acd90cbdf6ac0da64cdbfdf93a5a7cc8e740 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Thu, 4 Sep 2025 04:05:30 +0800 Subject: [PATCH 30/38] misc: updated gateway logs --- packages/gateway-v2/gateway.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index bbad19f9..46553798 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -225,12 +225,12 @@ func (g *Gateway) connectAndServe() error { } // Connect to Proxy server - log.Info().Msgf("Connecting to SSH server on %s:%d...", g.certificates.ProxyIP, g.config.SSHPort) + log.Info().Msgf("Connecting to proxy server %s on %s:%d...", g.config.ProxyName, g.certificates.ProxyIP, g.config.SSHPort) client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", g.certificates.ProxyIP, g.config.SSHPort), sshConfig) if err != nil { return fmt.Errorf("failed to connect to SSH server: %v", err) } - log.Info().Msgf("SSH connection established for gateway") + log.Info().Msgf("Proxy connection established for gateway") g.mu.Lock() g.sshClient = client @@ -254,7 +254,7 @@ func (g *Gateway) connectAndServe() error { // Monitor for context cancellation and close SSH client go func() { <-g.ctx.Done() - log.Info().Msg("Context cancelled, closing SSH connection...") + log.Info().Msg("Context cancelled, closing proxy connection...") client.Close() }() From 7e9a71a7d2b26abf936c56562d3f47f5eafda482 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Tue, 9 Sep 2025 02:09:46 +0800 Subject: [PATCH 31/38] misc: updated proxy terminology to relay and finalized command location --- packages/api/api.go | 24 +-- packages/api/model.go | 8 +- packages/cmd/gateway.go | 226 +++++++++++++++++++- packages/cmd/network.go | 152 ++----------- packages/cmd/relay.go | 156 ++++++++++++++ packages/gateway-v2/constants.go | 4 +- packages/gateway-v2/gateway.go | 14 +- packages/gateway-v2/systemd.go | 8 +- packages/{proxy/proxy.go => relay/relay.go} | 184 ++++++++-------- 9 files changed, 511 insertions(+), 265 deletions(-) create mode 100644 packages/cmd/relay.go rename packages/{proxy/proxy.go => relay/relay.go} (72%) diff --git a/packages/api/api.go b/packages/api/api.go index e20e6daf..352465ea 100644 --- a/packages/api/api.go +++ b/packages/api/api.go @@ -41,8 +41,8 @@ const ( operationCallGatewayHeartBeatV1 = "CallGatewayHeartBeatV1" operationCallGatewayHeartBeatV2 = "CallGatewayHeartBeatV2" operationCallBootstrapInstance = "CallBootstrapInstance" - operationCallRegisterInstanceProxy = "CallRegisterInstanceProxy" - operationCallRegisterOrgProxy = "CallRegisterOrgProxy" + operationCallRegisterInstanceRelay = "CallRegisterInstanceRelay" + operationCallRegisterOrgRelay = "CallRegisterOrgRelay" operationCallRegisterGateway = "CallRegisterGateway" ) @@ -693,41 +693,41 @@ func CallBootstrapInstance(httpClient *resty.Client, request BootstrapInstanceRe return resBody, nil } -func CallRegisterInstanceProxy(httpClient *resty.Client, request RegisterProxyRequest) (RegisterProxyResponse, error) { - var resBody RegisterProxyResponse +func CallRegisterInstanceRelay(httpClient *resty.Client, request RegisterRelayRequest) (RegisterRelayResponse, error) { + var resBody RegisterRelayResponse response, err := httpClient. R(). SetResult(&resBody). SetHeader("User-Agent", USER_AGENT). SetBody(request). - Post(fmt.Sprintf("%v/v1/proxies/register-instance-proxy", config.INFISICAL_URL)) + Post(fmt.Sprintf("%v/v1/relays/register-instance-relay", config.INFISICAL_URL)) if err != nil { - return RegisterProxyResponse{}, NewGenericRequestError(operationCallRegisterInstanceProxy, err) + return RegisterRelayResponse{}, NewGenericRequestError(operationCallRegisterInstanceRelay, err) } if response.IsError() { - return RegisterProxyResponse{}, NewAPIErrorWithResponse(operationCallRegisterInstanceProxy, response, nil) + return RegisterRelayResponse{}, NewAPIErrorWithResponse(operationCallRegisterInstanceRelay, response, nil) } return resBody, nil } -func CallRegisterProxy(httpClient *resty.Client, request RegisterProxyRequest) (RegisterProxyResponse, error) { - var resBody RegisterProxyResponse +func CallRegisterRelay(httpClient *resty.Client, request RegisterRelayRequest) (RegisterRelayResponse, error) { + var resBody RegisterRelayResponse response, err := httpClient. R(). SetResult(&resBody). SetHeader("User-Agent", USER_AGENT). SetBody(request). - Post(fmt.Sprintf("%v/v1/proxies/register-org-proxy", config.INFISICAL_URL)) + Post(fmt.Sprintf("%v/v1/relays/register-org-relay", config.INFISICAL_URL)) if err != nil { - return RegisterProxyResponse{}, NewGenericRequestError(operationCallRegisterOrgProxy, err) + return RegisterRelayResponse{}, NewGenericRequestError(operationCallRegisterOrgRelay, err) } if response.IsError() { - return RegisterProxyResponse{}, NewAPIErrorWithResponse(operationCallRegisterOrgProxy, response, nil) + return RegisterRelayResponse{}, NewAPIErrorWithResponse(operationCallRegisterOrgRelay, response, nil) } return resBody, nil diff --git a/packages/api/model.go b/packages/api/model.go index c436d117..a78c76d8 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -704,12 +704,12 @@ type BootstrapUser struct { SuperAdmin bool `json:"superAdmin"` } -type RegisterProxyRequest struct { +type RegisterRelayRequest struct { IP string `json:"ip"` Name string `json:"name"` } -type RegisterProxyResponse struct { +type RegisterRelayResponse struct { PKI struct { ServerCertificate string `json:"serverCertificate"` ServerPrivateKey string `json:"serverPrivateKey"` @@ -723,13 +723,13 @@ type RegisterProxyResponse struct { } type RegisterGatewayRequest struct { - ProxyName string `json:"proxyName"` + RelayName string `json:"relayName"` Name string `json:"name"` } type RegisterGatewayResponse struct { GatewayID string `json:"gatewayId"` - ProxyIP string `json:"proxyIp"` + RelayIP string `json:"relayIp"` PKI struct { ServerCertificate string `json:"serverCertificate"` ServerPrivateKey string `json:"serverPrivateKey"` diff --git a/packages/cmd/gateway.go b/packages/cmd/gateway.go index abc4d694..8dca18c9 100644 --- a/packages/cmd/gateway.go +++ b/packages/cmd/gateway.go @@ -14,6 +14,7 @@ import ( "github.com/Infisical/infisical-merge/packages/api" "github.com/Infisical/infisical-merge/packages/config" "github.com/Infisical/infisical-merge/packages/gateway" + gatewayv2 "github.com/Infisical/infisical-merge/packages/gateway-v2" "github.com/Infisical/infisical-merge/packages/util" infisicalSdk "github.com/infisical/go-sdk" "github.com/pkg/errors" @@ -87,6 +88,8 @@ var gatewayCmd = &cobra.Command{ DisableFlagsInUseLine: true, Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { + log.Warn().Msg("DEPRECATION WARNING: The 'infisical gateway' command is deprecated. Please use 'infisical gateway start'") + log.Warn().Msg("This legacy gateway will be removed in a future version.") infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) if err != nil { @@ -199,6 +202,105 @@ var gatewayCmd = &cobra.Command{ }, } +var gatewayStartCmd = &cobra.Command{ + Use: "start", + Short: "Start the new Infisical gateway", + Long: "Start the new Infisical gateway component.", + Example: "infisical gateway start --relay= --name= --token=", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + relayName, err := util.GetCmdFlagOrEnv(cmd, "relay", []string{gatewayv2.RELAY_NAME_ENV_NAME}) + if err != nil { + util.HandleError(err, fmt.Sprintf("unable to get relay flag or %s env", gatewayv2.RELAY_NAME_ENV_NAME)) + } + + gatewayName, err := util.GetCmdFlagOrEnv(cmd, "name", []string{gatewayv2.GATEWAY_NAME_ENV_NAME}) + if err != nil { + util.HandleError(err, fmt.Sprintf("unable to get name flag or %s env", gatewayv2.GATEWAY_NAME_ENV_NAME)) + } + + gatewayInstance, err := gatewayv2.NewGateway(&gatewayv2.GatewayConfig{ + Name: gatewayName, + RelayName: relayName, + ReconnectDelay: 10 * time.Second, + }) + + if err != nil { + util.HandleError(err, "unable to create gateway instance") + } + + infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) + if err != nil { + util.HandleError(err, "unable to get infisical client") + } + defer cancelSdk() + + var accessToken atomic.Value + accessToken.Store(infisicalClient.Auth().GetAccessToken()) + + if accessToken.Load().(string) == "" { + util.HandleError(errors.New("no access token found")) + } + + gatewayInstance.SetToken(accessToken.Load().(string)) + + Telemetry.CaptureEvent("cli-command:gateway-v2", posthog.NewProperties().Set("version", util.CLI_VERSION)) + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + ctx, cancelCmd := context.WithCancel(cmd.Context()) + defer cancelCmd() + + go func() { + <-sigCh + log.Info().Msg("Received shutdown signal, shutting down gateway...") + cancelCmd() + cancelSdk() + + // Give graceful shutdown 10 seconds, then force exit on second signal + select { + case <-sigCh: + log.Warn().Msg("Second signal received, force exit triggered") + os.Exit(1) + case <-time.After(10 * time.Second): + log.Info().Msg("Graceful shutdown completed") + os.Exit(0) + } + }() + + // Token refresh goroutine - runs every 10 seconds + go func() { + tokenRefreshTicker := time.NewTicker(10 * time.Second) + defer tokenRefreshTicker.Stop() + + for { + select { + case <-tokenRefreshTicker.C: + if ctx.Err() != nil { + return + } + + newToken := infisicalClient.Auth().GetAccessToken() + if newToken != "" && newToken != accessToken.Load().(string) { + accessToken.Store(newToken) + gatewayInstance.SetToken(newToken) + } + + case <-ctx.Done(): + return + } + } + }() + + err = gatewayInstance.Start(ctx) + if err != nil { + util.HandleError(err, "unable to start gateway instance") + } + }, +} + var gatewayInstallCmd = &cobra.Command{ Use: "install", Short: "Install and enable systemd service for the gateway (requires sudo)", @@ -265,6 +367,99 @@ var gatewayUninstallCmd = &cobra.Command{ }, } +var gatewaySystemdCmd = &cobra.Command{ + Use: "systemd", + Short: "Manage systemd service for Infisical gateway", + Long: "Manage systemd service for Infisical gateway. Use 'systemd install' to install and enable the service.", + Example: `sudo infisical gateway systemd install --token= --domain= --name= --relay= + sudo infisical gateway systemd uninstall`, + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, +} + +var gatewaySystemdInstallCmd = &cobra.Command{ + Use: "install", + Short: "Install and enable systemd service for the gateway (v2) (requires sudo)", + Long: "Install and enable systemd service for the new gateway (v2). Must be run with sudo on Linux.", + Example: "sudo infisical gateway systemd install --token= --domain= --name= --relay=", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + if runtime.GOOS != "linux" { + util.HandleError(fmt.Errorf("systemd service installation is only supported on Linux")) + } + + if os.Geteuid() != 0 { + util.HandleError(fmt.Errorf("systemd service installation requires root/sudo privileges")) + } + + token, err := util.GetInfisicalToken(cmd) + if err != nil { + util.HandleError(err, "Unable to parse flag") + } + + if token == nil { + util.HandleError(errors.New("Token not found")) + } + + domain, err := cmd.Flags().GetString("domain") + if err != nil { + util.HandleError(err, "Unable to parse domain flag") + } + + gatewayName, err := cmd.Flags().GetString("name") + if err != nil { + util.HandleError(err, "Unable to parse name flag") + } + if gatewayName == "" { + util.HandleError(errors.New("Gateway name is required")) + } + + relayName, err := cmd.Flags().GetString("relay") + if err != nil { + util.HandleError(err, "Unable to parse relay flag") + } + if relayName == "" { + util.HandleError(errors.New("Relay is required")) + } + + err = gatewayv2.InstallGatewaySystemdService(token.Token, domain, gatewayName, relayName) + if err != nil { + util.HandleError(err, "Unable to install systemd service") + } + + enableCmd := exec.Command("systemctl", "enable", "infisical-gateway") + if err := enableCmd.Run(); err != nil { + util.HandleError(err, "Failed to enable systemd service") + } + + log.Info().Msg("Successfully installed and enabled infisical-gateway service") + log.Info().Msg("To start the service, run: sudo systemctl start infisical-gateway") + }, +} + +var gatewaySystemdUninstallCmd = &cobra.Command{ + Use: "uninstall", + Short: "Uninstall and remove systemd service for the gateway (requires sudo)", + Long: "Uninstall and remove systemd service for the gateway. Must be run with sudo on Linux.", + Example: "sudo infisical gateway systemd uninstall", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + if runtime.GOOS != "linux" { + util.HandleError(fmt.Errorf("systemd service installation is only supported on Linux")) + } + + if os.Geteuid() != 0 { + util.HandleError(fmt.Errorf("systemd service installation requires root/sudo privileges")) + } + + if err := gatewayv2.UninstallGatewaySystemdService(); err != nil { + util.HandleError(err, "Failed to uninstall systemd service") + } + }, +} + var gatewayRelayCmd = &cobra.Command{ Example: `infisical gateway relay`, Short: "Used to run infisical gateway relay", @@ -293,24 +488,47 @@ var gatewayRelayCmd = &cobra.Command{ } func init() { + // Legacy gateway command flags (v1) gatewayCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") - gatewayCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") - gatewayCmd.Flags().String("client-id", "", "client id for universal auth") gatewayCmd.Flags().String("client-secret", "", "client secret for universal auth") - gatewayCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") gatewayCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") gatewayCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") - gatewayCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") + // Gateway start command flags (v2) + gatewayStartCmd.Flags().String("relay", "", "name of the relay to connect to") + gatewayStartCmd.Flags().String("name", "", "name of the gateway") + gatewayStartCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") + gatewayStartCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") + gatewayStartCmd.Flags().String("client-id", "", "client id for universal auth") + gatewayStartCmd.Flags().String("client-secret", "", "client secret for universal auth") + gatewayStartCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") + gatewayStartCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") + gatewayStartCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") + gatewayStartCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") + + // Legacy install command flags (v1) gatewayInstallCmd.Flags().String("token", "", "Connect with Infisical using machine identity access token") gatewayInstallCmd.Flags().String("domain", "", "Domain of your self-hosted Infisical instance") + // Systemd install command flags (v2) + gatewaySystemdInstallCmd.Flags().String("token", "", "Connect with Infisical using machine identity access token") + gatewaySystemdInstallCmd.Flags().String("domain", "", "Domain of your self-hosted Infisical instance") + gatewaySystemdInstallCmd.Flags().String("name", "", "The name of the gateway") + gatewaySystemdInstallCmd.Flags().String("relay", "", "The name of the relay") + + // Gateway relay command flags gatewayRelayCmd.Flags().String("config", "", "Relay config yaml file path") + // Wire up command hierarchy + gatewaySystemdCmd.AddCommand(gatewaySystemdInstallCmd) + gatewaySystemdCmd.AddCommand(gatewaySystemdUninstallCmd) + + gatewayCmd.AddCommand(gatewayStartCmd) + gatewayCmd.AddCommand(gatewaySystemdCmd) gatewayCmd.AddCommand(gatewayInstallCmd) gatewayCmd.AddCommand(gatewayUninstallCmd) gatewayCmd.AddCommand(gatewayRelayCmd) diff --git a/packages/cmd/network.go b/packages/cmd/network.go index 753e7d22..4e9f6935 100644 --- a/packages/cmd/network.go +++ b/packages/cmd/network.go @@ -12,7 +12,6 @@ import ( "time" gatewayv2 "github.com/Infisical/infisical-merge/packages/gateway-v2" - "github.com/Infisical/infisical-merge/packages/proxy" "github.com/Infisical/infisical-merge/packages/util" "github.com/rs/zerolog/log" "github.com/spf13/cobra" @@ -24,132 +23,18 @@ var networkCmd = &cobra.Command{ Long: "Network-related commands for Infisical", } -var networkProxyCmd = &cobra.Command{ - Use: "proxy", - Short: "Run the Infisical proxy component", - Long: "Run the Infisical proxy component", - Example: "infisical network proxy --type=instance --ip= --name= --token=", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - - proxyName, err := cmd.Flags().GetString("name") - if err != nil || proxyName == "" { - util.HandleError(err, "unable to get name flag") - } - - ip, err := cmd.Flags().GetString("ip") - if err != nil || ip == "" { - util.HandleError(err, "unable to get ip flag") - } - - instanceType, err := cmd.Flags().GetString("type") - if err != nil { - util.HandleError(err, "unable to get type flag") - } - - proxyInstance, err := proxy.NewProxy(&proxy.ProxyConfig{ - ProxyName: proxyName, - SSHPort: "2222", - TLSPort: "8443", - StaticIP: ip, - Type: instanceType, - }) - - if err != nil { - util.HandleError(err, "unable to create proxy instance") - } - - if instanceType == "instance" { - proxyAuthSecret := os.Getenv(gatewayv2.PROXY_AUTH_SECRET_ENV_NAME) - if proxyAuthSecret == "" { - util.HandleError(fmt.Errorf("%s is not set", gatewayv2.PROXY_AUTH_SECRET_ENV_NAME), "unable to get proxy auth secret") - } - - proxyInstance.SetToken(proxyAuthSecret) - } else { - infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) - if err != nil { - util.HandleError(err, "unable to get infisical client") - } - defer cancelSdk() - - var accessToken atomic.Value - accessToken.Store(infisicalClient.Auth().GetAccessToken()) - - if accessToken.Load().(string) == "" { - util.HandleError(errors.New("no access token found")) - } - - proxyInstance.SetToken(accessToken.Load().(string)) - - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - - ctx, cancelCmd := context.WithCancel(cmd.Context()) - defer cancelCmd() - - go func() { - <-sigCh - log.Info().Msg("Received shutdown signal, shutting down proxy...") - cancelCmd() - cancelSdk() - - // Give graceful shutdown 10 seconds, then force exit on second signal - select { - case <-sigCh: - log.Warn().Msg("Second signal received, force exit triggered") - os.Exit(1) - case <-time.After(10 * time.Second): - log.Info().Msg("Graceful shutdown completed") - os.Exit(0) - } - }() - - // Token refresh goroutine - runs every 10 seconds - go func() { - tokenRefreshTicker := time.NewTicker(10 * time.Second) - defer tokenRefreshTicker.Stop() - - for { - select { - case <-tokenRefreshTicker.C: - if ctx.Err() != nil { - return - } - - newToken := infisicalClient.Auth().GetAccessToken() - if newToken != "" && newToken != accessToken.Load().(string) { - accessToken.Store(newToken) - proxyInstance.SetToken(newToken) - } - - case <-ctx.Done(): - return - } - } - }() - } - - err = proxyInstance.Start(cmd.Context()) - if err != nil { - util.HandleError(err, "unable to start proxy instance") - } - }, -} - var networkGatewayCmd = &cobra.Command{ Use: "gateway", Short: "Run the Infisical gateway component", Long: "Run the Infisical gateway component. Use 'network gateway install' to set up the systemd service.", - Example: "infisical network gateway --proxy-name= --name= --token=", + Example: "infisical network gateway --relay= --name= --token=", DisableFlagsInUseLine: true, Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { - proxyName, err := util.GetCmdFlagOrEnv(cmd, "proxy-name", []string{gatewayv2.PROXY_NAME_ENV_NAME}) + relayName, err := util.GetCmdFlagOrEnv(cmd, "relay", []string{gatewayv2.RELAY_NAME_ENV_NAME}) if err != nil { - util.HandleError(err, fmt.Sprintf("unable to get proxy-name flag or %s env", gatewayv2.PROXY_NAME_ENV_NAME)) + util.HandleError(err, fmt.Sprintf("unable to get relay flag or %s env", gatewayv2.RELAY_NAME_ENV_NAME)) } gatewayName, err := util.GetCmdFlagOrEnv(cmd, "name", []string{gatewayv2.GATEWAY_NAME_ENV_NAME}) @@ -159,7 +44,7 @@ var networkGatewayCmd = &cobra.Command{ gatewayInstance, err := gatewayv2.NewGateway(&gatewayv2.GatewayConfig{ Name: gatewayName, - ProxyName: proxyName, + RelayName: relayName, ReconnectDelay: 10 * time.Second, }) @@ -241,7 +126,7 @@ var networkGatewayInstallCmd = &cobra.Command{ Use: "install", Short: "Install and enable systemd service for the gateway (requires sudo)", Long: "Install and enable systemd service for the gateway. Must be run with sudo on Linux.", - Example: "sudo infisical network gateway install --token= --domain= --name= --proxy-name=", + Example: "sudo infisical network gateway install --token= --domain= --name= --relay=", DisableFlagsInUseLine: true, Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { @@ -275,15 +160,15 @@ var networkGatewayInstallCmd = &cobra.Command{ util.HandleError(errors.New("Gateway name is required")) } - proxyName, err := cmd.Flags().GetString("proxy-name") + relayName, err := cmd.Flags().GetString("relay") if err != nil { - util.HandleError(err, "Unable to parse proxy-name flag") + util.HandleError(err, "Unable to parse relay flag") } - if proxyName == "" { - util.HandleError(errors.New("Proxy name is required")) + if relayName == "" { + util.HandleError(errors.New("Relay is required")) } - err = gatewayv2.InstallGatewaySystemdService(token.Token, domain, gatewayName, proxyName) + err = gatewayv2.InstallGatewaySystemdService(token.Token, domain, gatewayName, relayName) if err != nil { util.HandleError(err, "Unable to install systemd service") } @@ -313,7 +198,7 @@ var networkGatewayUninstallCmd = &cobra.Command{ } func init() { - networkGatewayCmd.Flags().String("proxy-name", "", "The name of the proxy to connect to") + networkGatewayCmd.Flags().String("relay", "", "The name of the relay to connect to") networkGatewayCmd.Flags().String("name", "", "The name of the gateway") networkGatewayCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") networkGatewayCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") @@ -324,27 +209,14 @@ func init() { networkGatewayCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") networkGatewayCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") - networkProxyCmd.Flags().String("type", "org", "The type of proxy to run. Must be either 'instance' or 'org'") - networkProxyCmd.Flags().String("ip", "", "The IP address of the proxy") - networkProxyCmd.Flags().String("name", "", "The name of the proxy") - networkProxyCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") - networkProxyCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") - networkProxyCmd.Flags().String("client-id", "", "client id for universal auth") - networkProxyCmd.Flags().String("client-secret", "", "client secret for universal auth") - networkProxyCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") - networkProxyCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") - networkProxyCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") - networkProxyCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") - networkGatewayInstallCmd.Flags().String("token", "", "Connect with Infisical using machine identity access token") networkGatewayInstallCmd.Flags().String("domain", "", "Domain of your self-hosted Infisical instance") networkGatewayInstallCmd.Flags().String("name", "", "The name of the gateway") - networkGatewayInstallCmd.Flags().String("proxy-name", "", "The name of the proxy") + networkGatewayInstallCmd.Flags().String("relay", "", "The name of the relay") networkGatewayCmd.AddCommand(networkGatewayInstallCmd) networkGatewayCmd.AddCommand(networkGatewayUninstallCmd) - networkCmd.AddCommand(networkProxyCmd) networkCmd.AddCommand(networkGatewayCmd) rootCmd.AddCommand(networkCmd) diff --git a/packages/cmd/relay.go b/packages/cmd/relay.go new file mode 100644 index 00000000..4a21cbc2 --- /dev/null +++ b/packages/cmd/relay.go @@ -0,0 +1,156 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "sync/atomic" + "syscall" + "time" + + gatewayv2 "github.com/Infisical/infisical-merge/packages/gateway-v2" + "github.com/Infisical/infisical-merge/packages/relay" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" +) + +var relayCmd = &cobra.Command{ + Use: "relay", + Short: "Relay-related commands", + Long: "Relay-related commands for Infisical", +} + +var relayStartCmd = &cobra.Command{ + Use: "start", + Short: "Start the Infisical relay component", + Long: "Start the Infisical relay component", + Example: "infisical relay start --type=instance --ip= --name= --token=", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + + relayName, err := cmd.Flags().GetString("name") + if err != nil || relayName == "" { + util.HandleError(err, "unable to get name flag") + } + + ip, err := cmd.Flags().GetString("ip") + if err != nil || ip == "" { + util.HandleError(err, "unable to get ip flag") + } + + instanceType, err := cmd.Flags().GetString("type") + if err != nil { + util.HandleError(err, "unable to get type flag") + } + + relayInstance, err := relay.NewRelay(&relay.RelayConfig{ + RelayName: relayName, + SSHPort: "2222", + TLSPort: "8443", + StaticIP: ip, + Type: instanceType, + }) + + if err != nil { + util.HandleError(err, "unable to create relay instance") + } + + if instanceType == "instance" { + relayAuthSecret := os.Getenv(gatewayv2.RELAY_AUTH_SECRET_ENV_NAME) + if relayAuthSecret == "" { + util.HandleError(fmt.Errorf("%s is not set", gatewayv2.RELAY_AUTH_SECRET_ENV_NAME), "unable to get relay auth secret") + } + + relayInstance.SetToken(relayAuthSecret) + } else { + infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) + if err != nil { + util.HandleError(err, "unable to get infisical client") + } + defer cancelSdk() + + var accessToken atomic.Value + accessToken.Store(infisicalClient.Auth().GetAccessToken()) + + if accessToken.Load().(string) == "" { + util.HandleError(errors.New("no access token found")) + } + + relayInstance.SetToken(accessToken.Load().(string)) + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + ctx, cancelCmd := context.WithCancel(cmd.Context()) + defer cancelCmd() + + go func() { + <-sigCh + log.Info().Msg("Received shutdown signal, shutting down relay...") + cancelCmd() + cancelSdk() + + // Give graceful shutdown 10 seconds, then force exit on second signal + select { + case <-sigCh: + log.Warn().Msg("Second signal received, force exit triggered") + os.Exit(1) + case <-time.After(10 * time.Second): + log.Info().Msg("Graceful shutdown completed") + os.Exit(0) + } + }() + + // Token refresh goroutine - runs every 10 seconds + go func() { + tokenRefreshTicker := time.NewTicker(10 * time.Second) + defer tokenRefreshTicker.Stop() + + for { + select { + case <-tokenRefreshTicker.C: + if ctx.Err() != nil { + return + } + + newToken := infisicalClient.Auth().GetAccessToken() + if newToken != "" && newToken != accessToken.Load().(string) { + accessToken.Store(newToken) + relayInstance.SetToken(newToken) + } + + case <-ctx.Done(): + return + } + } + }() + } + + err = relayInstance.Start(cmd.Context()) + if err != nil { + util.HandleError(err, "unable to start relay instance") + } + }, +} + +func init() { + relayStartCmd.Flags().String("type", "org", "The type of relay to run. Must be either 'instance' or 'org'") + relayStartCmd.Flags().String("ip", "", "The IP address of the relay") + relayStartCmd.Flags().String("name", "", "The name of the relay") + relayStartCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") + relayStartCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") + relayStartCmd.Flags().String("client-id", "", "client id for universal auth") + relayStartCmd.Flags().String("client-secret", "", "client secret for universal auth") + relayStartCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") + relayStartCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") + relayStartCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") + relayStartCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") + + relayCmd.AddCommand(relayStartCmd) + + rootCmd.AddCommand(relayCmd) +} diff --git a/packages/gateway-v2/constants.go b/packages/gateway-v2/constants.go index de54cd6f..87597511 100644 --- a/packages/gateway-v2/constants.go +++ b/packages/gateway-v2/constants.go @@ -6,10 +6,10 @@ const ( KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/token" - PROXY_NAME_ENV_NAME = "INFISICAL_PROXY_NAME" + RELAY_NAME_ENV_NAME = "INFISICAL_RELAY_NAME" GATEWAY_NAME_ENV_NAME = "INFISICAL_GATEWAY_NAME" - PROXY_AUTH_SECRET_ENV_NAME = "INFISICAL_PROXY_AUTH_SECRET" + RELAY_AUTH_SECRET_ENV_NAME = "INFISICAL_RELAY_AUTH_SECRET" INFISICAL_HTTP_PROXY_ACTION_HEADER = "x-infisical-action" ) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 46553798..3d3622c1 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -64,7 +64,7 @@ type ActorDetails struct { type GatewayConfig struct { Name string - ProxyName string + RelayName string IdentityToken string SSHPort int ReconnectDelay time.Duration @@ -224,13 +224,13 @@ func (g *Gateway) connectAndServe() error { return fmt.Errorf("failed to create SSH config: %v", err) } - // Connect to Proxy server - log.Info().Msgf("Connecting to proxy server %s on %s:%d...", g.config.ProxyName, g.certificates.ProxyIP, g.config.SSHPort) - client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", g.certificates.ProxyIP, g.config.SSHPort), sshConfig) + // Connect to Relay server + log.Info().Msgf("Connecting to relay server %s on %s:%d...", g.config.RelayName, g.certificates.RelayIP, g.config.SSHPort) + client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", g.certificates.RelayIP, g.config.SSHPort), sshConfig) if err != nil { return fmt.Errorf("failed to connect to SSH server: %v", err) } - log.Info().Msgf("Proxy connection established for gateway") + log.Info().Msgf("Relay connection established for gateway") g.mu.Lock() g.sshClient = client @@ -254,7 +254,7 @@ func (g *Gateway) connectAndServe() error { // Monitor for context cancellation and close SSH client go func() { <-g.ctx.Done() - log.Info().Msg("Context cancelled, closing proxy connection...") + log.Info().Msg("Context cancelled, closing relay connection...") client.Close() }() @@ -276,7 +276,7 @@ func (g *Gateway) connectAndServe() error { func (g *Gateway) registerGateway() error { body := api.RegisterGatewayRequest{ - ProxyName: g.config.ProxyName, + RelayName: g.config.RelayName, Name: g.config.Name, } diff --git a/packages/gateway-v2/systemd.go b/packages/gateway-v2/systemd.go index 794509ea..d4fa2940 100644 --- a/packages/gateway-v2/systemd.go +++ b/packages/gateway-v2/systemd.go @@ -18,7 +18,7 @@ After=network.target Type=notify NotifyAccess=all EnvironmentFile=/etc/infisical/gateway.conf -ExecStart=infisical network gateway +ExecStart=infisical gateway start Restart=on-failure InaccessibleDirectories=/home PrivateTmp=yes @@ -32,7 +32,7 @@ LimitRTTIME=7000000 WantedBy=multi-user.target ` -func InstallGatewaySystemdService(token string, domain string, name string, proxyName string) error { +func InstallGatewaySystemdService(token string, domain string, name string, relayName string) error { if runtime.GOOS != "linux" { log.Info().Msg("Skipping systemd service installation - not on Linux") return nil @@ -56,8 +56,8 @@ func InstallGatewaySystemdService(token string, domain string, name string, prox if name != "" { configContent += fmt.Sprintf("%s=%s\n", GATEWAY_NAME_ENV_NAME, name) } - if proxyName != "" { - configContent += fmt.Sprintf("%s=%s\n", PROXY_NAME_ENV_NAME, proxyName) + if relayName != "" { + configContent += fmt.Sprintf("%s=%s\n", RELAY_NAME_ENV_NAME, relayName) } configPath := filepath.Join(configDir, "gateway.conf") diff --git a/packages/proxy/proxy.go b/packages/relay/relay.go similarity index 72% rename from packages/proxy/proxy.go rename to packages/relay/relay.go index 2a2a6301..e7d678bb 100644 --- a/packages/proxy/proxy.go +++ b/packages/relay/relay.go @@ -1,4 +1,4 @@ -package proxy +package relay import ( "bytes" @@ -20,10 +20,10 @@ import ( "golang.org/x/crypto/ssh" ) -type ProxyConfig struct { +type RelayConfig struct { // API Configuration Token string - ProxyName string + RelayName string Type string @@ -35,12 +35,12 @@ type ProxyConfig struct { StaticIP string } -type Proxy struct { +type Relay struct { httpClient *resty.Client - config *ProxyConfig + config *RelayConfig // Certificate storage - certificates *api.RegisterProxyResponse + certificates *api.RegisterRelayResponse // SSH server components sshConfig *ssh.ServerConfig @@ -60,7 +60,7 @@ type Proxy struct { tlsListener net.Listener } -func NewProxy(config *ProxyConfig) (*Proxy, error) { +func NewRelay(config *RelayConfig) (*Relay, error) { httpClient, err := util.GetRestyClientWithCustomHeaders() if err != nil { return nil, fmt.Errorf("unable to get client with custom headers [err=%v]", err) @@ -68,90 +68,90 @@ func NewProxy(config *ProxyConfig) (*Proxy, error) { httpClient.SetAuthToken(config.Token) - return &Proxy{ + return &Relay{ httpClient: httpClient, config: config, tunnels: make(map[string]*ssh.ServerConn), }, nil } -func (p *Proxy) SetToken(token string) { - p.httpClient.SetAuthToken(token) +func (r *Relay) SetToken(token string) { + r.httpClient.SetAuthToken(token) } -func (p *Proxy) Start(ctx context.Context) error { - if err := p.registerProxy(); err != nil { - return fmt.Errorf("failed to register proxy: %v", err) +func (r *Relay) Start(ctx context.Context) error { + if err := r.registerRelay(); err != nil { + return fmt.Errorf("failed to register relay: %v", err) } // Setup SSH server - if err := p.setupSSHServer(); err != nil { + if err := r.setupSSHServer(); err != nil { return fmt.Errorf("failed to setup SSH server: %v", err) } // Setup TLS server - if err := p.setupTLSServer(); err != nil { + if err := r.setupTLSServer(); err != nil { return fmt.Errorf("failed to setup TLS server: %v", err) } // Start certificate renewal goroutine - go p.startCertificateRenewal(ctx) + go r.startCertificateRenewal(ctx) // Start SSH server - go p.startSSHServer() + go r.startSSHServer() // Start TLS server - go p.startTLSServer() + go r.startTLSServer() - log.Info().Msg("Proxy server started successfully") + log.Info().Msg("Relay server started successfully") // Wait for context cancellation <-ctx.Done() // Cleanup - p.cleanup() + r.cleanup() return nil } -func (p *Proxy) registerProxy() error { - body := api.RegisterProxyRequest{ - IP: p.config.StaticIP, - Name: p.config.ProxyName, +func (r *Relay) registerRelay() error { + body := api.RegisterRelayRequest{ + IP: r.config.StaticIP, + Name: r.config.RelayName, } - if p.config.Type == "instance" { - certResp, err := api.CallRegisterInstanceProxy(p.httpClient, body) + if r.config.Type == "instance" { + certResp, err := api.CallRegisterInstanceRelay(r.httpClient, body) if err != nil { - return fmt.Errorf("failed to register instance proxy: %v", err) + return fmt.Errorf("failed to register instance relay: %v", err) } - p.certificates = &certResp + r.certificates = &certResp } else { - certResp, err := api.CallRegisterProxy(p.httpClient, body) + certResp, err := api.CallRegisterRelay(r.httpClient, body) if err != nil { - return fmt.Errorf("failed to register org proxy: %v", err) + return fmt.Errorf("failed to register org relay: %v", err) } - p.certificates = &certResp + r.certificates = &certResp } - log.Info().Msg("Successfully registered proxy and received certificates from API") + log.Info().Msg("Successfully registered relay and received certificates from API") return nil } -func (p *Proxy) setupSSHServer() error { +func (r *Relay) setupSSHServer() error { // Parse SSH CA public key - sshCAPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(p.certificates.SSH.ClientCAPublicKey)) + sshCAPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.certificates.SSH.ClientCAPublicKey)) if err != nil { return fmt.Errorf("failed to parse SSH CA public key: %v", err) } // Parse SSH server private key - sshServerKey, err := ssh.ParsePrivateKey([]byte(p.certificates.SSH.ServerPrivateKey)) + sshServerKey, err := ssh.ParsePrivateKey([]byte(r.certificates.SSH.ServerPrivateKey)) if err != nil { return fmt.Errorf("failed to parse SSH server private key: %v", err) } // Parse SSH server certificate - sshServerCert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(p.certificates.SSH.ServerCertificate)) + sshServerCert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.certificates.SSH.ServerCertificate)) if err != nil { return fmt.Errorf("failed to parse SSH server certificate: %v", err) } @@ -163,7 +163,7 @@ func (p *Proxy) setupSSHServer() error { } // Setup SSH server config - p.sshConfig = &ssh.ServerConfig{ + r.sshConfig = &ssh.ServerConfig{ MaxAuthTries: 3, PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { // Check if this is an SSH certificate @@ -174,7 +174,7 @@ func (p *Proxy) setupSSHServer() error { } // Validate the certificate - if err := p.validateSSHCertificate(cert, conn.User(), sshCAPubKey); err != nil { + if err := r.validateSSHCertificate(cert, conn.User(), sshCAPubKey); err != nil { log.Error().Msgf("Gateway '%s' certificate validation failed: %v", conn.User(), err) return nil, err } @@ -188,8 +188,8 @@ func (p *Proxy) setupSSHServer() error { return nil, fmt.Errorf("gateway id is required") } - // Validate that the user is authorized to connect to the current proxy - expectedKeyId := "client-" + p.config.ProxyName + // Validate that the user is authorized to connect to the current relay + expectedKeyId := "client-" + r.config.RelayName if cert.KeyId != expectedKeyId { log.Error().Msgf("Gateway '%s' certificate Key ID '%s' does not match expected '%s'", conn.User(), cert.KeyId, expectedKeyId) return nil, fmt.Errorf("certificate Key ID does not match expected value") @@ -203,13 +203,13 @@ func (p *Proxy) setupSSHServer() error { }, } - p.sshConfig.AddHostKey(certSigner) + r.sshConfig.AddHostKey(certSigner) return nil } -func (p *Proxy) setupTLSServer() error { +func (r *Relay) setupTLSServer() error { // Parse TLS server certificate - serverCertBlock, _ := pem.Decode([]byte(p.certificates.PKI.ServerCertificate)) + serverCertBlock, _ := pem.Decode([]byte(r.certificates.PKI.ServerCertificate)) if serverCertBlock == nil { return fmt.Errorf("failed to decode server certificate") } @@ -222,7 +222,7 @@ func (p *Proxy) setupTLSServer() error { } // Parse TLS server private key - serverKeyBlock, _ := pem.Decode([]byte(p.certificates.PKI.ServerPrivateKey)) + serverKeyBlock, _ := pem.Decode([]byte(r.certificates.PKI.ServerPrivateKey)) if serverKeyBlock == nil { return fmt.Errorf("failed to decode server private key") } @@ -236,7 +236,7 @@ func (p *Proxy) setupTLSServer() error { clientCAPool := x509.NewCertPool() var chainCerts [][]byte - chainData := []byte(p.certificates.PKI.ClientCertificateChain) + chainData := []byte(r.certificates.PKI.ClientCertificateChain) for { block, rest := pem.Decode(chainData) if block == nil { @@ -256,7 +256,7 @@ func (p *Proxy) setupTLSServer() error { } // Create TLS config - p.tlsConfig = &tls.Config{ + r.tlsConfig = &tls.Config{ Certificates: []tls.Certificate{ { Certificate: [][]byte{serverCertBlock.Bytes}, @@ -271,7 +271,7 @@ func (p *Proxy) setupTLSServer() error { return nil } -func (p *Proxy) validateSSHCertificate(cert *ssh.Certificate, username string, caPubKey ssh.PublicKey) error { +func (r *Relay) validateSSHCertificate(cert *ssh.Certificate, username string, caPubKey ssh.PublicKey) error { // Check certificate type if cert.CertType != ssh.UserCert { return fmt.Errorf("invalid certificate type: %d", cert.CertType) @@ -293,14 +293,14 @@ func (p *Proxy) validateSSHCertificate(cert *ssh.Certificate, username string, c return nil } -func (p *Proxy) startSSHServer() { - listener, err := net.Listen("tcp", ":"+p.config.SSHPort) +func (r *Relay) startSSHServer() { + listener, err := net.Listen("tcp", ":"+r.config.SSHPort) if err != nil { log.Fatal().Msgf("Failed to start SSH server: %v", err) } - p.sshListener = listener + r.sshListener = listener - log.Info().Msgf("SSH server listening on :%s for gateways", p.config.SSHPort) + log.Info().Msgf("SSH server listening on :%s for gateways", r.config.SSHPort) for { conn, err := listener.Accept() @@ -308,15 +308,15 @@ func (p *Proxy) startSSHServer() { log.Error().Msgf("Failed to accept SSH connection: %v", err) continue } - go p.handleSSHAgent(conn) + go r.handleSSHAgent(conn) } } -func (p *Proxy) handleSSHAgent(conn net.Conn) { +func (r *Relay) handleSSHAgent(conn net.Conn) { defer conn.Close() // SSH handshake - sshConn, chans, reqs, err := ssh.NewServerConn(conn, p.sshConfig) + sshConn, chans, reqs, err := ssh.NewServerConn(conn, r.sshConfig) if err != nil { log.Error().Msgf("SSH handshake failed: %v", err) return @@ -326,22 +326,22 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { log.Info().Msgf("SSH handshake successful for gateway: %s", gatewayId) // Store the connection (ensure only one connection per gateway) - p.mu.Lock() - if _, exists := p.tunnels[gatewayId]; exists { - p.mu.Unlock() + r.mu.Lock() + if _, exists := r.tunnels[gatewayId]; exists { + r.mu.Unlock() log.Warn().Msgf("Gateway '%s' already has an active connection, rejecting new connection", gatewayId) sshConn.Close() return } - p.tunnels[gatewayId] = sshConn - p.mu.Unlock() + r.tunnels[gatewayId] = sshConn + r.mu.Unlock() // Clean up when agent disconnects defer func() { - p.mu.Lock() - delete(p.tunnels, gatewayId) - p.mu.Unlock() + r.mu.Lock() + delete(r.tunnels, gatewayId) + r.mu.Unlock() log.Info().Msgf("Gateway %s disconnected", gatewayId) }() @@ -377,14 +377,14 @@ func (p *Proxy) handleSSHAgent(conn net.Conn) { } } -func (p *Proxy) startTLSServer() { - listener, err := net.Listen("tcp", ":"+p.config.TLSPort) +func (r *Relay) startTLSServer() { + listener, err := net.Listen("tcp", ":"+r.config.TLSPort) if err != nil { log.Fatal().Msgf("Failed to start TLS server: %v", err) } - p.tlsListener = listener + r.tlsListener = listener - log.Info().Msgf("TLS server listening on :%s for clients", p.config.TLSPort) + log.Info().Msgf("TLS server listening on :%s for clients", r.config.TLSPort) for { conn, err := listener.Accept() @@ -392,15 +392,15 @@ func (p *Proxy) startTLSServer() { log.Error().Msgf("Failed to accept TLS connection: %v", err) continue } - go p.handleTLSClient(conn) + go r.handleTLSClient(conn) } } -func (p *Proxy) handleTLSClient(conn net.Conn) { +func (r *Relay) handleTLSClient(conn net.Conn) { defer conn.Close() // Perform TLS handshake using current TLS config - tlsConn := tls.Server(conn, p.tlsConfig) + tlsConn := tls.Server(conn, r.tlsConfig) defer tlsConn.Close() // Set handshake timeout to avoid hanging on slow/malicious connections @@ -416,10 +416,10 @@ func (p *Proxy) handleTLSClient(conn net.Conn) { // Clear deadline for actual data transfer tlsConn.SetDeadline(time.Time{}) - p.handleClient(tlsConn) + r.handleClient(tlsConn) } -func (p *Proxy) handleClient(tlsConn *tls.Conn) { +func (r *Relay) handleClient(tlsConn *tls.Conn) { var gatewayId string state := tlsConn.ConnectionState() @@ -432,10 +432,10 @@ func (p *Proxy) handleClient(tlsConn *tls.Conn) { return } - // Get the SSH connection for this agent - p.mu.RLock() - conn, exists := p.tunnels[gatewayId] - p.mu.RUnlock() + // Get the SSH connection for this gateway + r.mu.RLock() + conn, exists := r.tunnels[gatewayId] + r.mu.RUnlock() if !exists { log.Warn().Msgf("Gateway '%s' not connected", gatewayId) @@ -447,8 +447,8 @@ func (p *Proxy) handleClient(tlsConn *tls.Conn) { channel, _, err := conn.OpenChannel("direct-tcpip", nil) if err != nil { - log.Error().Msgf("Failed to connect to agent: %v", err) - tlsConn.Write([]byte("ERROR: Failed to connect to agent\n")) + log.Error().Msgf("Failed to connect to gateway: %v", err) + tlsConn.Write([]byte("ERROR: Failed to connect to gateway\n")) return } defer channel.Close() @@ -463,21 +463,21 @@ func (p *Proxy) handleClient(tlsConn *tls.Conn) { log.Info().Msgf("Client %s disconnected", tlsConn.RemoteAddr()) } -func (p *Proxy) cleanup() { - log.Info().Msg("Shutting down proxy server...") +func (r *Relay) cleanup() { + log.Info().Msg("Shutting down relay server...") - if p.sshListener != nil { - p.sshListener.Close() + if r.sshListener != nil { + r.sshListener.Close() } - if p.tlsListener != nil { - p.tlsListener.Close() + if r.tlsListener != nil { + r.tlsListener.Close() } - log.Info().Msg("Proxy server shutdown complete") + log.Info().Msg("Relay server shutdown complete") } // startCertificateRenewal runs a background process to renew certificates every 24 hours -func (p *Proxy) startCertificateRenewal(ctx context.Context) { +func (r *Relay) startCertificateRenewal(ctx context.Context) { log.Info().Msg("Starting certificate renewal goroutine") ticker := time.NewTicker(10 * 24 * time.Hour) defer ticker.Stop() @@ -489,7 +489,7 @@ func (p *Proxy) startCertificateRenewal(ctx context.Context) { return case <-ticker.C: log.Info().Msg("Renewing certificates...") - if err := p.renewCertificates(); err != nil { + if err := r.renewCertificates(); err != nil { log.Error().Msgf("Failed to renew certificates: %v", err) } else { log.Info().Msg("Certificates renewed successfully") @@ -499,19 +499,19 @@ func (p *Proxy) startCertificateRenewal(ctx context.Context) { } // renewCertificates fetches new certificates and updates the server configurations -func (p *Proxy) renewCertificates() error { - // Re-register proxy to get fresh certificates - if err := p.registerProxy(); err != nil { - return fmt.Errorf("failed to register proxy: %v", err) +func (r *Relay) renewCertificates() error { + // Re-register relay to get fresh certificates + if err := r.registerRelay(); err != nil { + return fmt.Errorf("failed to register relay: %v", err) } // Update SSH server configuration - if err := p.setupSSHServer(); err != nil { + if err := r.setupSSHServer(); err != nil { return fmt.Errorf("failed to setup SSH server: %v", err) } // Update TLS server configuration - if err := p.setupTLSServer(); err != nil { + if err := r.setupTLSServer(); err != nil { return fmt.Errorf("failed to setup TLS server: %v", err) } From e79f425ecdfc8ae885c1dd3f79dee0a4e0b56fec Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Tue, 9 Sep 2025 03:39:41 +0800 Subject: [PATCH 32/38] misc: updated ip flag to be host instead for relay --- packages/api/model.go | 2 +- packages/cmd/relay.go | 10 +++++----- packages/relay/relay.go | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/packages/api/model.go b/packages/api/model.go index a78c76d8..4f9ab7d9 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -705,7 +705,7 @@ type BootstrapUser struct { } type RegisterRelayRequest struct { - IP string `json:"ip"` + Host string `json:"host"` Name string `json:"name"` } diff --git a/packages/cmd/relay.go b/packages/cmd/relay.go index 4a21cbc2..dac6b1c5 100644 --- a/packages/cmd/relay.go +++ b/packages/cmd/relay.go @@ -37,9 +37,9 @@ var relayStartCmd = &cobra.Command{ util.HandleError(err, "unable to get name flag") } - ip, err := cmd.Flags().GetString("ip") - if err != nil || ip == "" { - util.HandleError(err, "unable to get ip flag") + host, err := cmd.Flags().GetString("host") + if err != nil || host == "" { + util.HandleError(err, "unable to get host flag") } instanceType, err := cmd.Flags().GetString("type") @@ -51,7 +51,7 @@ var relayStartCmd = &cobra.Command{ RelayName: relayName, SSHPort: "2222", TLSPort: "8443", - StaticIP: ip, + Host: host, Type: instanceType, }) @@ -139,7 +139,7 @@ var relayStartCmd = &cobra.Command{ func init() { relayStartCmd.Flags().String("type", "org", "The type of relay to run. Must be either 'instance' or 'org'") - relayStartCmd.Flags().String("ip", "", "The IP address of the relay") + relayStartCmd.Flags().String("host", "", "The IP or hostname for the relay") relayStartCmd.Flags().String("name", "", "The name of the relay") relayStartCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") relayStartCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") diff --git a/packages/relay/relay.go b/packages/relay/relay.go index e7d678bb..73e54333 100644 --- a/packages/relay/relay.go +++ b/packages/relay/relay.go @@ -32,7 +32,7 @@ type RelayConfig struct { TLSPort string // Network Configuration - StaticIP string + Host string } type Relay struct { @@ -115,7 +115,7 @@ func (r *Relay) Start(ctx context.Context) error { func (r *Relay) registerRelay() error { body := api.RegisterRelayRequest{ - IP: r.config.StaticIP, + Host: r.config.Host, Name: r.config.RelayName, } From 97da19883ed785761c11fa33b48f71ae25903636 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Tue, 9 Sep 2025 03:49:37 +0800 Subject: [PATCH 33/38] misc: updated logs --- packages/api/model.go | 2 +- packages/gateway-v2/gateway.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/api/model.go b/packages/api/model.go index 4f9ab7d9..699bc8c9 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -729,7 +729,7 @@ type RegisterGatewayRequest struct { type RegisterGatewayResponse struct { GatewayID string `json:"gatewayId"` - RelayIP string `json:"relayIp"` + RelayHost string `json:"relayHost"` PKI struct { ServerCertificate string `json:"serverCertificate"` ServerPrivateKey string `json:"serverPrivateKey"` diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 3d3622c1..02029bd1 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -225,8 +225,8 @@ func (g *Gateway) connectAndServe() error { } // Connect to Relay server - log.Info().Msgf("Connecting to relay server %s on %s:%d...", g.config.RelayName, g.certificates.RelayIP, g.config.SSHPort) - client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", g.certificates.RelayIP, g.config.SSHPort), sshConfig) + log.Info().Msgf("Connecting to relay server %s on %s:%d...", g.config.RelayName, g.certificates.RelayHost, g.config.SSHPort) + client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", g.certificates.RelayHost, g.config.SSHPort), sshConfig) if err != nil { return fmt.Errorf("failed to connect to SSH server: %v", err) } From e00b1a8ecf4d78174e780ea55f39312e274136e1 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Tue, 9 Sep 2025 04:43:07 +0800 Subject: [PATCH 34/38] misc: updated gateway to negotiate protocol through alpn headers --- packages/gateway-v2/gateway.go | 70 ++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 28 deletions(-) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 02029bd1..4e1880ca 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -344,6 +344,7 @@ func (g *Gateway) setupTLSConfig() error { ClientCAs: clientCAPool, ClientAuth: tls.RequireAndVerifyClientCert, MinVersion: tls.VersionTLS12, + NextProtos: []string{"infisical-http-proxy", "infisical-tcp-proxy", "infisical-ping"}, } return nil @@ -467,10 +468,10 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { // Create reader for the TLS connection reader := bufio.NewReader(tlsConn) - // Get the forward mode here - forwardConfig, err := g.parseForwardConfig(tlsConn, reader) + // Get the negotiated protocol from ALPN + forwardConfig, err := g.parseForwardConfigFromALPN(tlsConn, reader) if err != nil { - log.Info().Msgf("Failed to parse forward command: %v", err) + log.Info().Msgf("Failed to parse forward config from ALPN: %v", err) return } @@ -488,45 +489,58 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { } } -func (g *Gateway) parseForwardConfig(tlsConn *tls.Conn, reader *bufio.Reader) (*ForwardConfig, error) { +func (g *Gateway) parseForwardConfigFromALPN(tlsConn *tls.Conn, reader *bufio.Reader) (*ForwardConfig, error) { config := &ForwardConfig{} + // Parse routing information from the client certificate if err := g.parseDetailsFromCertificate(tlsConn, config); err != nil { return nil, fmt.Errorf("failed to parse routing info from certificate: %v", err) } - for { - msg, err := reader.ReadBytes('\n') - if err != nil { - return nil, fmt.Errorf("failed to read command: %v", err) - } + // Get the negotiated ALPN protocol + state := tlsConn.ConnectionState() + negotiatedProtocol := state.NegotiatedProtocol - cmd := strings.ToUpper(strings.TrimSpace(string(strings.Split(string(msg), " ")[0]))) - args := strings.TrimSpace(strings.TrimPrefix(string(msg), strings.Split(string(msg), " ")[0])) + log.Info().Msgf("Negotiated ALPN protocol: %s", negotiatedProtocol) - switch cmd { - case "FORWARD-TCP": - config.Mode = ForwardModeTCP - return config, nil + // Map ALPN protocol to ForwardMode + switch negotiatedProtocol { + case "infisical-http-proxy": + config.Mode = ForwardModeHTTP + // For HTTP proxy, read additional parameters from the connection + if err := g.parseHTTPParametersFromConnection(reader, config); err != nil { + return nil, fmt.Errorf("failed to parse HTTP parameters: %v", err) + } + return config, nil - case "FORWARD-HTTP": - config.Mode = ForwardModeHTTP - if args != "" { - if err := g.parseForwardHTTPParams(args, config); err != nil { - return nil, fmt.Errorf("failed to parse HTTP parameters: %v", err) - } - } + case "infisical-tcp-proxy": + config.Mode = ForwardModeTCP + return config, nil - return config, nil + case "infisical-ping": + config.Mode = ForwardModePing + return config, nil - case "PING": - config.Mode = ForwardModePing - return config, nil + default: + return nil, fmt.Errorf("unsupported ALPN protocol: %s", negotiatedProtocol) + } +} - default: - return nil, fmt.Errorf("invalid forward command: %s", cmd) +func (g *Gateway) parseHTTPParametersFromConnection(reader *bufio.Reader, config *ForwardConfig) error { + // Read the first line which should contain HTTP parameters + msg, err := reader.ReadBytes('\n') + if err != nil { + return fmt.Errorf("failed to read HTTP parameters: %v", err) + } + + params := strings.TrimSpace(string(msg)) + if params != "" { + if err := g.parseForwardHTTPParams(params, config); err != nil { + return fmt.Errorf("failed to parse HTTP parameters: %v", err) } } + + return nil } func (g *Gateway) parseForwardHTTPParams(params string, config *ForwardConfig) error { From 61182d956ae7c89df494a3e07fa72eddc421d916 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Thu, 11 Sep 2025 00:23:58 +0800 Subject: [PATCH 35/38] misc: addressed comments --- packages/cmd/gateway.go | 3 +-- packages/gateway-v2/connection.go | 10 +++++++-- packages/gateway-v2/gateway.go | 35 ++++++++++++++++++++++++++----- packages/relay/relay.go | 5 +++-- 4 files changed, 42 insertions(+), 11 deletions(-) diff --git a/packages/cmd/gateway.go b/packages/cmd/gateway.go index 8dca18c9..42bdbe61 100644 --- a/packages/cmd/gateway.go +++ b/packages/cmd/gateway.go @@ -88,8 +88,7 @@ var gatewayCmd = &cobra.Command{ DisableFlagsInUseLine: true, Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { - log.Warn().Msg("DEPRECATION WARNING: The 'infisical gateway' command is deprecated. Please use 'infisical gateway start'") - log.Warn().Msg("This legacy gateway will be removed in a future version.") + log.Info().Msg("DEPRECATION NOTICE: The 'infisical gateway' command will be deprecated in a future version. Please use 'infisical gateway start'.\nNOTE: This requires manually updating your existing resources to point to the new gateway.") infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) if err != nil { diff --git a/packages/gateway-v2/connection.go b/packages/gateway-v2/connection.go index 141681f8..04a856b1 100644 --- a/packages/gateway-v2/connection.go +++ b/packages/gateway-v2/connection.go @@ -215,6 +215,10 @@ func handleTCPProxy(ctx context.Context, conn *tls.Conn, forwardConfig *ForwardC } defer localConn.Close() + log.Info(). + Str("target", target). + Msg("TCP proxy connection established to local service") + // Create a context for this connection that gets cancelled when the parent context is cancelled // or when either connection closes connCtx, cancel := context.WithCancel(ctx) @@ -226,7 +230,8 @@ func handleTCPProxy(ctx context.Context, conn *tls.Conn, forwardConfig *ForwardC // Forward data from TLS connection to local service go func() { defer cancel() - _, err := io.Copy(localConn, conn) + bytesCopied, err := io.Copy(localConn, conn) + log.Info().Int64("bytes", bytesCopied).Msg("Copied from client to local service") if err != nil { if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { log.Debug().Msgf("TLS to local copy ended normally: %v", err) @@ -240,7 +245,8 @@ func handleTCPProxy(ctx context.Context, conn *tls.Conn, forwardConfig *ForwardC // Forward data from local service to TLS connection go func() { defer cancel() - _, err := io.Copy(conn, localConn) + bytesCopied, err := io.Copy(conn, localConn) + log.Info().Int64("bytes", bytesCopied).Msg("Copied from local service to client") if err != nil { if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { log.Debug().Msgf("Local to TLS copy ended normally: %v", err) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 4e1880ca..bcaaff0b 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -460,31 +460,56 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { tlsConn := tls.Server(virtualConn, tlsConfig) // Perform TLS handshake + log.Info().Msg("Starting TLS handshake on incoming channel") if err := tlsConn.Handshake(); err != nil { log.Info().Msgf("TLS handshake failed: %v", err) return } + log.Info().Msg("TLS handshake completed successfully") // Create reader for the TLS connection reader := bufio.NewReader(tlsConn) // Get the negotiated protocol from ALPN + log.Info().Msg("Parsing forwarding configuration from ALPN and client certificate") forwardConfig, err := g.parseForwardConfigFromALPN(tlsConn, reader) if err != nil { log.Info().Msgf("Failed to parse forward config from ALPN: %v", err) return } - log.Info().Msgf("Forward config: %+v", forwardConfig) - if forwardConfig.Mode == ForwardModeHTTP { - handleHTTPProxy(g.ctx, tlsConn, reader, forwardConfig) + log.Info(). + Str("mode", string(forwardConfig.Mode)). + Str("target", fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort)). + Str("actorType", string(forwardConfig.ActorType)). + Bool("verifyTLS", forwardConfig.VerifyTLS). + Msg("Starting HTTP proxy handler") + if err := handleHTTPProxy(g.ctx, tlsConn, reader, forwardConfig); err != nil { + log.Error().Err(err).Msg("HTTP proxy handler ended with error") + } else { + log.Info().Msg("HTTP proxy handler completed") + } return } else if forwardConfig.Mode == ForwardModeTCP { - handleTCPProxy(g.ctx, tlsConn, forwardConfig) + log.Info(). + Str("mode", string(forwardConfig.Mode)). + Str("target", fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort)). + Str("actorType", string(forwardConfig.ActorType)). + Msg("Starting TCP proxy handler") + if err := handleTCPProxy(g.ctx, tlsConn, forwardConfig); err != nil { + log.Error().Err(err).Msg("TCP proxy handler ended with error") + } else { + log.Info().Msg("TCP proxy handler completed") + } return } else if forwardConfig.Mode == ForwardModePing { - handlePing(g.ctx, tlsConn, reader) + log.Info().Msg("Starting ping handler") + if err := handlePing(g.ctx, tlsConn, reader); err != nil { + log.Error().Err(err).Msg("Ping handler ended with error") + } else { + log.Info().Msg("Ping handler completed") + } return } } diff --git a/packages/relay/relay.go b/packages/relay/relay.go index 73e54333..a7306a74 100644 --- a/packages/relay/relay.go +++ b/packages/relay/relay.go @@ -421,12 +421,14 @@ func (r *Relay) handleTLSClient(conn net.Conn) { func (r *Relay) handleClient(tlsConn *tls.Conn) { var gatewayId string + var orgDetails string state := tlsConn.ConnectionState() if len(state.PeerCertificates) > 0 { cert := state.PeerCertificates[0] log.Info().Msgf("Client connected with certificate: %s", cert.Subject.CommonName) gatewayId = cert.Subject.CommonName + orgDetails = cert.Subject.Organization[0] } else { log.Warn().Msg("No peer certificates found") return @@ -443,7 +445,7 @@ func (r *Relay) handleClient(tlsConn *tls.Conn) { return } - log.Info().Msgf("Routing TCP connection to gateway: %s", gatewayId) + log.Info().Msgf("Routing TCP connection from %s to Gateway with ID: %s", orgDetails, gatewayId) channel, _, err := conn.OpenChannel("direct-tcpip", nil) if err != nil { @@ -478,7 +480,6 @@ func (r *Relay) cleanup() { // startCertificateRenewal runs a background process to renew certificates every 24 hours func (r *Relay) startCertificateRenewal(ctx context.Context) { - log.Info().Msg("Starting certificate renewal goroutine") ticker := time.NewTicker(10 * 24 * time.Hour) defer ticker.Stop() From 88571797ac4266a9ecb852321c9b5e8a78b69ff7 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Thu, 11 Sep 2025 00:46:51 +0800 Subject: [PATCH 36/38] misc: more improvements to logging --- packages/gateway-v2/gateway.go | 4 +--- packages/relay/relay.go | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index bcaaff0b..939f2d48 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -460,7 +460,7 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { tlsConn := tls.Server(virtualConn, tlsConfig) // Perform TLS handshake - log.Info().Msg("Starting TLS handshake on incoming channel") + log.Info().Msg("Received incoming connection, starting TLS handshake") if err := tlsConn.Handshake(); err != nil { log.Info().Msgf("TLS handshake failed: %v", err) return @@ -470,8 +470,6 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { // Create reader for the TLS connection reader := bufio.NewReader(tlsConn) - // Get the negotiated protocol from ALPN - log.Info().Msg("Parsing forwarding configuration from ALPN and client certificate") forwardConfig, err := g.parseForwardConfigFromALPN(tlsConn, reader) if err != nil { log.Info().Msgf("Failed to parse forward config from ALPN: %v", err) diff --git a/packages/relay/relay.go b/packages/relay/relay.go index a7306a74..9cc4af9c 100644 --- a/packages/relay/relay.go +++ b/packages/relay/relay.go @@ -445,7 +445,7 @@ func (r *Relay) handleClient(tlsConn *tls.Conn) { return } - log.Info().Msgf("Routing TCP connection from %s to Gateway with ID: %s", orgDetails, gatewayId) + log.Info().Msgf("Routing connection from Organization %s to Gateway with ID: %s", orgDetails, gatewayId) channel, _, err := conn.OpenChannel("direct-tcpip", nil) if err != nil { From 242eb0a36ce9fb1c067030f323babbbc53660b77 Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Thu, 11 Sep 2025 02:54:55 +0800 Subject: [PATCH 37/38] misc: deleted network.go file --- packages/cmd/network.go | 223 ---------------------------------------- 1 file changed, 223 deletions(-) delete mode 100644 packages/cmd/network.go diff --git a/packages/cmd/network.go b/packages/cmd/network.go deleted file mode 100644 index 4e9f6935..00000000 --- a/packages/cmd/network.go +++ /dev/null @@ -1,223 +0,0 @@ -package cmd - -import ( - "context" - "errors" - "fmt" - "os" - "os/signal" - "runtime" - "sync/atomic" - "syscall" - "time" - - gatewayv2 "github.com/Infisical/infisical-merge/packages/gateway-v2" - "github.com/Infisical/infisical-merge/packages/util" - "github.com/rs/zerolog/log" - "github.com/spf13/cobra" -) - -var networkCmd = &cobra.Command{ - Use: "network", - Short: "Network-related commands", - Long: "Network-related commands for Infisical", -} - -var networkGatewayCmd = &cobra.Command{ - Use: "gateway", - Short: "Run the Infisical gateway component", - Long: "Run the Infisical gateway component. Use 'network gateway install' to set up the systemd service.", - Example: "infisical network gateway --relay= --name= --token=", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - - relayName, err := util.GetCmdFlagOrEnv(cmd, "relay", []string{gatewayv2.RELAY_NAME_ENV_NAME}) - if err != nil { - util.HandleError(err, fmt.Sprintf("unable to get relay flag or %s env", gatewayv2.RELAY_NAME_ENV_NAME)) - } - - gatewayName, err := util.GetCmdFlagOrEnv(cmd, "name", []string{gatewayv2.GATEWAY_NAME_ENV_NAME}) - if err != nil { - util.HandleError(err, fmt.Sprintf("unable to get name flag or %s env", gatewayv2.GATEWAY_NAME_ENV_NAME)) - } - - gatewayInstance, err := gatewayv2.NewGateway(&gatewayv2.GatewayConfig{ - Name: gatewayName, - RelayName: relayName, - ReconnectDelay: 10 * time.Second, - }) - - if err != nil { - util.HandleError(err, "unable to create gateway instance") - } - - infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) - if err != nil { - util.HandleError(err, "unable to get infisical client") - } - defer cancelSdk() - - var accessToken atomic.Value - accessToken.Store(infisicalClient.Auth().GetAccessToken()) - - if accessToken.Load().(string) == "" { - util.HandleError(errors.New("no access token found")) - } - - gatewayInstance.SetToken(accessToken.Load().(string)) - - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - - ctx, cancelCmd := context.WithCancel(cmd.Context()) - defer cancelCmd() - - go func() { - <-sigCh - log.Info().Msg("Received shutdown signal, shutting down gateway...") - cancelCmd() - cancelSdk() - - // Give graceful shutdown 10 seconds, then force exit on second signal - select { - case <-sigCh: - log.Warn().Msg("Second signal received, force exit triggered") - os.Exit(1) - case <-time.After(10 * time.Second): - log.Info().Msg("Graceful shutdown completed") - os.Exit(0) - } - }() - - // Token refresh goroutine - runs every 10 seconds - go func() { - tokenRefreshTicker := time.NewTicker(10 * time.Second) - defer tokenRefreshTicker.Stop() - - for { - select { - case <-tokenRefreshTicker.C: - if ctx.Err() != nil { - return - } - - newToken := infisicalClient.Auth().GetAccessToken() - if newToken != "" && newToken != accessToken.Load().(string) { - accessToken.Store(newToken) - gatewayInstance.SetToken(newToken) - } - - case <-ctx.Done(): - return - } - } - }() - - err = gatewayInstance.Start(ctx) - if err != nil { - util.HandleError(err, "unable to start gateway instance") - } - - }, -} - -var networkGatewayInstallCmd = &cobra.Command{ - Use: "install", - Short: "Install and enable systemd service for the gateway (requires sudo)", - Long: "Install and enable systemd service for the gateway. Must be run with sudo on Linux.", - Example: "sudo infisical network gateway install --token= --domain= --name= --relay=", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - if runtime.GOOS != "linux" { - util.HandleError(fmt.Errorf("systemd service installation is only supported on Linux")) - } - - if os.Geteuid() != 0 { - util.HandleError(fmt.Errorf("systemd service installation requires root/sudo privileges")) - } - - token, err := util.GetInfisicalToken(cmd) - if err != nil { - util.HandleError(err, "Unable to parse flag") - } - - if token == nil { - util.HandleError(errors.New("Token not found")) - } - - domain, err := cmd.Flags().GetString("domain") - if err != nil { - util.HandleError(err, "Unable to parse domain flag") - } - - gatewayName, err := cmd.Flags().GetString("name") - if err != nil { - util.HandleError(err, "Unable to parse name flag") - } - if gatewayName == "" { - util.HandleError(errors.New("Gateway name is required")) - } - - relayName, err := cmd.Flags().GetString("relay") - if err != nil { - util.HandleError(err, "Unable to parse relay flag") - } - if relayName == "" { - util.HandleError(errors.New("Relay is required")) - } - - err = gatewayv2.InstallGatewaySystemdService(token.Token, domain, gatewayName, relayName) - if err != nil { - util.HandleError(err, "Unable to install systemd service") - } - }, -} - -var networkGatewayUninstallCmd = &cobra.Command{ - Use: "uninstall", - Short: "Uninstall and remove systemd service for the gateway (requires sudo)", - Long: "Uninstall and remove systemd service for the gateway. Must be run with sudo on Linux.", - Example: "sudo infisical network gateway uninstall", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - if runtime.GOOS != "linux" { - util.HandleError(fmt.Errorf("systemd service installation is only supported on Linux")) - } - - if os.Geteuid() != 0 { - util.HandleError(fmt.Errorf("systemd service installation requires root/sudo privileges")) - } - - if err := gatewayv2.UninstallGatewaySystemdService(); err != nil { - util.HandleError(err, "Failed to uninstall systemd service") - } - }, -} - -func init() { - networkGatewayCmd.Flags().String("relay", "", "The name of the relay to connect to") - networkGatewayCmd.Flags().String("name", "", "The name of the gateway") - networkGatewayCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") - networkGatewayCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") - networkGatewayCmd.Flags().String("client-id", "", "client id for universal auth") - networkGatewayCmd.Flags().String("client-secret", "", "client secret for universal auth") - networkGatewayCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") - networkGatewayCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") - networkGatewayCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") - networkGatewayCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") - - networkGatewayInstallCmd.Flags().String("token", "", "Connect with Infisical using machine identity access token") - networkGatewayInstallCmd.Flags().String("domain", "", "Domain of your self-hosted Infisical instance") - networkGatewayInstallCmd.Flags().String("name", "", "The name of the gateway") - networkGatewayInstallCmd.Flags().String("relay", "", "The name of the relay") - - networkGatewayCmd.AddCommand(networkGatewayInstallCmd) - networkGatewayCmd.AddCommand(networkGatewayUninstallCmd) - - networkCmd.AddCommand(networkGatewayCmd) - - rootCmd.AddCommand(networkCmd) -} From 0da9260b63b75b25f6da62fe052f07c5e1dd452e Mon Sep 17 00:00:00 2001 From: Sheen Capadngan Date: Thu, 11 Sep 2025 03:20:09 +0800 Subject: [PATCH 38/38] misc: updated cert renewal every 6 hours --- packages/gateway-v2/gateway.go | 4 ++-- packages/relay/relay.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 939f2d48..22b23eae 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -660,10 +660,10 @@ func (vc *virtualConnection) SetWriteDeadline(t time.Time) error { return nil } -// startCertificateRenewal runs a background process to renew certificates every 10 days +// startCertificateRenewal runs a background process to renew certificates every 6 hours func (g *Gateway) startCertificateRenewal(ctx context.Context) { log.Info().Msg("Starting gateway certificate renewal goroutine") - ticker := time.NewTicker(10 * 24 * time.Hour) + ticker := time.NewTicker(6 * 60 * time.Minute) defer ticker.Stop() for { diff --git a/packages/relay/relay.go b/packages/relay/relay.go index 9cc4af9c..285e9e0a 100644 --- a/packages/relay/relay.go +++ b/packages/relay/relay.go @@ -478,9 +478,9 @@ func (r *Relay) cleanup() { log.Info().Msg("Relay server shutdown complete") } -// startCertificateRenewal runs a background process to renew certificates every 24 hours +// startCertificateRenewal runs a background process to renew certificates every 6 hours func (r *Relay) startCertificateRenewal(ctx context.Context) { - ticker := time.NewTicker(10 * 24 * time.Hour) + ticker := time.NewTicker(6 * 60 * time.Minute) defer ticker.Stop() for {