diff --git a/packages/api/api.go b/packages/api/api.go index a9b204b6..352465ea 100644 --- a/packages/api/api.go +++ b/packages/api/api.go @@ -39,7 +39,11 @@ const ( operationCallRegisterGatewayIdentityV1 = "CallRegisterGatewayIdentityV1" operationCallExchangeRelayCertV1 = "CallExchangeRelayCertV1" operationCallGatewayHeartBeatV1 = "CallGatewayHeartBeatV1" + operationCallGatewayHeartBeatV2 = "CallGatewayHeartBeatV2" operationCallBootstrapInstance = "CallBootstrapInstance" + operationCallRegisterInstanceRelay = "CallRegisterInstanceRelay" + operationCallRegisterOrgRelay = "CallRegisterOrgRelay" + operationCallRegisterGateway = "CallRegisterGateway" ) func CallGetEncryptedWorkspaceKey(httpClient *resty.Client, request GetEncryptedWorkspaceKeyRequest) (GetEncryptedWorkspaceKeyResponse, error) { @@ -652,6 +656,23 @@ func CallGatewayHeartBeatV1(httpClient *resty.Client) error { return nil } +func CallGatewayHeartBeatV2(httpClient *resty.Client) error { + response, err := httpClient. + R(). + SetHeader("User-Agent", USER_AGENT). + Post(fmt.Sprintf("%v/v2/gateways/heartbeat", config.INFISICAL_URL)) + + if err != nil { + return NewGenericRequestError(operationCallGatewayHeartBeatV2, err) + } + + if response.IsError() { + return NewAPIErrorWithResponse(operationCallGatewayHeartBeatV2, response, nil) + } + + return nil +} + func CallBootstrapInstance(httpClient *resty.Client, request BootstrapInstanceRequest) (BootstrapInstanceResponse, error) { var resBody BootstrapInstanceResponse response, err := httpClient. @@ -671,3 +692,63 @@ func CallBootstrapInstance(httpClient *resty.Client, request BootstrapInstanceRe return resBody, nil } + +func CallRegisterInstanceRelay(httpClient *resty.Client, request RegisterRelayRequest) (RegisterRelayResponse, error) { + var resBody RegisterRelayResponse + response, err := httpClient. + R(). + SetResult(&resBody). + SetHeader("User-Agent", USER_AGENT). + SetBody(request). + Post(fmt.Sprintf("%v/v1/relays/register-instance-relay", config.INFISICAL_URL)) + + if err != nil { + return RegisterRelayResponse{}, NewGenericRequestError(operationCallRegisterInstanceRelay, err) + } + + if response.IsError() { + return RegisterRelayResponse{}, NewAPIErrorWithResponse(operationCallRegisterInstanceRelay, response, nil) + } + + return resBody, nil +} + +func CallRegisterRelay(httpClient *resty.Client, request RegisterRelayRequest) (RegisterRelayResponse, error) { + var resBody RegisterRelayResponse + response, err := httpClient. + R(). + SetResult(&resBody). + SetHeader("User-Agent", USER_AGENT). + SetBody(request). + Post(fmt.Sprintf("%v/v1/relays/register-org-relay", config.INFISICAL_URL)) + + if err != nil { + return RegisterRelayResponse{}, NewGenericRequestError(operationCallRegisterOrgRelay, err) + } + + if response.IsError() { + return RegisterRelayResponse{}, NewAPIErrorWithResponse(operationCallRegisterOrgRelay, response, nil) + } + + return resBody, nil +} + +func CallRegisterGateway(httpClient *resty.Client, request RegisterGatewayRequest) (RegisterGatewayResponse, error) { + var resBody RegisterGatewayResponse + response, err := httpClient. + R(). + SetResult(&resBody). + SetHeader("User-Agent", USER_AGENT). + SetBody(request). + Post(fmt.Sprintf("%v/v2/gateways", config.INFISICAL_URL)) + + if err != nil { + return RegisterGatewayResponse{}, NewGenericRequestError(operationCallRegisterGateway, err) + } + + if response.IsError() { + return RegisterGatewayResponse{}, NewAPIErrorWithResponse(operationCallRegisterGateway, response, nil) + } + + return resBody, nil +} diff --git a/packages/api/model.go b/packages/api/model.go index 3f10b4ca..699bc8c9 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -703,3 +703,41 @@ type BootstrapUser struct { Username string `json:"username"` SuperAdmin bool `json:"superAdmin"` } + +type RegisterRelayRequest struct { + Host string `json:"host"` + Name string `json:"name"` +} + +type RegisterRelayResponse struct { + PKI struct { + ServerCertificate string `json:"serverCertificate"` + ServerPrivateKey string `json:"serverPrivateKey"` + ClientCertificateChain string `json:"clientCertificateChain"` + } `json:"pki"` + SSH struct { + ServerCertificate string `json:"serverCertificate"` + ServerPrivateKey string `json:"serverPrivateKey"` + ClientCAPublicKey string `json:"clientCAPublicKey"` + } `json:"ssh"` +} + +type RegisterGatewayRequest struct { + RelayName string `json:"relayName"` + Name string `json:"name"` +} + +type RegisterGatewayResponse struct { + GatewayID string `json:"gatewayId"` + RelayHost string `json:"relayHost"` + PKI struct { + ServerCertificate string `json:"serverCertificate"` + ServerPrivateKey string `json:"serverPrivateKey"` + ClientCertificateChain string `json:"clientCertificateChain"` + } `json:"pki"` + SSH struct { + ClientCertificate string `json:"clientCertificate"` + ClientPrivateKey string `json:"clientPrivateKey"` + ServerCAPublicKey string `json:"serverCAPublicKey"` + } `json:"ssh"` +} diff --git a/packages/cmd/gateway.go b/packages/cmd/gateway.go index abc4d694..42bdbe61 100644 --- a/packages/cmd/gateway.go +++ b/packages/cmd/gateway.go @@ -14,6 +14,7 @@ import ( "github.com/Infisical/infisical-merge/packages/api" "github.com/Infisical/infisical-merge/packages/config" "github.com/Infisical/infisical-merge/packages/gateway" + gatewayv2 "github.com/Infisical/infisical-merge/packages/gateway-v2" "github.com/Infisical/infisical-merge/packages/util" infisicalSdk "github.com/infisical/go-sdk" "github.com/pkg/errors" @@ -87,6 +88,7 @@ var gatewayCmd = &cobra.Command{ DisableFlagsInUseLine: true, Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { + log.Info().Msg("DEPRECATION NOTICE: The 'infisical gateway' command will be deprecated in a future version. Please use 'infisical gateway start'.\nNOTE: This requires manually updating your existing resources to point to the new gateway.") infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) if err != nil { @@ -199,6 +201,105 @@ var gatewayCmd = &cobra.Command{ }, } +var gatewayStartCmd = &cobra.Command{ + Use: "start", + Short: "Start the new Infisical gateway", + Long: "Start the new Infisical gateway component.", + Example: "infisical gateway start --relay= --name= --token=", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + relayName, err := util.GetCmdFlagOrEnv(cmd, "relay", []string{gatewayv2.RELAY_NAME_ENV_NAME}) + if err != nil { + util.HandleError(err, fmt.Sprintf("unable to get relay flag or %s env", gatewayv2.RELAY_NAME_ENV_NAME)) + } + + gatewayName, err := util.GetCmdFlagOrEnv(cmd, "name", []string{gatewayv2.GATEWAY_NAME_ENV_NAME}) + if err != nil { + util.HandleError(err, fmt.Sprintf("unable to get name flag or %s env", gatewayv2.GATEWAY_NAME_ENV_NAME)) + } + + gatewayInstance, err := gatewayv2.NewGateway(&gatewayv2.GatewayConfig{ + Name: gatewayName, + RelayName: relayName, + ReconnectDelay: 10 * time.Second, + }) + + if err != nil { + util.HandleError(err, "unable to create gateway instance") + } + + infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) + if err != nil { + util.HandleError(err, "unable to get infisical client") + } + defer cancelSdk() + + var accessToken atomic.Value + accessToken.Store(infisicalClient.Auth().GetAccessToken()) + + if accessToken.Load().(string) == "" { + util.HandleError(errors.New("no access token found")) + } + + gatewayInstance.SetToken(accessToken.Load().(string)) + + Telemetry.CaptureEvent("cli-command:gateway-v2", posthog.NewProperties().Set("version", util.CLI_VERSION)) + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + ctx, cancelCmd := context.WithCancel(cmd.Context()) + defer cancelCmd() + + go func() { + <-sigCh + log.Info().Msg("Received shutdown signal, shutting down gateway...") + cancelCmd() + cancelSdk() + + // Give graceful shutdown 10 seconds, then force exit on second signal + select { + case <-sigCh: + log.Warn().Msg("Second signal received, force exit triggered") + os.Exit(1) + case <-time.After(10 * time.Second): + log.Info().Msg("Graceful shutdown completed") + os.Exit(0) + } + }() + + // Token refresh goroutine - runs every 10 seconds + go func() { + tokenRefreshTicker := time.NewTicker(10 * time.Second) + defer tokenRefreshTicker.Stop() + + for { + select { + case <-tokenRefreshTicker.C: + if ctx.Err() != nil { + return + } + + newToken := infisicalClient.Auth().GetAccessToken() + if newToken != "" && newToken != accessToken.Load().(string) { + accessToken.Store(newToken) + gatewayInstance.SetToken(newToken) + } + + case <-ctx.Done(): + return + } + } + }() + + err = gatewayInstance.Start(ctx) + if err != nil { + util.HandleError(err, "unable to start gateway instance") + } + }, +} + var gatewayInstallCmd = &cobra.Command{ Use: "install", Short: "Install and enable systemd service for the gateway (requires sudo)", @@ -265,6 +366,99 @@ var gatewayUninstallCmd = &cobra.Command{ }, } +var gatewaySystemdCmd = &cobra.Command{ + Use: "systemd", + Short: "Manage systemd service for Infisical gateway", + Long: "Manage systemd service for Infisical gateway. Use 'systemd install' to install and enable the service.", + Example: `sudo infisical gateway systemd install --token= --domain= --name= --relay= + sudo infisical gateway systemd uninstall`, + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, +} + +var gatewaySystemdInstallCmd = &cobra.Command{ + Use: "install", + Short: "Install and enable systemd service for the gateway (v2) (requires sudo)", + Long: "Install and enable systemd service for the new gateway (v2). Must be run with sudo on Linux.", + Example: "sudo infisical gateway systemd install --token= --domain= --name= --relay=", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + if runtime.GOOS != "linux" { + util.HandleError(fmt.Errorf("systemd service installation is only supported on Linux")) + } + + if os.Geteuid() != 0 { + util.HandleError(fmt.Errorf("systemd service installation requires root/sudo privileges")) + } + + token, err := util.GetInfisicalToken(cmd) + if err != nil { + util.HandleError(err, "Unable to parse flag") + } + + if token == nil { + util.HandleError(errors.New("Token not found")) + } + + domain, err := cmd.Flags().GetString("domain") + if err != nil { + util.HandleError(err, "Unable to parse domain flag") + } + + gatewayName, err := cmd.Flags().GetString("name") + if err != nil { + util.HandleError(err, "Unable to parse name flag") + } + if gatewayName == "" { + util.HandleError(errors.New("Gateway name is required")) + } + + relayName, err := cmd.Flags().GetString("relay") + if err != nil { + util.HandleError(err, "Unable to parse relay flag") + } + if relayName == "" { + util.HandleError(errors.New("Relay is required")) + } + + err = gatewayv2.InstallGatewaySystemdService(token.Token, domain, gatewayName, relayName) + if err != nil { + util.HandleError(err, "Unable to install systemd service") + } + + enableCmd := exec.Command("systemctl", "enable", "infisical-gateway") + if err := enableCmd.Run(); err != nil { + util.HandleError(err, "Failed to enable systemd service") + } + + log.Info().Msg("Successfully installed and enabled infisical-gateway service") + log.Info().Msg("To start the service, run: sudo systemctl start infisical-gateway") + }, +} + +var gatewaySystemdUninstallCmd = &cobra.Command{ + Use: "uninstall", + Short: "Uninstall and remove systemd service for the gateway (requires sudo)", + Long: "Uninstall and remove systemd service for the gateway. Must be run with sudo on Linux.", + Example: "sudo infisical gateway systemd uninstall", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + if runtime.GOOS != "linux" { + util.HandleError(fmt.Errorf("systemd service installation is only supported on Linux")) + } + + if os.Geteuid() != 0 { + util.HandleError(fmt.Errorf("systemd service installation requires root/sudo privileges")) + } + + if err := gatewayv2.UninstallGatewaySystemdService(); err != nil { + util.HandleError(err, "Failed to uninstall systemd service") + } + }, +} + var gatewayRelayCmd = &cobra.Command{ Example: `infisical gateway relay`, Short: "Used to run infisical gateway relay", @@ -293,24 +487,47 @@ var gatewayRelayCmd = &cobra.Command{ } func init() { + // Legacy gateway command flags (v1) gatewayCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") - gatewayCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") - gatewayCmd.Flags().String("client-id", "", "client id for universal auth") gatewayCmd.Flags().String("client-secret", "", "client secret for universal auth") - gatewayCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") gatewayCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") gatewayCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") - gatewayCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") + // Gateway start command flags (v2) + gatewayStartCmd.Flags().String("relay", "", "name of the relay to connect to") + gatewayStartCmd.Flags().String("name", "", "name of the gateway") + gatewayStartCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") + gatewayStartCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") + gatewayStartCmd.Flags().String("client-id", "", "client id for universal auth") + gatewayStartCmd.Flags().String("client-secret", "", "client secret for universal auth") + gatewayStartCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") + gatewayStartCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") + gatewayStartCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") + gatewayStartCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") + + // Legacy install command flags (v1) gatewayInstallCmd.Flags().String("token", "", "Connect with Infisical using machine identity access token") gatewayInstallCmd.Flags().String("domain", "", "Domain of your self-hosted Infisical instance") + // Systemd install command flags (v2) + gatewaySystemdInstallCmd.Flags().String("token", "", "Connect with Infisical using machine identity access token") + gatewaySystemdInstallCmd.Flags().String("domain", "", "Domain of your self-hosted Infisical instance") + gatewaySystemdInstallCmd.Flags().String("name", "", "The name of the gateway") + gatewaySystemdInstallCmd.Flags().String("relay", "", "The name of the relay") + + // Gateway relay command flags gatewayRelayCmd.Flags().String("config", "", "Relay config yaml file path") + // Wire up command hierarchy + gatewaySystemdCmd.AddCommand(gatewaySystemdInstallCmd) + gatewaySystemdCmd.AddCommand(gatewaySystemdUninstallCmd) + + gatewayCmd.AddCommand(gatewayStartCmd) + gatewayCmd.AddCommand(gatewaySystemdCmd) gatewayCmd.AddCommand(gatewayInstallCmd) gatewayCmd.AddCommand(gatewayUninstallCmd) gatewayCmd.AddCommand(gatewayRelayCmd) diff --git a/packages/cmd/relay.go b/packages/cmd/relay.go new file mode 100644 index 00000000..dac6b1c5 --- /dev/null +++ b/packages/cmd/relay.go @@ -0,0 +1,156 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "sync/atomic" + "syscall" + "time" + + gatewayv2 "github.com/Infisical/infisical-merge/packages/gateway-v2" + "github.com/Infisical/infisical-merge/packages/relay" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" +) + +var relayCmd = &cobra.Command{ + Use: "relay", + Short: "Relay-related commands", + Long: "Relay-related commands for Infisical", +} + +var relayStartCmd = &cobra.Command{ + Use: "start", + Short: "Start the Infisical relay component", + Long: "Start the Infisical relay component", + Example: "infisical relay start --type=instance --ip= --name= --token=", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + + relayName, err := cmd.Flags().GetString("name") + if err != nil || relayName == "" { + util.HandleError(err, "unable to get name flag") + } + + host, err := cmd.Flags().GetString("host") + if err != nil || host == "" { + util.HandleError(err, "unable to get host flag") + } + + instanceType, err := cmd.Flags().GetString("type") + if err != nil { + util.HandleError(err, "unable to get type flag") + } + + relayInstance, err := relay.NewRelay(&relay.RelayConfig{ + RelayName: relayName, + SSHPort: "2222", + TLSPort: "8443", + Host: host, + Type: instanceType, + }) + + if err != nil { + util.HandleError(err, "unable to create relay instance") + } + + if instanceType == "instance" { + relayAuthSecret := os.Getenv(gatewayv2.RELAY_AUTH_SECRET_ENV_NAME) + if relayAuthSecret == "" { + util.HandleError(fmt.Errorf("%s is not set", gatewayv2.RELAY_AUTH_SECRET_ENV_NAME), "unable to get relay auth secret") + } + + relayInstance.SetToken(relayAuthSecret) + } else { + infisicalClient, cancelSdk, err := getInfisicalSdkInstance(cmd) + if err != nil { + util.HandleError(err, "unable to get infisical client") + } + defer cancelSdk() + + var accessToken atomic.Value + accessToken.Store(infisicalClient.Auth().GetAccessToken()) + + if accessToken.Load().(string) == "" { + util.HandleError(errors.New("no access token found")) + } + + relayInstance.SetToken(accessToken.Load().(string)) + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + ctx, cancelCmd := context.WithCancel(cmd.Context()) + defer cancelCmd() + + go func() { + <-sigCh + log.Info().Msg("Received shutdown signal, shutting down relay...") + cancelCmd() + cancelSdk() + + // Give graceful shutdown 10 seconds, then force exit on second signal + select { + case <-sigCh: + log.Warn().Msg("Second signal received, force exit triggered") + os.Exit(1) + case <-time.After(10 * time.Second): + log.Info().Msg("Graceful shutdown completed") + os.Exit(0) + } + }() + + // Token refresh goroutine - runs every 10 seconds + go func() { + tokenRefreshTicker := time.NewTicker(10 * time.Second) + defer tokenRefreshTicker.Stop() + + for { + select { + case <-tokenRefreshTicker.C: + if ctx.Err() != nil { + return + } + + newToken := infisicalClient.Auth().GetAccessToken() + if newToken != "" && newToken != accessToken.Load().(string) { + accessToken.Store(newToken) + relayInstance.SetToken(newToken) + } + + case <-ctx.Done(): + return + } + } + }() + } + + err = relayInstance.Start(cmd.Context()) + if err != nil { + util.HandleError(err, "unable to start relay instance") + } + }, +} + +func init() { + relayStartCmd.Flags().String("type", "org", "The type of relay to run. Must be either 'instance' or 'org'") + relayStartCmd.Flags().String("host", "", "The IP or hostname for the relay") + relayStartCmd.Flags().String("name", "", "The name of the relay") + relayStartCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") + relayStartCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") + relayStartCmd.Flags().String("client-id", "", "client id for universal auth") + relayStartCmd.Flags().String("client-secret", "", "client secret for universal auth") + relayStartCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") + relayStartCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth") + relayStartCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth") + relayStartCmd.Flags().String("jwt", "", "JWT for jwt-based auth methods [oidc-auth, jwt-auth]") + + relayCmd.AddCommand(relayStartCmd) + + rootCmd.AddCommand(relayCmd) +} diff --git a/packages/gateway-v2/connection.go b/packages/gateway-v2/connection.go new file mode 100644 index 00000000..04a856b1 --- /dev/null +++ b/packages/gateway-v2/connection.go @@ -0,0 +1,275 @@ +package gatewayv2 + +import ( + "bufio" + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "strings" + "time" + + "github.com/rs/zerolog/log" +) + +func buildHttpInternalServerError(message string) string { + return fmt.Sprintf("HTTP/1.1 500 Internal Server Error\r\nContent-Type: application/json\r\n\r\n{\"message\": \"gateway: %s\"}", message) +} + +func handleHTTPProxy(ctx context.Context, conn *tls.Conn, reader *bufio.Reader, forwardConfig *ForwardConfig) error { + targetURL := fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort) + caCert := forwardConfig.CACertificate + verifyTLS := forwardConfig.VerifyTLS + + transport := &http.Transport{ + DisableKeepAlives: false, + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + } + + if strings.HasPrefix(targetURL, "https://") { + tlsConfig := &tls.Config{} + + if len(caCert) > 0 { + caCertPool := x509.NewCertPool() + if caCertPool.AppendCertsFromPEM(caCert) { + tlsConfig.RootCAs = caCertPool + log.Info().Msg("Using provided CA certificate from gateway client") + } else { + log.Error().Msg("Failed to parse provided CA certificate") + } + } + + tlsConfig.InsecureSkipVerify = !verifyTLS + log.Info().Msgf("TLS verification set to: %v", verifyTLS) + + transport.TLSClientConfig = tlsConfig + } + + // Loop to handle multiple HTTP requests on the same connection + for { + select { + case <-ctx.Done(): + log.Info().Msg("Context cancelled, closing HTTP proxy connection") + return ctx.Err() + default: + } + + log.Info().Msg("Attempting to read HTTP request...") + + // Create a channel to receive the request or error + reqCh := make(chan *http.Request, 1) + errCh := make(chan error, 1) + + // Read request in a goroutine so we can cancel it + go func() { + req, err := http.ReadRequest(reader) + if err != nil { + errCh <- err + } else { + reqCh <- req + } + }() + + var req *http.Request + select { + case <-ctx.Done(): + log.Info().Msg("Context cancelled while reading HTTP request") + return ctx.Err() + case err := <-errCh: + if errors.Is(err, io.EOF) { + log.Info().Msg("Client closed HTTP connection") + return nil + } + log.Error().Msgf("Failed to read HTTP request: %v", err) + return fmt.Errorf("failed to read HTTP request: %v", err) + case req = <-reqCh: + // Successfully received request + } + + log.Info().Msgf("Received HTTP request: %s", req.URL.Path) + + actionHeader := HttpProxyAction(req.Header.Get(INFISICAL_HTTP_PROXY_ACTION_HEADER)) + + // Only platform actor can perform privileged actions + if actionHeader != "" && forwardConfig.ActorType == ActorTypePlatform { + if actionHeader == HttpProxyActionInjectGatewayK8sServiceAccountToken { + token, err := os.ReadFile(KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH) + if err != nil { + conn.Write([]byte(buildHttpInternalServerError("failed to read k8s sa auth token"))) + continue // Continue to next request instead of returning + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(token))) + log.Info().Msgf("Injected gateway k8s SA auth token in request to %s", targetURL) + } else if actionHeader == HttpProxyActionUseGatewayK8sServiceAccount { + // will work without a target URL set + // set the ca cert to the pod's k8s service account ca cert: + caCert, err := os.ReadFile(KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH) + if err != nil { + conn.Write([]byte(buildHttpInternalServerError("failed to read k8s sa ca cert"))) + continue + } + + caCertPool := x509.NewCertPool() + if ok := caCertPool.AppendCertsFromPEM(caCert); !ok { + conn.Write([]byte(buildHttpInternalServerError("failed to parse k8s sa ca cert"))) + continue + } + + transport.TLSClientConfig = &tls.Config{ + RootCAs: caCertPool, + } + + // set authorization header to the pod's k8s service account token: + token, err := os.ReadFile(KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH) + if err != nil { + conn.Write([]byte(buildHttpInternalServerError("failed to read k8s sa auth token"))) + continue + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(token))) + + // update the target URL to point to the kubernetes API server: + kubernetesServiceHost := os.Getenv(KUBERNETES_SERVICE_HOST_ENV_NAME) + kubernetesServicePort := os.Getenv(KUBERNETES_SERVICE_PORT_HTTPS_ENV_NAME) + + fullBaseUrl := fmt.Sprintf("https://%s:%s", kubernetesServiceHost, kubernetesServicePort) + targetURL = fullBaseUrl + + log.Info().Msgf("Redirected request to Kubernetes API server: %s", targetURL) + } + + req.Header.Del(INFISICAL_HTTP_PROXY_ACTION_HEADER) + } + + // Build full target URL + var targetFullURL string + if strings.HasPrefix(targetURL, "http://") || strings.HasPrefix(targetURL, "https://") { + baseURL := strings.TrimSuffix(targetURL, "/") + targetFullURL = baseURL + req.URL.Path + if req.URL.RawQuery != "" { + targetFullURL += "?" + req.URL.RawQuery + } + } else { + baseURL := strings.TrimSuffix("http://"+targetURL, "/") + targetFullURL = baseURL + req.URL.Path + if req.URL.RawQuery != "" { + targetFullURL += "?" + req.URL.RawQuery + } + } + + // create the request to the target + proxyReq, err := http.NewRequest(req.Method, targetFullURL, req.Body) + if err != nil { + log.Error().Msgf("Failed to create proxy request: %v", err) + conn.Write([]byte(buildHttpInternalServerError("failed to create proxy request"))) + continue // Continue to next request + } + proxyReq.Header = req.Header.Clone() + + log.Info().Msgf("Proxying %s %s to %s", req.Method, req.URL.Path, targetFullURL) + + client := &http.Client{ + Transport: transport, + Timeout: 30 * time.Second, + } + + resp, err := client.Do(proxyReq) + if err != nil { + log.Error().Msgf("Failed to reach target: %v", err) + conn.Write([]byte(buildHttpInternalServerError(fmt.Sprintf("failed to reach target due to networking error: %s", err.Error())))) + continue // Continue to next request + } + + // Write the entire response (status line, headers, body) to the connection + resp.Header.Del("Connection") + + log.Info().Msgf("Writing response to connection: %s", resp.Status) + + if err := resp.Write(conn); err != nil { + log.Error().Err(err).Msg("Failed to write response to connection") + resp.Body.Close() + return fmt.Errorf("failed to write response to connection: %w", err) + } + + resp.Body.Close() + + // Check if client wants to close connection + if req.Header.Get("Connection") == "close" { + log.Info().Msg("Client requested connection close") + return nil + } + } +} + +func handleTCPProxy(ctx context.Context, conn *tls.Conn, forwardConfig *ForwardConfig) error { + target := fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort) + localConn, err := net.Dial("tcp", target) + if err != nil { + log.Error().Msgf("Failed to connect to local service %s: %v", target, err) + return fmt.Errorf("failed to connect to local service %s: %v", target, err) + } + defer localConn.Close() + + log.Info(). + Str("target", target). + Msg("TCP proxy connection established to local service") + + // Create a context for this connection that gets cancelled when the parent context is cancelled + // or when either connection closes + connCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Error channel to collect errors from both copy goroutines + errCh := make(chan error, 2) + + // Forward data from TLS connection to local service + go func() { + defer cancel() + bytesCopied, err := io.Copy(localConn, conn) + log.Info().Int64("bytes", bytesCopied).Msg("Copied from client to local service") + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + log.Debug().Msgf("TLS to local copy ended normally: %v", err) + } else { + log.Error().Msgf("TLS to local copy failed: %v", err) + } + } + errCh <- err + }() + + // Forward data from local service to TLS connection + go func() { + defer cancel() + bytesCopied, err := io.Copy(conn, localConn) + log.Info().Int64("bytes", bytesCopied).Msg("Copied from local service to client") + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + log.Debug().Msgf("Local to TLS copy ended normally: %v", err) + } else { + log.Error().Msgf("Local to TLS copy failed: %v", err) + } + } + errCh <- err + }() + + // Wait for either context cancellation or one of the copy operations to complete + select { + case <-connCtx.Done(): + log.Info().Msg("TCP proxy connection cancelled") + return connCtx.Err() + case err := <-errCh: + // One of the copy operations completed (or failed) + // The defer cancel() will stop the other goroutine + return err + } +} + +func handlePing(ctx context.Context, conn *tls.Conn, reader *bufio.Reader) error { + conn.Write([]byte("PONG\n")) + return nil +} diff --git a/packages/gateway-v2/constants.go b/packages/gateway-v2/constants.go new file mode 100644 index 00000000..87597511 --- /dev/null +++ b/packages/gateway-v2/constants.go @@ -0,0 +1,22 @@ +package gatewayv2 + +const ( + KUBERNETES_SERVICE_HOST_ENV_NAME = "KUBERNETES_SERVICE_HOST" + KUBERNETES_SERVICE_PORT_HTTPS_ENV_NAME = "KUBERNETES_SERVICE_PORT_HTTPS" + KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" + KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/token" + + RELAY_NAME_ENV_NAME = "INFISICAL_RELAY_NAME" + GATEWAY_NAME_ENV_NAME = "INFISICAL_GATEWAY_NAME" + + RELAY_AUTH_SECRET_ENV_NAME = "INFISICAL_RELAY_AUTH_SECRET" + + INFISICAL_HTTP_PROXY_ACTION_HEADER = "x-infisical-action" +) + +type HttpProxyAction string + +const ( + HttpProxyActionInjectGatewayK8sServiceAccountToken HttpProxyAction = "inject-k8s-sa-auth-token" + HttpProxyActionUseGatewayK8sServiceAccount HttpProxyAction = "use-k8s-sa" +) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go new file mode 100644 index 00000000..22b23eae --- /dev/null +++ b/packages/gateway-v2/gateway.go @@ -0,0 +1,693 @@ +package gatewayv2 + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "net" + "strconv" + "strings" + "sync" + "time" + + "github.com/Infisical/infisical-merge/packages/api" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/go-resty/resty/v2" + "github.com/rs/zerolog/log" + "golang.org/x/crypto/ssh" +) + +// ForwardMode represents the type of forwarding +type ForwardMode string + +const ( + ForwardModeHTTP ForwardMode = "HTTP" + ForwardModeTCP ForwardMode = "TCP" + ForwardModePing ForwardMode = "PING" +) + +type ActorType string + +const ( + ActorTypePlatform ActorType = "platform" + ActorTypeUser ActorType = "user" +) + +const GATEWAY_ROUTING_INFO_OID = "1.3.6.1.4.1.12345.100.1" +const GATEWAY_ACTOR_OID = "1.3.6.1.4.1.12345.100.2" + +// ForwardConfig contains the configuration for forwarding +type ForwardConfig struct { + Mode ForwardMode + CACertificate []byte // Decoded CA certificate for HTTPS verification + VerifyTLS bool // Whether to verify TLS certificates + TargetHost string + TargetPort int + ActorType ActorType +} + +// RoutingInfo represents the routing information embedded in client certificates +type RoutingInfo struct { + TargetHost string `json:"targetHost"` + TargetPort int `json:"targetPort"` +} + +type ActorDetails struct { + Type string `json:"type"` +} + +type GatewayConfig struct { + Name string + RelayName string + IdentityToken string + SSHPort int + ReconnectDelay time.Duration +} + +type Gateway struct { + GatewayID string + + httpClient *resty.Client + config *GatewayConfig + sshClient *ssh.Client + + // Certificate storage + certificates *api.RegisterGatewayResponse + + // mTLS server components + tlsConfig *tls.Config + + // Connection management + mu sync.RWMutex + isConnected bool + ctx context.Context + cancel context.CancelFunc +} + +// NewGateway creates a new gateway instance +func NewGateway(config *GatewayConfig) (*Gateway, error) { + httpClient, err := util.GetRestyClientWithCustomHeaders() + if err != nil { + return nil, fmt.Errorf("unable to get client with custom headers [err=%v]", err) + } + + httpClient.SetAuthToken(config.IdentityToken) + + ctx, cancel := context.WithCancel(context.Background()) + + // Set default SSH port if not specified + if config.SSHPort == 0 { + config.SSHPort = 2222 + } + + return &Gateway{ + httpClient: httpClient, + config: config, + ctx: ctx, + cancel: cancel, + }, nil +} + +func (g *Gateway) registerHeartBeat(ctx context.Context, errCh chan error) { + sendHeartbeat := func() { + if err := api.CallGatewayHeartBeatV2(g.httpClient); err != nil { + log.Warn().Msgf("Heartbeat failed: %v", err) + select { + case errCh <- err: + default: + log.Warn().Msg("Error channel full, skipping heartbeat error report") + } + } else { + log.Info().Msg("Gateway is reachable by Infisical") + } + } + + go func() { + select { + case <-ctx.Done(): + return + case <-time.After(10 * time.Second): + sendHeartbeat() + } + + ticker := time.NewTicker(30 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + sendHeartbeat() + } + } + }() +} + +func (g *Gateway) Start(ctx context.Context) error { + log.Info().Msgf("Starting gateway") + + errCh := make(chan error, 1) + g.registerHeartBeat(ctx, errCh) + + // Start certificate renewal goroutine + go g.startCertificateRenewal(ctx) + + go func() { + for { + select { + case <-ctx.Done(): + return + case err := <-errCh: + log.Warn().Msgf("Heartbeat error received: %v", err) + } + } + }() + + for { + select { + case <-ctx.Done(): + log.Info().Msgf("Gateway stopped by context cancellation") + return nil + default: + if err := g.connectAndServe(); err != nil { + log.Error().Msgf("Connection failed: %v, retrying in %v...", err, g.config.ReconnectDelay) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(g.config.ReconnectDelay): + continue + } + } + // If we get here, the connection was closed gracefully + log.Info().Msgf("Connection closed, reconnecting in 10 seconds...") + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(10 * time.Second): + continue + } + } + } +} + +func (g *Gateway) SetToken(token string) { + g.httpClient.SetAuthToken(token) +} + +func (g *Gateway) Stop() { + g.cancel() + + g.mu.Lock() + if g.sshClient != nil { + g.sshClient.Close() + g.sshClient = nil + } + g.isConnected = false + g.mu.Unlock() +} + +func (g *Gateway) connectAndServe() error { + if err := g.registerGateway(); err != nil { + return fmt.Errorf("failed to register gateway: %v", err) + } + + // Create SSH client config + sshConfig, err := g.createSSHConfig() + if err != nil { + return fmt.Errorf("failed to create SSH config: %v", err) + } + + // Connect to Relay server + log.Info().Msgf("Connecting to relay server %s on %s:%d...", g.config.RelayName, g.certificates.RelayHost, g.config.SSHPort) + client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", g.certificates.RelayHost, g.config.SSHPort), sshConfig) + if err != nil { + return fmt.Errorf("failed to connect to SSH server: %v", err) + } + log.Info().Msgf("Relay connection established for gateway") + + g.mu.Lock() + g.sshClient = client + g.isConnected = true + g.mu.Unlock() + + defer func() { + g.mu.Lock() + g.sshClient = nil + g.isConnected = false + g.mu.Unlock() + client.Close() + }() + + // Handle incoming channels from the server + channels := client.HandleChannelOpen("direct-tcpip") + if channels == nil { + return fmt.Errorf("failed to handle channel open") + } + + // Monitor for context cancellation and close SSH client + go func() { + <-g.ctx.Done() + log.Info().Msg("Context cancelled, closing relay connection...") + client.Close() + }() + + // Process incoming channels with context cancellation support + for { + select { + case <-g.ctx.Done(): + log.Info().Msg("Context cancelled, stopping channel processing") + return g.ctx.Err() + case newChannel, ok := <-channels: + if !ok { + log.Info().Msg("SSH channels closed") + return nil + } + go g.handleIncomingChannel(newChannel) + } + } +} + +func (g *Gateway) registerGateway() error { + body := api.RegisterGatewayRequest{ + RelayName: g.config.RelayName, + Name: g.config.Name, + } + + certResp, err := api.CallRegisterGateway(g.httpClient, body) + if err != nil { + return fmt.Errorf("failed to register gateway: %v", err) + } + + g.GatewayID = certResp.GatewayID + g.certificates = &certResp + log.Info().Msgf("Successfully registered gateway and received certificates") + + // Setup mTLS config + if err := g.setupTLSConfig(); err != nil { + return fmt.Errorf("failed to setup TLS config: %v", err) + } + + return nil +} + +func (g *Gateway) setupTLSConfig() error { + serverCertBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerCertificate)) + if serverCertBlock == nil { + return fmt.Errorf("failed to decode server certificate") + } + + serverKeyBlock, _ := pem.Decode([]byte(g.certificates.PKI.ServerPrivateKey)) + if serverKeyBlock == nil { + return fmt.Errorf("failed to decode server private key") + } + + serverKey, err := x509.ParsePKCS8PrivateKey(serverKeyBlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse server private key: %v", err) + } + + clientCAPool := x509.NewCertPool() + var chainCerts [][]byte + chainData := []byte(g.certificates.PKI.ClientCertificateChain) + for { + block, rest := pem.Decode(chainData) + if block == nil { + break + } + chainCerts = append(chainCerts, block.Bytes) + chainData = rest + } + + for i, certBytes := range chainCerts { + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + log.Info().Msgf("Failed to parse client chain certificate %d: %v", i+1, err) + continue + } + clientCAPool.AddCert(cert) + } + + g.tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{serverCertBlock.Bytes}, + PrivateKey: serverKey, + }, + }, + ClientCAs: clientCAPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS12, + NextProtos: []string{"infisical-http-proxy", "infisical-tcp-proxy", "infisical-ping"}, + } + + return nil +} + +func (g *Gateway) createSSHConfig() (*ssh.ClientConfig, error) { + privateKey, err := ssh.ParsePrivateKey([]byte(g.certificates.SSH.ClientPrivateKey)) + if err != nil { + return nil, fmt.Errorf("failed to parse SSH private key: %v", err) + } + + // Parse certificate + cert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(g.certificates.SSH.ClientCertificate)) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %v", err) + } + + sshCert, ok := cert.(*ssh.Certificate) + if !ok { + return nil, fmt.Errorf("parsed key is not an SSH certificate, got type: %T", cert) + } + + // Create certificate signer + certSigner, err := ssh.NewCertSigner(sshCert, privateKey) + if err != nil { + return nil, fmt.Errorf("failed to create certificate signer: %v", err) + } + + // Create SSH client config + config := &ssh.ClientConfig{ + User: g.GatewayID, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(certSigner), + }, + HostKeyCallback: g.createHostKeyCallback(), + Timeout: 30 * time.Second, + Config: ssh.Config{ + KeyExchanges: []string{ + "diffie-hellman-group14-sha256", + "diffie-hellman-group16-sha512", + "diffie-hellman-group18-sha512", + }, + Ciphers: []string{ + "aes128-ctr", + "aes192-ctr", + "aes256-ctr", + }, + MACs: []string{ + "hmac-sha2-256", + "hmac-sha2-512", + }, + }, + } + + return config, nil +} + +func (g *Gateway) createHostKeyCallback() ssh.HostKeyCallback { + caKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(g.certificates.SSH.ServerCAPublicKey)) + if err != nil { + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return fmt.Errorf("failed to parse CA public key: %v", err) + } + } + + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + cert, ok := key.(*ssh.Certificate) + if !ok { + return fmt.Errorf("host certificates required, raw host keys not allowed") + } + + return g.validateHostCertificate(cert, hostname, caKey) + } +} + +func (g *Gateway) validateHostCertificate(cert *ssh.Certificate, hostname string, caKey ssh.PublicKey) error { + checker := &ssh.CertChecker{ + IsHostAuthority: func(auth ssh.PublicKey, address string) bool { + return bytes.Equal(auth.Marshal(), caKey.Marshal()) + }, + } + + if err := checker.CheckCert(hostname, cert); err != nil { + return fmt.Errorf("host certificate check failed: %v", err) + } + + return nil +} + +func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { + channel, requests, err := newChannel.Accept() + if err != nil { + log.Info().Msgf("Failed to accept channel: %v", err) + return + } + defer channel.Close() + + go ssh.DiscardRequests(requests) + + // Create mTLS server configuration + tlsConfig := g.tlsConfig + if tlsConfig == nil { + log.Info().Msgf("TLS config not initialized, cannot create mTLS server") + return + } + + // Create a virtual connection that pipes data between SSH channel and TLS + virtualConn := &virtualConnection{ + channel: channel, + } + + // Wrap the virtual connection with TLS + tlsConn := tls.Server(virtualConn, tlsConfig) + + // Perform TLS handshake + log.Info().Msg("Received incoming connection, starting TLS handshake") + if err := tlsConn.Handshake(); err != nil { + log.Info().Msgf("TLS handshake failed: %v", err) + return + } + log.Info().Msg("TLS handshake completed successfully") + + // Create reader for the TLS connection + reader := bufio.NewReader(tlsConn) + + forwardConfig, err := g.parseForwardConfigFromALPN(tlsConn, reader) + if err != nil { + log.Info().Msgf("Failed to parse forward config from ALPN: %v", err) + return + } + + if forwardConfig.Mode == ForwardModeHTTP { + log.Info(). + Str("mode", string(forwardConfig.Mode)). + Str("target", fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort)). + Str("actorType", string(forwardConfig.ActorType)). + Bool("verifyTLS", forwardConfig.VerifyTLS). + Msg("Starting HTTP proxy handler") + if err := handleHTTPProxy(g.ctx, tlsConn, reader, forwardConfig); err != nil { + log.Error().Err(err).Msg("HTTP proxy handler ended with error") + } else { + log.Info().Msg("HTTP proxy handler completed") + } + return + } else if forwardConfig.Mode == ForwardModeTCP { + log.Info(). + Str("mode", string(forwardConfig.Mode)). + Str("target", fmt.Sprintf("%s:%d", forwardConfig.TargetHost, forwardConfig.TargetPort)). + Str("actorType", string(forwardConfig.ActorType)). + Msg("Starting TCP proxy handler") + if err := handleTCPProxy(g.ctx, tlsConn, forwardConfig); err != nil { + log.Error().Err(err).Msg("TCP proxy handler ended with error") + } else { + log.Info().Msg("TCP proxy handler completed") + } + return + } else if forwardConfig.Mode == ForwardModePing { + log.Info().Msg("Starting ping handler") + if err := handlePing(g.ctx, tlsConn, reader); err != nil { + log.Error().Err(err).Msg("Ping handler ended with error") + } else { + log.Info().Msg("Ping handler completed") + } + return + } +} + +func (g *Gateway) parseForwardConfigFromALPN(tlsConn *tls.Conn, reader *bufio.Reader) (*ForwardConfig, error) { + config := &ForwardConfig{} + + // Parse routing information from the client certificate + if err := g.parseDetailsFromCertificate(tlsConn, config); err != nil { + return nil, fmt.Errorf("failed to parse routing info from certificate: %v", err) + } + + // Get the negotiated ALPN protocol + state := tlsConn.ConnectionState() + negotiatedProtocol := state.NegotiatedProtocol + + log.Info().Msgf("Negotiated ALPN protocol: %s", negotiatedProtocol) + + // Map ALPN protocol to ForwardMode + switch negotiatedProtocol { + case "infisical-http-proxy": + config.Mode = ForwardModeHTTP + // For HTTP proxy, read additional parameters from the connection + if err := g.parseHTTPParametersFromConnection(reader, config); err != nil { + return nil, fmt.Errorf("failed to parse HTTP parameters: %v", err) + } + return config, nil + + case "infisical-tcp-proxy": + config.Mode = ForwardModeTCP + return config, nil + + case "infisical-ping": + config.Mode = ForwardModePing + return config, nil + + default: + return nil, fmt.Errorf("unsupported ALPN protocol: %s", negotiatedProtocol) + } +} + +func (g *Gateway) parseHTTPParametersFromConnection(reader *bufio.Reader, config *ForwardConfig) error { + // Read the first line which should contain HTTP parameters + msg, err := reader.ReadBytes('\n') + if err != nil { + return fmt.Errorf("failed to read HTTP parameters: %v", err) + } + + params := strings.TrimSpace(string(msg)) + if params != "" { + if err := g.parseForwardHTTPParams(params, config); err != nil { + return fmt.Errorf("failed to parse HTTP parameters: %v", err) + } + } + + return nil +} + +func (g *Gateway) parseForwardHTTPParams(params string, config *ForwardConfig) error { + parts := strings.Fields(params) + + for _, part := range parts { + if strings.HasPrefix(part, "ca=") { + caB64 := strings.TrimPrefix(part, "ca=") + caCert, err := base64.StdEncoding.DecodeString(caB64) + if err != nil { + return fmt.Errorf("invalid base64 CA certificate: %v", err) + } + config.CACertificate = caCert + } else if strings.HasPrefix(part, "verify=") { + verifyStr := strings.TrimPrefix(part, "verify=") + verify, err := strconv.ParseBool(verifyStr) + if err != nil { + return fmt.Errorf("invalid verify parameter: %s", verifyStr) + } + config.VerifyTLS = verify + } + } + + return nil +} + +func (g *Gateway) parseDetailsFromCertificate(tlsConn *tls.Conn, config *ForwardConfig) error { + // Get the peer certificates + state := tlsConn.ConnectionState() + if len(state.PeerCertificates) == 0 { + return fmt.Errorf("no peer certificates found") + } + + clientCert := state.PeerCertificates[0] + + for _, ext := range clientCert.Extensions { + // Extract target host and port from client certificate custom extension + if ext.Id.String() == GATEWAY_ROUTING_INFO_OID { + var routingInfo RoutingInfo + if err := json.Unmarshal(ext.Value, &routingInfo); err != nil { + return fmt.Errorf("failed to parse routing info JSON: %v", err) + } + + config.TargetHost = routingInfo.TargetHost + config.TargetPort = routingInfo.TargetPort + } + // Extract actor type from client certificate custom extension + if ext.Id.String() == GATEWAY_ACTOR_OID { + var actorDetails ActorDetails + if err := json.Unmarshal(ext.Value, &actorDetails); err != nil { + return fmt.Errorf("failed to parse actor details JSON: %v", err) + } + config.ActorType = ActorType(actorDetails.Type) + } + } + + return nil +} + +// virtualConnection implements net.Conn to bridge SSH channel and TLS +type virtualConnection struct { + channel ssh.Channel +} + +func (vc *virtualConnection) Read(b []byte) (n int, err error) { + return vc.channel.Read(b) +} + +func (vc *virtualConnection) Write(b []byte) (n int, err error) { + return vc.channel.Write(b) +} + +func (vc *virtualConnection) Close() error { + return vc.channel.Close() +} + +func (vc *virtualConnection) LocalAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} + +func (vc *virtualConnection) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} + +func (vc *virtualConnection) SetDeadline(t time.Time) error { + return nil +} + +func (vc *virtualConnection) SetReadDeadline(t time.Time) error { + return nil +} + +func (vc *virtualConnection) SetWriteDeadline(t time.Time) error { + return nil +} + +// startCertificateRenewal runs a background process to renew certificates every 6 hours +func (g *Gateway) startCertificateRenewal(ctx context.Context) { + log.Info().Msg("Starting gateway certificate renewal goroutine") + ticker := time.NewTicker(6 * 60 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Info().Msg("Gateway certificate renewal goroutine stopping...") + return + case <-ticker.C: + log.Info().Msg("Renewing gateway certificates...") + if err := g.renewCertificates(); err != nil { + log.Error().Msgf("Failed to renew gateway certificates: %v", err) + } else { + log.Info().Msg("Gateway certificates renewed successfully") + } + } + } +} + +// renewCertificates fetches new certificates and updates the gateway configurations +func (g *Gateway) renewCertificates() error { + // Re-register gateway to get fresh certificates + if err := g.registerGateway(); err != nil { + return fmt.Errorf("failed to register gateway: %v", err) + } + + return nil +} diff --git a/packages/gateway-v2/systemd.go b/packages/gateway-v2/systemd.go new file mode 100644 index 00000000..d4fa2940 --- /dev/null +++ b/packages/gateway-v2/systemd.go @@ -0,0 +1,128 @@ +package gatewayv2 + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + + "github.com/rs/zerolog/log" +) + +const systemdServiceTemplate = `[Unit] +Description=Infisical Gateway Service +After=network.target + +[Service] +Type=notify +NotifyAccess=all +EnvironmentFile=/etc/infisical/gateway.conf +ExecStart=infisical gateway start +Restart=on-failure +InaccessibleDirectories=/home +PrivateTmp=yes +LimitCORE=infinity +LimitNOFILE=1000000 +LimitNPROC=60000 +LimitRTPRIO=infinity +LimitRTTIME=7000000 + +[Install] +WantedBy=multi-user.target +` + +func InstallGatewaySystemdService(token string, domain string, name string, relayName string) error { + if runtime.GOOS != "linux" { + log.Info().Msg("Skipping systemd service installation - not on Linux") + return nil + } + + if os.Geteuid() != 0 { + log.Info().Msg("Skipping systemd service installation - not running as root/sudo") + return nil + } + + configDir := "/etc/infisical" + if err := os.MkdirAll(configDir, 0755); err != nil { + return fmt.Errorf("failed to create config directory: %v", err) + } + + configContent := fmt.Sprintf("INFISICAL_TOKEN=%s\n", token) + if domain != "" { + configContent += fmt.Sprintf("INFISICAL_API_URL=%s\n", domain) + } + + if name != "" { + configContent += fmt.Sprintf("%s=%s\n", GATEWAY_NAME_ENV_NAME, name) + } + if relayName != "" { + configContent += fmt.Sprintf("%s=%s\n", RELAY_NAME_ENV_NAME, relayName) + } + + configPath := filepath.Join(configDir, "gateway.conf") + if err := os.WriteFile(configPath, []byte(configContent), 0600); err != nil { + return fmt.Errorf("failed to write config file: %v", err) + } + + servicePath := "/etc/systemd/system/infisical-gateway.service" + if err := os.WriteFile(servicePath, []byte(systemdServiceTemplate), 0644); err != nil { + return fmt.Errorf("failed to write systemd service file: %v", err) + } + + reloadCmd := exec.Command("systemctl", "daemon-reload") + if err := reloadCmd.Run(); err != nil { + return fmt.Errorf("failed to reload systemd: %v", err) + } + + log.Info().Msg("Successfully installed systemd service") + log.Info().Msg("To start the service, run: sudo systemctl start infisical-gateway") + log.Info().Msg("To enable the service on boot, run: sudo systemctl enable infisical-gateway") + + return nil +} + +func UninstallGatewaySystemdService() error { + if runtime.GOOS != "linux" { + log.Info().Msg("Skipping systemd service uninstallation - not on Linux") + return nil + } + + if os.Geteuid() != 0 { + log.Info().Msg("Skipping systemd service uninstallation - not running as root/sudo") + return nil + } + + // Stop the service if it's running + stopCmd := exec.Command("systemctl", "stop", "infisical-gateway") + if err := stopCmd.Run(); err != nil { + log.Warn().Msgf("Failed to stop service: %v", err) + } + + // Disable the service + disableCmd := exec.Command("systemctl", "disable", "infisical-gateway") + if err := disableCmd.Run(); err != nil { + log.Warn().Msgf("Failed to disable service: %v", err) + } + + // Remove the service file + servicePath := "/etc/systemd/system/infisical-gateway.service" + if err := os.Remove(servicePath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove systemd service file: %v", err) + } + + // Remove the configuration file + configPath := "/etc/infisical/gateway.conf" + if err := os.Remove(configPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove config file: %v", err) + } + + // Reload systemd to apply changes + reloadCmd := exec.Command("systemctl", "daemon-reload") + if err := reloadCmd.Run(); err != nil { + return fmt.Errorf("failed to reload systemd: %v", err) + } + + log.Info().Msg("Successfully uninstalled Infisical Gateway systemd service") + return nil +} diff --git a/packages/relay/relay.go b/packages/relay/relay.go new file mode 100644 index 00000000..285e9e0a --- /dev/null +++ b/packages/relay/relay.go @@ -0,0 +1,520 @@ +package relay + +import ( + "bytes" + "context" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/Infisical/infisical-merge/packages/api" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/go-resty/resty/v2" + "github.com/rs/zerolog/log" + "golang.org/x/crypto/ssh" +) + +type RelayConfig struct { + // API Configuration + Token string + RelayName string + + Type string + + // Server Ports + SSHPort string + TLSPort string + + // Network Configuration + Host string +} + +type Relay struct { + httpClient *resty.Client + config *RelayConfig + + // Certificate storage + certificates *api.RegisterRelayResponse + + // SSH server components + sshConfig *ssh.ServerConfig + sshCA ssh.Signer + + // TLS server components + tlsConfig *tls.Config + tlsCACert []byte + tlsCAKey *rsa.PrivateKey + + // Tunnel storage (Gateway ID -> SSH connection) + tunnels map[string]*ssh.ServerConn + mu sync.RWMutex + + // Server listeners + sshListener net.Listener + tlsListener net.Listener +} + +func NewRelay(config *RelayConfig) (*Relay, error) { + httpClient, err := util.GetRestyClientWithCustomHeaders() + if err != nil { + return nil, fmt.Errorf("unable to get client with custom headers [err=%v]", err) + } + + httpClient.SetAuthToken(config.Token) + + return &Relay{ + httpClient: httpClient, + config: config, + tunnels: make(map[string]*ssh.ServerConn), + }, nil +} + +func (r *Relay) SetToken(token string) { + r.httpClient.SetAuthToken(token) +} + +func (r *Relay) Start(ctx context.Context) error { + if err := r.registerRelay(); err != nil { + return fmt.Errorf("failed to register relay: %v", err) + } + + // Setup SSH server + if err := r.setupSSHServer(); err != nil { + return fmt.Errorf("failed to setup SSH server: %v", err) + } + + // Setup TLS server + if err := r.setupTLSServer(); err != nil { + return fmt.Errorf("failed to setup TLS server: %v", err) + } + + // Start certificate renewal goroutine + go r.startCertificateRenewal(ctx) + + // Start SSH server + go r.startSSHServer() + + // Start TLS server + go r.startTLSServer() + + log.Info().Msg("Relay server started successfully") + + // Wait for context cancellation + <-ctx.Done() + + // Cleanup + r.cleanup() + return nil +} + +func (r *Relay) registerRelay() error { + body := api.RegisterRelayRequest{ + Host: r.config.Host, + Name: r.config.RelayName, + } + + if r.config.Type == "instance" { + certResp, err := api.CallRegisterInstanceRelay(r.httpClient, body) + if err != nil { + return fmt.Errorf("failed to register instance relay: %v", err) + } + r.certificates = &certResp + } else { + certResp, err := api.CallRegisterRelay(r.httpClient, body) + if err != nil { + return fmt.Errorf("failed to register org relay: %v", err) + } + r.certificates = &certResp + } + + log.Info().Msg("Successfully registered relay and received certificates from API") + return nil +} + +func (r *Relay) setupSSHServer() error { + // Parse SSH CA public key + sshCAPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.certificates.SSH.ClientCAPublicKey)) + if err != nil { + return fmt.Errorf("failed to parse SSH CA public key: %v", err) + } + + // Parse SSH server private key + sshServerKey, err := ssh.ParsePrivateKey([]byte(r.certificates.SSH.ServerPrivateKey)) + if err != nil { + return fmt.Errorf("failed to parse SSH server private key: %v", err) + } + + // Parse SSH server certificate + sshServerCert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.certificates.SSH.ServerCertificate)) + if err != nil { + return fmt.Errorf("failed to parse SSH server certificate: %v", err) + } + + // Create certificate signer + certSigner, err := ssh.NewCertSigner(sshServerCert.(*ssh.Certificate), sshServerKey) + if err != nil { + return fmt.Errorf("failed to create SSH certificate signer: %v", err) + } + + // Setup SSH server config + r.sshConfig = &ssh.ServerConfig{ + MaxAuthTries: 3, + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + // Check if this is an SSH certificate + cert, ok := key.(*ssh.Certificate) + if !ok { + log.Warn().Msgf("Gateway '%s' tried to authenticate with raw public key (rejected)", conn.User()) + return nil, fmt.Errorf("certificates required, raw public keys not allowed") + } + + // Validate the certificate + if err := r.validateSSHCertificate(cert, conn.User(), sshCAPubKey); err != nil { + log.Error().Msgf("Gateway '%s' certificate validation failed: %v", conn.User(), err) + return nil, err + } + + gatewayId := "" + if len(cert.ValidPrincipals) > 0 { + gatewayId = cert.ValidPrincipals[0] + } + + if gatewayId == "" { + return nil, fmt.Errorf("gateway id is required") + } + + // Validate that the user is authorized to connect to the current relay + expectedKeyId := "client-" + r.config.RelayName + if cert.KeyId != expectedKeyId { + log.Error().Msgf("Gateway '%s' certificate Key ID '%s' does not match expected '%s'", conn.User(), cert.KeyId, expectedKeyId) + return nil, fmt.Errorf("certificate Key ID does not match expected value") + } + + return &ssh.Permissions{ + Extensions: map[string]string{ + "gateway-id": gatewayId, + }, + }, nil + }, + } + + r.sshConfig.AddHostKey(certSigner) + return nil +} + +func (r *Relay) setupTLSServer() error { + // Parse TLS server certificate + serverCertBlock, _ := pem.Decode([]byte(r.certificates.PKI.ServerCertificate)) + if serverCertBlock == nil { + return fmt.Errorf("failed to decode server certificate") + } + + // Note: serverCert is parsed for validation but not used in the TLS config + // since we use the raw bytes directly + _, err := x509.ParseCertificate(serverCertBlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse server certificate: %v", err) + } + + // Parse TLS server private key + serverKeyBlock, _ := pem.Decode([]byte(r.certificates.PKI.ServerPrivateKey)) + if serverKeyBlock == nil { + return fmt.Errorf("failed to decode server private key") + } + + serverKey, err := x509.ParsePKCS8PrivateKey(serverKeyBlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse server private key: %v", err) + } + + // Create certificate pool for client CAs + clientCAPool := x509.NewCertPool() + + var chainCerts [][]byte + chainData := []byte(r.certificates.PKI.ClientCertificateChain) + for { + block, rest := pem.Decode(chainData) + if block == nil { + break + } + chainCerts = append(chainCerts, block.Bytes) + chainData = rest + } + + for i, certBytes := range chainCerts { + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + log.Error().Msgf("Failed to parse client chain certificate %d: %v", i+1, err) + continue + } + clientCAPool.AddCert(cert) + } + + // Create TLS config + r.tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{serverCertBlock.Bytes}, + PrivateKey: serverKey, + }, + }, + ClientCAs: clientCAPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS12, + } + + return nil +} + +func (r *Relay) validateSSHCertificate(cert *ssh.Certificate, username string, caPubKey ssh.PublicKey) error { + // Check certificate type + if cert.CertType != ssh.UserCert { + return fmt.Errorf("invalid certificate type: %d", cert.CertType) + } + + // Check if certificate is signed by expected CA + checker := &ssh.CertChecker{ + IsUserAuthority: func(auth ssh.PublicKey) bool { + return bytes.Equal(auth.Marshal(), caPubKey.Marshal()) + }, + } + + // Validate the certificate + if err := checker.CheckCert(username, cert); err != nil { + return fmt.Errorf("certificate check failed: %v", err) + } + + log.Debug().Msgf("SSH certificate valid for user '%s', principals: %v", username, cert.ValidPrincipals) + return nil +} + +func (r *Relay) startSSHServer() { + listener, err := net.Listen("tcp", ":"+r.config.SSHPort) + if err != nil { + log.Fatal().Msgf("Failed to start SSH server: %v", err) + } + r.sshListener = listener + + log.Info().Msgf("SSH server listening on :%s for gateways", r.config.SSHPort) + + for { + conn, err := listener.Accept() + if err != nil { + log.Error().Msgf("Failed to accept SSH connection: %v", err) + continue + } + go r.handleSSHAgent(conn) + } +} + +func (r *Relay) handleSSHAgent(conn net.Conn) { + defer conn.Close() + + // SSH handshake + sshConn, chans, reqs, err := ssh.NewServerConn(conn, r.sshConfig) + if err != nil { + log.Error().Msgf("SSH handshake failed: %v", err) + return + } + + gatewayId := sshConn.Permissions.Extensions["gateway-id"] + log.Info().Msgf("SSH handshake successful for gateway: %s", gatewayId) + + // Store the connection (ensure only one connection per gateway) + r.mu.Lock() + if _, exists := r.tunnels[gatewayId]; exists { + r.mu.Unlock() + log.Warn().Msgf("Gateway '%s' already has an active connection, rejecting new connection", gatewayId) + sshConn.Close() + return + } + + r.tunnels[gatewayId] = sshConn + r.mu.Unlock() + + // Clean up when agent disconnects + defer func() { + r.mu.Lock() + delete(r.tunnels, gatewayId) + r.mu.Unlock() + log.Info().Msgf("Gateway %s disconnected", gatewayId) + }() + + // Handle global requests (reject all for security) + go func() { + for req := range reqs { + log.Debug().Msgf("Rejecting global request: %s from gateway %s", req.Type, gatewayId) + if req.WantReply { + req.Reply(false, nil) + } + } + }() + + // Handle channel requests + for newChannel := range chans { + switch newChannel.ChannelType() { + case "session": + log.Debug().Msgf("Rejecting session channel from gateway %s", gatewayId) + newChannel.Reject(ssh.Prohibited, "no shell access") + case "x11": + log.Debug().Msgf("Rejecting X11 forwarding from gateway %s", gatewayId) + newChannel.Reject(ssh.Prohibited, "no X11 forwarding") + case "auth-agent": + log.Debug().Msgf("Rejecting auth-agent forwarding from gateway %s", gatewayId) + newChannel.Reject(ssh.Prohibited, "no agent forwarding") + case "forwarded-tcpip": + log.Debug().Msgf("Rejecting forwarded-tcpip from gateway %s", gatewayId) + newChannel.Reject(ssh.Prohibited, "no port forwarding") + default: + log.Warn().Msgf("Rejecting unknown channel type '%s' from gateway %s", newChannel.ChannelType(), gatewayId) + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + } + } +} + +func (r *Relay) startTLSServer() { + listener, err := net.Listen("tcp", ":"+r.config.TLSPort) + if err != nil { + log.Fatal().Msgf("Failed to start TLS server: %v", err) + } + r.tlsListener = listener + + log.Info().Msgf("TLS server listening on :%s for clients", r.config.TLSPort) + + for { + conn, err := listener.Accept() + if err != nil { + log.Error().Msgf("Failed to accept TLS connection: %v", err) + continue + } + go r.handleTLSClient(conn) + } +} + +func (r *Relay) handleTLSClient(conn net.Conn) { + defer conn.Close() + + // Perform TLS handshake using current TLS config + tlsConn := tls.Server(conn, r.tlsConfig) + defer tlsConn.Close() + + // Set handshake timeout to avoid hanging on slow/malicious connections + tlsConn.SetDeadline(time.Now().Add(5 * time.Second)) + + // Force TLS handshake + err := tlsConn.Handshake() + if err != nil { + log.Debug().Msgf("TLS handshake failed from %s: %v", conn.RemoteAddr(), err) + return + } + + // Clear deadline for actual data transfer + tlsConn.SetDeadline(time.Time{}) + + r.handleClient(tlsConn) +} + +func (r *Relay) handleClient(tlsConn *tls.Conn) { + var gatewayId string + var orgDetails string + state := tlsConn.ConnectionState() + + if len(state.PeerCertificates) > 0 { + cert := state.PeerCertificates[0] + log.Info().Msgf("Client connected with certificate: %s", cert.Subject.CommonName) + gatewayId = cert.Subject.CommonName + orgDetails = cert.Subject.Organization[0] + } else { + log.Warn().Msg("No peer certificates found") + return + } + + // Get the SSH connection for this gateway + r.mu.RLock() + conn, exists := r.tunnels[gatewayId] + r.mu.RUnlock() + + if !exists { + log.Warn().Msgf("Gateway '%s' not connected", gatewayId) + tlsConn.Write([]byte("ERROR: Gateway not connected\n")) + return + } + + log.Info().Msgf("Routing connection from Organization %s to Gateway with ID: %s", orgDetails, gatewayId) + + channel, _, err := conn.OpenChannel("direct-tcpip", nil) + if err != nil { + log.Error().Msgf("Failed to connect to gateway: %v", err) + tlsConn.Write([]byte("ERROR: Failed to connect to gateway\n")) + return + } + defer channel.Close() + + // Bidirectional forwarding + go func() { + io.Copy(channel, tlsConn) + channel.CloseWrite() + }() + + io.Copy(tlsConn, channel) + log.Info().Msgf("Client %s disconnected", tlsConn.RemoteAddr()) +} + +func (r *Relay) cleanup() { + log.Info().Msg("Shutting down relay server...") + + if r.sshListener != nil { + r.sshListener.Close() + } + if r.tlsListener != nil { + r.tlsListener.Close() + } + + log.Info().Msg("Relay server shutdown complete") +} + +// startCertificateRenewal runs a background process to renew certificates every 6 hours +func (r *Relay) startCertificateRenewal(ctx context.Context) { + ticker := time.NewTicker(6 * 60 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Info().Msg("Certificate renewal goroutine stopping...") + return + case <-ticker.C: + log.Info().Msg("Renewing certificates...") + if err := r.renewCertificates(); err != nil { + log.Error().Msgf("Failed to renew certificates: %v", err) + } else { + log.Info().Msg("Certificates renewed successfully") + } + } + } +} + +// renewCertificates fetches new certificates and updates the server configurations +func (r *Relay) renewCertificates() error { + // Re-register relay to get fresh certificates + if err := r.registerRelay(); err != nil { + return fmt.Errorf("failed to register relay: %v", err) + } + + // Update SSH server configuration + if err := r.setupSSHServer(); err != nil { + return fmt.Errorf("failed to setup SSH server: %v", err) + } + + // Update TLS server configuration + if err := r.setupTLSServer(); err != nil { + return fmt.Errorf("failed to setup TLS server: %v", err) + } + + return nil +}