diff --git a/packages/api/model.go b/packages/api/model.go index 0d56e1c0..d0f89033 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -787,6 +787,8 @@ type PAMSessionCredentials struct { SSLCertificate string `json:"sslCertificate,omitempty"` Username string `json:"username"` Password string `json:"password"` + AuthMethod string `json:"authMethod,omitempty"` + PrivateKey string `json:"privateKey,omitempty"` } type UploadSessionLogEntry struct { @@ -795,8 +797,16 @@ type UploadSessionLogEntry struct { Output string `json:"output"` } +// UploadTerminalEvent represents a terminal session event for upload +type UploadTerminalEvent struct { + Timestamp time.Time `json:"timestamp"` + EventType string `json:"eventType"` + Data []byte `json:"data"` + ElapsedTime float64 `json:"elapsedTime"` +} + type UploadPAMSessionLogsRequest struct { - Logs []UploadSessionLogEntry `json:"logs"` + Logs interface{} `json:"logs"` // Can be []UploadSessionLogEntry or []UploadTerminalEvent } type RelayHeartbeatRequest struct { diff --git a/packages/cmd/pam.go b/packages/cmd/pam.go index fd2701ac..90781ec4 100644 --- a/packages/cmd/pam.go +++ b/packages/cmd/pam.go @@ -3,11 +3,10 @@ package cmd import ( "time" + pam "github.com/Infisical/infisical-merge/packages/pam/local" "github.com/Infisical/infisical-merge/packages/util" "github.com/rs/zerolog/log" "github.com/spf13/cobra" - - "github.com/Infisical/infisical-merge/packages/pam" ) var pamCmd = &cobra.Command{ @@ -75,11 +74,67 @@ var pamDbAccessAccountCmd = &cobra.Command{ }, } +var pamSshCmd = &cobra.Command{ + Use: "ssh", + Short: "SSH-related PAM commands", + Long: "SSH-related PAM commands for Infisical", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, +} + +var pamSshAccessAccountCmd = &cobra.Command{ + Use: "access-account ", + Short: "Start SSH session to PAM account", + Long: "Start an SSH session to a PAM-managed SSH account. This command automatically launches an SSH client connected through the Infisical Gateway.", + Example: "infisical pam ssh access-account --duration 2h", + DisableFlagsInUseLine: true, + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + util.RequireLogin() + + accountID := args[0] + + durationStr, err := cmd.Flags().GetString("duration") + if err != nil { + util.HandleError(err, "Unable to parse duration flag") + } + + // Parse duration + _, err = time.ParseDuration(durationStr) + if err != nil { + util.HandleError(err, "Invalid duration format. Use formats like '1h', '30m', '2h30m'") + } + + log.Debug().Msg("PAM SSH Access: Trying to fetch credentials using logged in details") + + loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true) + isConnected := util.ValidateInfisicalAPIConnection() + + if isConnected { + log.Debug().Msg("PAM SSH Access: Connected to Infisical instance, checking logged in creds") + } + + if err != nil { + util.HandleError(err, "Unable to get logged in user details") + } + + if isConnected && loggedInUserDetails.LoginExpired { + loggedInUserDetails = util.EstablishUserLoginSession() + } + + pam.StartSSHLocalProxy(loggedInUserDetails.UserCredentials.JTWToken, accountID, durationStr) + }, +} + func init() { pamDbCmd.AddCommand(pamDbAccessAccountCmd) pamDbAccessAccountCmd.Flags().String("duration", "1h", "Duration for database access session (e.g., '1h', '30m', '2h30m')") pamDbAccessAccountCmd.Flags().Int("port", 0, "Port for the local database proxy server (0 for auto-assign)") + pamSshCmd.AddCommand(pamSshAccessAccountCmd) + pamSshAccessAccountCmd.Flags().String("duration", "1h", "Duration for SSH access session (e.g., '1h', '30m', '2h30m')") + pamCmd.AddCommand(pamDbCmd) + pamCmd.AddCommand(pamSshCmd) rootCmd.AddCommand(pamCmd) } diff --git a/packages/pam/handlers/ssh/keys.go b/packages/pam/handlers/ssh/keys.go new file mode 100644 index 00000000..6d50a864 --- /dev/null +++ b/packages/pam/handlers/ssh/keys.go @@ -0,0 +1,16 @@ +package ssh + +import ( + "crypto/rand" + "crypto/rsa" + "fmt" +) + +// generateRSAKey generates a 2048-bit RSA private key +func generateRSAKey() (*rsa.PrivateKey, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("failed to generate RSA key: %w", err) + } + return privateKey, nil +} diff --git a/packages/pam/handlers/ssh/proxy.go b/packages/pam/handlers/ssh/proxy.go new file mode 100644 index 00000000..20d676b2 --- /dev/null +++ b/packages/pam/handlers/ssh/proxy.go @@ -0,0 +1,411 @@ +package ssh + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/Infisical/infisical-merge/packages/pam/session" + "github.com/rs/zerolog/log" + "golang.org/x/crypto/ssh" +) + +// SSHProxyConfig holds configuration for the SSH proxy +type SSHProxyConfig struct { + TargetAddr string // e.g., "target-host:22" + AuthMethod string + InjectUsername string + InjectPassword string + InjectPrivateKey string + SessionID string + SessionLogger session.SessionLogger +} + +// SSHProxy handles proxying SSH connections with credential injection +type SSHProxy struct { + config SSHProxyConfig + mutex sync.Mutex + sessionData []byte // Store session data for logging + inputBuffer []byte // Buffer for input data to batch keystrokes +} + +// NewSSHProxy creates a new SSH proxy instance +func NewSSHProxy(config SSHProxyConfig) *SSHProxy { + return &SSHProxy{ + config: config, + } +} + +// HandleConnection handles a single SSH client connection +func (p *SSHProxy) HandleConnection(ctx context.Context, clientConn net.Conn) error { + defer clientConn.Close() + + sessionID := p.config.SessionID + + // Ensure session logger cleanup + defer func() { + if err := p.config.SessionLogger.Close(); err != nil { + log.Error().Err(err).Str("sessionID", sessionID).Msg("Failed to close session logger") + } + }() + + log.Info(). + Str("sessionID", sessionID). + Str("targetAddr", p.config.TargetAddr). + Msg("New SSH connection for PAM session") + + // Configure SSH server (proxy acts as SSH server to the client) + serverConfig := &ssh.ServerConfig{ + // Accept any credentials from client - we'll inject our own to the target + NoClientAuth: true, + // Alternative: accept any password + PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + return nil, nil + }, + } + + // Generate a temporary host key for the proxy + hostKey, err := p.generateHostKey() + if err != nil { + log.Error().Err(err).Str("sessionID", sessionID).Msg("Failed to generate host key") + return fmt.Errorf("failed to generate host key: %w", err) + } + + serverConfig.AddHostKey(hostKey) + + // Perform SSH handshake with client + clientSSHConn, clientChannels, clientRequests, err := ssh.NewServerConn(clientConn, serverConfig) + if err != nil { + log.Error().Err(err).Str("sessionID", sessionID).Msg("Failed to establish SSH server connection with client") + return fmt.Errorf("failed to establish SSH connection with client: %w", err) + } + defer clientSSHConn.Close() + + log.Info(). + Str("sessionID", sessionID). + Str("clientUser", clientSSHConn.User()). + Str("clientVersion", string(clientSSHConn.ClientVersion())). + Msg("SSH client connected") + + // Connect to target SSH server with injected credentials + serverSSHConn, err := p.connectToTargetServer() + if err != nil { + log.Error().Err(err).Str("sessionID", sessionID).Msg("Failed to connect to target SSH server") + return fmt.Errorf("failed to connect to target SSH server: %w", err) + } + defer serverSSHConn.Close() + + log.Info(). + Str("sessionID", sessionID). + Str("serverVersion", string(serverSSHConn.ServerVersion())). + Msg("Connected to target SSH server with injected credentials") + + // Discard global requests (not needed for basic remote access) + go ssh.DiscardRequests(clientRequests) + + // Handle channels from client (this is where actual SSH sessions happen) + for newChannel := range clientChannels { + go p.handleChannel(ctx, newChannel, serverSSHConn, sessionID) + } + + log.Info(). + Str("sessionID", sessionID). + Msg("SSH connection closed") + + return nil +} + +// connectToTargetServer establishes connection to the actual SSH server with injected credentials +func (p *SSHProxy) connectToTargetServer() (*ssh.Client, error) { + var authMethods []ssh.AuthMethod + + switch p.config.AuthMethod { + case "public-key": + // Parse private key (convert PEM string to bytes) + signer, err := ssh.ParsePrivateKey([]byte(p.config.InjectPrivateKey)) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + authMethods = append(authMethods, ssh.PublicKeys(signer)) + log.Debug(). + Str("sessionID", p.config.SessionID). + Msg("Using public key authentication") + case "password": + authMethods = append(authMethods, ssh.Password(p.config.InjectPassword)) + log.Debug(). + Str("sessionID", p.config.SessionID). + Msg("Using password authentication") + default: + return nil, fmt.Errorf("invalid or unspecified auth method: %s (must be 'public-key' or 'password')", p.config.AuthMethod) + } + + // Configure SSH client (proxy acts as client to the target server) + clientConfig := &ssh.ClientConfig{ + User: p.config.InjectUsername, + Auth: authMethods, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), // TODO: add support for passing in host key + Timeout: 10 * time.Second, + } + + // Connect to target server + client, err := ssh.Dial("tcp", p.config.TargetAddr, clientConfig) + if err != nil { + return nil, fmt.Errorf("failed to dial target SSH server: %w", err) + } + + return client, nil +} + +// handleChannel handles a single SSH channel (session, direct-tcpip, etc.) +func (p *SSHProxy) handleChannel(ctx context.Context, newChannel ssh.NewChannel, serverConn *ssh.Client, sessionID string) { + channelType := newChannel.ChannelType() + + log.Debug(). + Str("sessionID", sessionID). + Str("channelType", channelType). + Msg("← CLIENT new channel request") + + // Open corresponding channel on server + serverChannel, serverRequests, err := serverConn.OpenChannel(channelType, newChannel.ExtraData()) + if err != nil { + log.Error().Err(err). + Str("sessionID", sessionID). + Str("channelType", channelType). + Msg("Failed to open channel on server") + newChannel.Reject(ssh.ConnectionFailed, fmt.Sprintf("failed to open channel: %v", err)) + return + } + defer serverChannel.Close() + + // Accept the channel from client + clientChannel, clientRequests, err := newChannel.Accept() + if err != nil { + log.Error().Err(err).Str("sessionID", sessionID).Msg("Failed to accept client channel") + serverChannel.Close() + return + } + defer clientChannel.Close() + + log.Info(). + Str("sessionID", sessionID). + Str("channelType", channelType). + Msg("SSH channel established") + + // Handle requests for this channel (pty-req, shell, exec, etc.) + go p.handleChannelRequests(clientRequests, serverChannel, sessionID, channelType) + go p.handleChannelRequests(serverRequests, clientChannel, sessionID, channelType) + + // Proxy data bidirectionally with logging + errChan := make(chan error, 2) + + // Client to Server + go func() { + err := p.proxyData(clientChannel, serverChannel, "client→server", sessionID, true) + errChan <- err + }() + + // Server to Client + go func() { + err := p.proxyData(serverChannel, clientChannel, "server→client", sessionID, false) + errChan <- err + }() + + // Wait for either direction to finish or context cancellation + select { + case err := <-errChan: + if err != nil && err != io.EOF { + log.Debug().Err(err).Str("sessionID", sessionID).Msg("Channel proxy error") + } + case <-ctx.Done(): + log.Info().Str("sessionID", sessionID).Msg("Channel cancelled by context") + } + + log.Debug(). + Str("sessionID", sessionID). + Str("channelType", channelType). + Msg("SSH channel closed") +} + +// handleChannelRequests handles channel-specific requests (pty, shell, exec, etc.) +func (p *SSHProxy) handleChannelRequests(requests <-chan *ssh.Request, targetChannel ssh.Channel, sessionID string, channelType string) { + for req := range requests { + log.Debug(). + Str("sessionID", sessionID). + Str("channelType", channelType). + Str("requestType", req.Type). + Bool("wantReply", req.WantReply). + Msg("Channel request") + + // Log exec and shell requests for audit + switch req.Type { + case "exec": + if len(req.Payload) > 4 { + cmdLen := int(req.Payload[3]) + if len(req.Payload) >= 4+cmdLen { + command := string(req.Payload[4 : 4+cmdLen]) + log.Info(). + Str("sessionID", sessionID). + Str("command", command). + Msg("SSH exec command") + + // Log the exec command to the session recording + // Format it similar to how it would appear in a shell + commandWithPrompt := fmt.Sprintf("$ %s\n", command) + event := session.TerminalEvent{ + Timestamp: time.Now(), + EventType: session.TerminalEventInput, + Data: []byte(commandWithPrompt), + } + if err := p.config.SessionLogger.LogTerminalEvent(event); err != nil { + log.Error().Err(err). + Str("sessionID", sessionID). + Str("command", command). + Msg("Failed to log exec command to session recording") + } + } + } + case "shell": + log.Info(). + Str("sessionID", sessionID). + Msg("SSH interactive shell requested") + case "pty-req": + log.Debug(). + Str("sessionID", sessionID). + Msg("PTY requested") + } + + // Forward request to target channel + ok, err := targetChannel.SendRequest(req.Type, req.WantReply, req.Payload) + if err != nil { + log.Error().Err(err). + Str("sessionID", sessionID). + Str("requestType", req.Type). + Msg("Failed to forward channel request") + if req.WantReply { + req.Reply(false, nil) + } + continue + } + + if req.WantReply { + req.Reply(ok, nil) + } + } +} + +// proxyData proxies data between channels with optional logging +func (p *SSHProxy) proxyData(src io.Reader, dst io.Writer, direction string, sessionID string, logInput bool) error { + buf := make([]byte, 32*1024) // 32KB buffer + + // Flush any remaining input buffer on exit + defer func() { + if logInput && len(p.inputBuffer) > 0 { + p.flushInputBuffer(sessionID) + } + }() + + for { + n, err := src.Read(buf) + if n > 0 { + // For input, buffer until we see newline or control chars + if logInput { + p.bufferInput(buf[:n], sessionID) + } else { + // For output, log immediately as before + event := session.TerminalEvent{ + Timestamp: time.Now(), + EventType: session.TerminalEventOutput, + Data: make([]byte, n), + } + copy(event.Data, buf[:n]) + + if err := p.config.SessionLogger.LogTerminalEvent(event); err != nil { + log.Error().Err(err). + Str("sessionID", sessionID). + Str("eventType", string(session.TerminalEventOutput)). + Msg("Failed to log terminal event") + } + } + + // Write to destination + written, writeErr := dst.Write(buf[:n]) + if writeErr != nil { + return writeErr + } + if written != n { + return io.ErrShortWrite + } + } + + if err != nil { + if err == io.EOF { + return nil + } + return err + } + } +} + +// bufferInput accumulates input data and logs only when newline or control chars are encountered +func (p *SSHProxy) bufferInput(data []byte, sessionID string) { + p.mutex.Lock() + defer p.mutex.Unlock() + + for _, b := range data { + p.inputBuffer = append(p.inputBuffer, b) + + // Check if we should flush the buffer + // CR (0x0D), LF (0x0A), or if buffer gets too large + if b == 0x0D || b == 0x0A || len(p.inputBuffer) >= 1024 { + p.flushInputBufferUnsafe(sessionID) + } + } +} + +// flushInputBuffer flushes the input buffer with locking +func (p *SSHProxy) flushInputBuffer(sessionID string) { + p.mutex.Lock() + defer p.mutex.Unlock() + p.flushInputBufferUnsafe(sessionID) +} + +// flushInputBufferUnsafe flushes the input buffer without locking (caller must hold lock) +func (p *SSHProxy) flushInputBufferUnsafe(sessionID string) { + if len(p.inputBuffer) == 0 { + return + } + + event := session.TerminalEvent{ + Timestamp: time.Now(), + EventType: session.TerminalEventInput, + Data: make([]byte, len(p.inputBuffer)), + } + copy(event.Data, p.inputBuffer) + + if err := p.config.SessionLogger.LogTerminalEvent(event); err != nil { + log.Error().Err(err). + Str("sessionID", sessionID). + Str("eventType", string(session.TerminalEventInput)). + Msg("Failed to log terminal event") + } + + // Clear the buffer + p.inputBuffer = p.inputBuffer[:0] +} + +// generateHostKey generates a temporary RSA key for the SSH server +func (p *SSHProxy) generateHostKey() (ssh.Signer, error) { + rsaKey, err := generateRSAKey() + if err != nil { + return nil, fmt.Errorf("failed to generate RSA key: %w", err) + } + + privateKey, err := ssh.NewSignerFromSigner(rsaKey) + if err != nil { + return nil, fmt.Errorf("failed to create signer: %w", err) + } + return privateKey, nil +} diff --git a/packages/pam/local/base-proxy.go b/packages/pam/local/base-proxy.go new file mode 100644 index 00000000..49e7bc94 --- /dev/null +++ b/packages/pam/local/base-proxy.go @@ -0,0 +1,172 @@ +package pam + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "strconv" + "strings" + "sync" + "time" + + "github.com/Infisical/infisical-merge/packages/api" + "github.com/go-resty/resty/v2" + "github.com/rs/zerolog/log" +) + +// BaseProxyServer contains common functionality for all local proxy types +type BaseProxyServer struct { + httpClient *resty.Client + relayHost string + relayClientCert string + relayClientKey string + relayServerCertChain string + gatewayClientCert string + gatewayClientKey string + gatewayServerCertChain string + sessionExpiry time.Time + sessionId string + ctx context.Context + cancel context.CancelFunc + activeConnections sync.WaitGroup + shutdownOnce sync.Once + shutdownCh chan struct{} +} + +// CreateRelayConnection establishes a TLS connection to the relay server +func (b *BaseProxyServer) CreateRelayConnection() (net.Conn, error) { + var host string + var port int = 8443 + + if strings.Contains(b.relayHost, ":") { + var portStr string + var err error + host, portStr, err = net.SplitHostPort(b.relayHost) + if err != nil { + return nil, fmt.Errorf("invalid relay host format: %w", err) + } + port, err = strconv.Atoi(portStr) + if err != nil { + return nil, fmt.Errorf("invalid port in relay host: %w", err) + } + } else { + host = b.relayHost + } + + // Load relay certificates + cert, err := tls.X509KeyPair([]byte(b.relayClientCert), []byte(b.relayClientKey)) + if err != nil { + return nil, fmt.Errorf("failed to load relay client certificate: %w", err) + } + + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM([]byte(b.relayServerCertChain)) { + return nil, fmt.Errorf("failed to parse relay server certificate chain") + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, + ServerName: host, + MinVersion: tls.VersionTLS12, + } + + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", host, port), tlsConfig) + if err != nil { + return nil, fmt.Errorf("failed to connect to relay: %w", err) + } + + log.Debug().Msg("Relay TLS connection established") + return conn, nil +} + +// CreateGatewayConnection establishes a mTLS connection to the gateway over the relay +func (b *BaseProxyServer) CreateGatewayConnection(relayConn net.Conn, alpn ALPN) (net.Conn, error) { + // Load gateway certificates + cert, err := tls.X509KeyPair([]byte(b.gatewayClientCert), []byte(b.gatewayClientKey)) + if err != nil { + return nil, fmt.Errorf("failed to load gateway client certificate: %w", err) + } + + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM([]byte(b.gatewayServerCertChain)) { + return nil, fmt.Errorf("failed to parse gateway server certificate chain") + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + MaxVersion: tls.VersionTLS13, + NextProtos: []string{string(alpn)}, + ServerName: "localhost", + } + + gatewayConn := tls.Client(relayConn, tlsConfig) + + err = gatewayConn.Handshake() + if err != nil { + return nil, fmt.Errorf("failed to establish gateway mTLS: %w", err) + } + + state := gatewayConn.ConnectionState() + if !state.HandshakeComplete { + return nil, fmt.Errorf("gateway TLS handshake not complete") + } + + log.Debug().Msg("Gateway mTLS connection established") + return gatewayConn, nil +} + +// NotifySessionTermination sends a termination notification through the gateway +func (b *BaseProxyServer) NotifySessionTermination() { + log.Debug().Msgf("Notifying session termination for session ID: %s", b.sessionId) + + // Try to notify via gateway connection first + relayConn, err := b.CreateRelayConnection() + if err != nil { + log.Error().Err(err).Msg("Failed to connect to relay for termination notification") + // Fallback to API call if relay connection fails + b.FallbackToAPITermination() + return + } + defer relayConn.Close() + + gatewayConn, err := b.CreateGatewayConnection(relayConn, ALPNInfisicalPAMCancellation) + if err != nil { + log.Error().Err(err).Msg("Failed to connect to gateway for termination notification") + // Fallback to API call if gateway connection fails + b.FallbackToAPITermination() + return + } + defer gatewayConn.Close() + log.Debug().Msg("Session termination notification sent successfully") +} + +// FallbackToAPITermination terminates the session via API call +func (b *BaseProxyServer) FallbackToAPITermination() { + err := api.CallPAMSessionTermination(b.httpClient, b.sessionId) + if err != nil { + log.Error().Err(err).Msg("Failed to terminate session via API fallback") + } else { + log.Debug().Msg("Session terminated successfully via API fallback") + } +} + +// WaitForConnectionsWithTimeout waits for active connections to close with a timeout +func (b *BaseProxyServer) WaitForConnectionsWithTimeout(timeout time.Duration) { + done := make(chan struct{}) + go func() { + b.activeConnections.Wait() + close(done) + }() + + select { + case <-done: + log.Debug().Msg("All connections closed gracefully") + case <-time.After(timeout): + log.Warn().Msg("Timeout waiting for connections to close, forcing shutdown") + } +} diff --git a/packages/pam/local-database-proxy.go b/packages/pam/local/database-proxy.go similarity index 56% rename from packages/pam/local-database-proxy.go rename to packages/pam/local/database-proxy.go index eceb95d9..66f9505e 100644 --- a/packages/pam/local-database-proxy.go +++ b/packages/pam/local/database-proxy.go @@ -2,43 +2,25 @@ package pam import ( "context" - "crypto/tls" - "crypto/x509" "fmt" "io" "net" "os" "os/signal" - "strconv" - "strings" - "sync" "syscall" "time" "github.com/Infisical/infisical-merge/packages/api" + "github.com/Infisical/infisical-merge/packages/pam/session" "github.com/Infisical/infisical-merge/packages/util" "github.com/go-resty/resty/v2" "github.com/rs/zerolog/log" ) type DatabaseProxyServer struct { - httpClient *resty.Client - server net.Listener - port int - relayHost string - relayClientCert string - relayClientKey string - relayServerCertChain string - gatewayClientCert string - gatewayClientKey string - gatewayServerCertChain string - sessionExpiry time.Time - sessionId string - ctx context.Context - cancel context.CancelFunc - activeConnections sync.WaitGroup - shutdownOnce sync.Once - shutdownCh chan struct{} + BaseProxyServer // Embed common functionality + server net.Listener + port int } type ALPN string @@ -78,19 +60,21 @@ func StartDatabaseLocalProxy(accessToken string, accountID string, durationStr s ctx, cancel := context.WithCancel(context.Background()) proxy := &DatabaseProxyServer{ - httpClient: httpClient, - relayHost: pamResponse.RelayHost, - relayClientCert: pamResponse.RelayClientCertificate, - relayClientKey: pamResponse.RelayClientPrivateKey, - relayServerCertChain: pamResponse.RelayServerCertificateChain, - gatewayClientCert: pamResponse.GatewayClientCertificate, - gatewayClientKey: pamResponse.GatewayClientPrivateKey, - gatewayServerCertChain: pamResponse.GatewayServerCertificateChain, - sessionExpiry: time.Now().Add(duration), - sessionId: pamResponse.SessionId, - ctx: ctx, - cancel: cancel, - shutdownCh: make(chan struct{}), + BaseProxyServer: BaseProxyServer{ + httpClient: httpClient, + relayHost: pamResponse.RelayHost, + relayClientCert: pamResponse.RelayClientCertificate, + relayClientKey: pamResponse.RelayClientPrivateKey, + relayServerCertChain: pamResponse.RelayServerCertificateChain, + gatewayClientCert: pamResponse.GatewayClientCertificate, + gatewayClientKey: pamResponse.GatewayClientPrivateKey, + gatewayServerCertChain: pamResponse.GatewayServerCertificateChain, + sessionExpiry: time.Now().Add(duration), + sessionId: pamResponse.SessionId, + ctx: ctx, + cancel: cancel, + shutdownCh: make(chan struct{}), + }, } err = proxy.Start(port) @@ -136,9 +120,9 @@ func StartDatabaseLocalProxy(accessToken string, accountID string, durationStr s fmt.Printf("You can now connect to your database using this connection string:\n") switch pamResponse.ResourceType { - case ResourceTypePostgres: + case session.ResourceTypePostgres: fmt.Printf("postgres://%s@localhost:%d/%s", username, proxy.port, database) - case ResourceTypeMysql: + case session.ResourceTypeMysql: fmt.Printf("mysql://%s@localhost:%d/%s", username, proxy.port, database) default: fmt.Printf("localhost:%d", proxy.port) @@ -181,7 +165,7 @@ func (p *DatabaseProxyServer) gracefulShutdown() { log.Info().Msg("Starting graceful shutdown of database proxy...") // Send session termination notification before cancelling context - p.notifySessionTermination() + p.NotifySessionTermination() // Signal the accept loop to stop close(p.shutdownCh) @@ -194,58 +178,14 @@ func (p *DatabaseProxyServer) gracefulShutdown() { // Cancel context to signal all goroutines to stop p.cancel() - done := make(chan struct{}) - go func() { - p.activeConnections.Wait() - close(done) - }() - - select { - case <-done: - log.Info().Msg("All connections closed gracefully") - case <-time.After(10 * time.Second): - log.Warn().Msg("Timeout waiting for connections to close, forcing shutdown") - } + // Wait for connections to close + p.WaitForConnectionsWithTimeout(10 * time.Second) log.Info().Msg("Database proxy shutdown complete") os.Exit(0) }) } -// notifySessionTermination sends a termination notification through the gateway -func (p *DatabaseProxyServer) notifySessionTermination() { - log.Info().Msgf("Notifying session termination for session ID: %s", p.sessionId) - - // Try to notify via gateway connection first - relayConn, err := p.createRelayConnection() - if err != nil { - log.Error().Err(err).Msg("Failed to connect to relay for termination notification") - // Fallback to API call if relay connection fails - p.fallbackToAPITermination() - return - } - defer relayConn.Close() - - gatewayConn, err := p.createGatewayConnection(relayConn, ALPNInfisicalPAMCancellation) - if err != nil { - log.Error().Err(err).Msg("Failed to connect to gateway for termination notification") - // Fallback to API call if gateway connection fails - p.fallbackToAPITermination() - return - } - defer gatewayConn.Close() - log.Info().Msg("Session termination notification sent successfully") -} - -func (p *DatabaseProxyServer) fallbackToAPITermination() { - err := api.CallPAMSessionTermination(p.httpClient, p.sessionId) - if err != nil { - log.Error().Err(err).Msg("Failed to terminate session via API fallback") - } else { - log.Info().Msg("Session terminated successfully via API fallback") - } -} - func (p *DatabaseProxyServer) Run() { defer p.server.Close() @@ -307,14 +247,14 @@ func (p *DatabaseProxyServer) handleConnection(clientConn net.Conn) { default: } - relayConn, err := p.createRelayConnection() + relayConn, err := p.CreateRelayConnection() if err != nil { log.Error().Err(err).Msg("Failed to connect to relay") return } defer relayConn.Close() - gatewayConn, err := p.createGatewayConnection(relayConn, ALPNInfisicalPAMProxy) + gatewayConn, err := p.CreateGatewayConnection(relayConn, ALPNInfisicalPAMProxy) if err != nil { log.Error().Err(err).Msg("Failed to connect to gateway") return @@ -363,86 +303,3 @@ func (p *DatabaseProxyServer) handleConnection(clientConn net.Conn) { log.Info().Msgf("Connection closed for client: %s", clientConn.RemoteAddr().String()) } - -func (p *DatabaseProxyServer) createRelayConnection() (net.Conn, error) { - var host string - var port int = 8443 - - if strings.Contains(p.relayHost, ":") { - var portStr string - var err error - host, portStr, err = net.SplitHostPort(p.relayHost) - if err != nil { - return nil, fmt.Errorf("invalid relay host format: %w", err) - } - port, err = strconv.Atoi(portStr) - if err != nil { - return nil, fmt.Errorf("invalid port in relay host: %w", err) - } - } else { - host = p.relayHost - } - - // Load relay certificates - cert, err := tls.X509KeyPair([]byte(p.relayClientCert), []byte(p.relayClientKey)) - if err != nil { - return nil, fmt.Errorf("failed to load relay client certificate: %w", err) - } - - caCertPool := x509.NewCertPool() - if !caCertPool.AppendCertsFromPEM([]byte(p.relayServerCertChain)) { - return nil, fmt.Errorf("failed to parse relay server certificate chain") - } - - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: caCertPool, - ServerName: host, - MinVersion: tls.VersionTLS12, - } - - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", host, port), tlsConfig) - if err != nil { - return nil, fmt.Errorf("failed to connect to relay: %w", err) - } - - log.Debug().Msg("Relay TLS connection established") - return conn, nil -} - -func (p *DatabaseProxyServer) createGatewayConnection(relayConn net.Conn, alpn ALPN) (net.Conn, error) { - // Load gateway certificates - cert, err := tls.X509KeyPair([]byte(p.gatewayClientCert), []byte(p.gatewayClientKey)) - if err != nil { - return nil, fmt.Errorf("failed to load gateway client certificate: %w", err) - } - - caCertPool := x509.NewCertPool() - if !caCertPool.AppendCertsFromPEM([]byte(p.gatewayServerCertChain)) { - return nil, fmt.Errorf("failed to parse gateway server certificate chain") - } - - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: caCertPool, - MinVersion: tls.VersionTLS12, - MaxVersion: tls.VersionTLS13, - NextProtos: []string{string(alpn)}, - ServerName: "localhost", - } - - gatewayConn := tls.Client(relayConn, tlsConfig) - - err = gatewayConn.Handshake() - if err != nil { - return nil, fmt.Errorf("failed to establish gateway mTLS: %w", err) - } - - state := gatewayConn.ConnectionState() - if !state.HandshakeComplete { - return nil, fmt.Errorf("gateway TLS handshake not complete") - } - - log.Debug().Msg("Gateway mTLS connection established") - return gatewayConn, nil -} diff --git a/packages/pam/local/ssh-proxy.go b/packages/pam/local/ssh-proxy.go new file mode 100644 index 00000000..5215b3c8 --- /dev/null +++ b/packages/pam/local/ssh-proxy.go @@ -0,0 +1,344 @@ +package pam + +import ( + "context" + "fmt" + "io" + "net" + "os" + "os/exec" + "os/signal" + "strconv" + "strings" + "syscall" + "time" + + "github.com/Infisical/infisical-merge/packages/api" + "github.com/Infisical/infisical-merge/packages/pam/session" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/go-resty/resty/v2" + "github.com/rs/zerolog/log" +) + +type SSHProxyServer struct { + BaseProxyServer // Embed common functionality + server net.Listener + port int + sshProcess *exec.Cmd +} + +func StartSSHLocalProxy(accessToken string, accountID string, durationStr string) { + httpClient := resty.New() + httpClient.SetAuthToken(accessToken) + httpClient.SetHeader("User-Agent", "infisical-cli") + + pamRequest := api.PAMAccessRequest{ + Duration: durationStr, + AccountId: accountID, + } + + pamResponse, err := api.CallPAMAccess(httpClient, pamRequest) + if err != nil { + util.HandleError(err, "Failed to access PAM account") + return + } + + // Verify this is an SSH resource + if pamResponse.ResourceType != session.ResourceTypeSSH { + util.HandleError(fmt.Errorf("account is not an SSH resource, got: %s", pamResponse.ResourceType), "Invalid resource type") + return + } + + duration, err := time.ParseDuration(durationStr) + if err != nil { + util.HandleError(err, "Failed to parse duration") + return + } + + ctx, cancel := context.WithCancel(context.Background()) + + proxy := &SSHProxyServer{ + BaseProxyServer: BaseProxyServer{ + httpClient: httpClient, + relayHost: pamResponse.RelayHost, + relayClientCert: pamResponse.RelayClientCertificate, + relayClientKey: pamResponse.RelayClientPrivateKey, + relayServerCertChain: pamResponse.RelayServerCertificateChain, + gatewayClientCert: pamResponse.GatewayClientCertificate, + gatewayClientKey: pamResponse.GatewayClientPrivateKey, + gatewayServerCertChain: pamResponse.GatewayServerCertificateChain, + sessionExpiry: time.Now().Add(duration), + sessionId: pamResponse.SessionId, + ctx: ctx, + cancel: cancel, + shutdownCh: make(chan struct{}), + }, + } + + // Start the local TCP proxy on a random port + err = proxy.Start(0) // 0 = random port + if err != nil { + util.HandleError(err, "Failed to start SSH proxy server") + return + } + + // Extract metadata + username, ok := pamResponse.Metadata["username"] + if !ok { + util.HandleError(fmt.Errorf("PAM response metadata is missing 'username'"), "Failed to start proxy server") + return + } + + log.Debug(). + Str("sessionID", pamResponse.SessionId). + Str("username", username). + Int("port", proxy.port). + Msg("SSH proxy ready") + + // Set up signal handling + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + go func() { + sig := <-sigChan + log.Debug().Msgf("Received signal %v, initiating graceful shutdown...", sig) + proxy.gracefulShutdown() + }() + + // Start the proxy server in a goroutine + go proxy.Run() + + // Give the proxy a moment to start accepting connections + time.Sleep(500 * time.Millisecond) + + // Launch SSH client connected to the local proxy (transparent to user) + err = proxy.launchSSHClient(username) + if err != nil { + log.Error().Err(err).Msg("Failed to launch SSH client") + proxy.gracefulShutdown() + return + } + + // Wait for SSH process to complete + proxy.waitForSSHCompletion() + + // SSH client exited, shutdown gracefully + proxy.gracefulShutdown() +} + +func (p *SSHProxyServer) Start(port int) error { + var err error + if port == 0 { + p.server, err = net.Listen("tcp", "127.0.0.1:0") // Bind to localhost only + } else { + p.server, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + } + + if err != nil { + return fmt.Errorf("failed to start server: %w", err) + } + + addr := p.server.Addr().(*net.TCPAddr) + p.port = addr.Port + + log.Debug().Msgf("SSH proxy server listening on 127.0.0.1:%d", p.port) + + return nil +} + +func (p *SSHProxyServer) launchSSHClient(username string) error { + // Build SSH command: ssh -p @localhost + sshArgs := []string{ + "-p", strconv.Itoa(p.port), + "-o", "StrictHostKeyChecking=no", // Skip host key verification (we're connecting to localhost) + "-o", "UserKnownHostsFile=/dev/null", + "-o", "LogLevel=ERROR", + fmt.Sprintf("%s@127.0.0.1", username), + } + + p.sshProcess = exec.Command("ssh", sshArgs...) + p.sshProcess.Stdin = os.Stdin + p.sshProcess.Stdout = os.Stdout + p.sshProcess.Stderr = os.Stderr + + log.Debug().Msgf("Executing: ssh %s", strings.Join(sshArgs, " ")) + + err := p.sshProcess.Start() + if err != nil { + return fmt.Errorf("failed to start SSH client: %w", err) + } + + log.Debug().Msgf("SSH client started with PID: %d", p.sshProcess.Process.Pid) + return nil +} + +func (p *SSHProxyServer) waitForSSHCompletion() { + if p.sshProcess == nil { + return + } + + err := p.sshProcess.Wait() + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + log.Debug().Msgf("SSH client exited with code: %d", exitErr.ExitCode()) + } else { + log.Error().Err(err).Msg("Error waiting for SSH client") + } + } else { + log.Debug().Msg("SSH client exited successfully") + } +} + +func (p *SSHProxyServer) gracefulShutdown() { + p.shutdownOnce.Do(func() { + log.Debug().Msg("Starting graceful shutdown of SSH proxy...") + + // Kill SSH process if it's still running + if p.sshProcess != nil && p.sshProcess.Process != nil { + log.Debug().Msg("Terminating SSH client process") + p.sshProcess.Process.Signal(syscall.SIGTERM) + } + + // Send session termination notification before cancelling context + p.NotifySessionTermination() + + // Signal the accept loop to stop + close(p.shutdownCh) + + // Close the server to stop accepting new connections + if p.server != nil { + p.server.Close() + } + + // Cancel context to signal all goroutines to stop + p.cancel() + + // Wait for connections to close + p.WaitForConnectionsWithTimeout(10 * time.Second) + + log.Debug().Msg("SSH proxy shutdown complete") + os.Exit(0) + }) +} + +func (p *SSHProxyServer) Run() { + defer p.server.Close() + + for { + select { + case <-p.ctx.Done(): + log.Debug().Msg("Context cancelled, stopping proxy server") + return + case <-p.shutdownCh: + log.Debug().Msg("Shutdown signal received, stopping proxy server") + return + default: + // Check if session has expired + if time.Now().After(p.sessionExpiry) { + log.Warn().Msg("SSH session expired, shutting down proxy") + p.gracefulShutdown() + return + } + + if tcpListener, ok := p.server.(*net.TCPListener); ok { + tcpListener.SetDeadline(time.Now().Add(1 * time.Second)) + } + + conn, err := p.server.Accept() + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue + } + select { + case <-p.ctx.Done(): + return + case <-p.shutdownCh: + return + default: + log.Error().Err(err).Msg("Failed to accept connection") + continue + } + } + + // Track active connection + p.activeConnections.Add(1) + go p.handleConnection(conn) + } + } +} + +func (p *SSHProxyServer) handleConnection(clientConn net.Conn) { + defer func() { + clientConn.Close() + p.activeConnections.Done() + }() + + log.Debug().Msgf("New SSH connection from %s", clientConn.RemoteAddr()) + + select { + case <-p.ctx.Done(): + log.Debug().Msg("Context cancelled, closing connection immediately") + return + default: + } + + // Connect to relay + relayConn, err := p.CreateRelayConnection() + if err != nil { + log.Error().Err(err).Msg("Failed to connect to relay") + return + } + defer relayConn.Close() + + // Connect to gateway (SSH proxy will handle the SSH protocol) + gatewayConn, err := p.CreateGatewayConnection(relayConn, ALPNInfisicalPAMProxy) + if err != nil { + log.Error().Err(err).Msg("Failed to connect to gateway") + return + } + defer gatewayConn.Close() + + log.Debug().Msg("Established connection to SSH gateway") + + connCtx, connCancel := context.WithCancel(p.ctx) + defer connCancel() + + errCh := make(chan error, 2) + + // Bidirectional data forwarding with context cancellation + // Client (local SSH) → Gateway (SSH proxy) + go func() { + defer connCancel() + _, err := io.Copy(gatewayConn, clientConn) + if err != nil { + select { + case <-connCtx.Done(): + default: + log.Debug().Err(err).Msg("Client to gateway copy ended") + } + } + errCh <- err + }() + + // Gateway (SSH proxy) → Client (local SSH) + go func() { + defer connCancel() + _, err := io.Copy(clientConn, gatewayConn) + if err != nil { + select { + case <-connCtx.Done(): + default: + log.Debug().Err(err).Msg("Gateway to client copy ended") + } + } + errCh <- err + }() + + select { + case <-errCh: + case <-connCtx.Done(): + log.Debug().Msg("Connection cancelled by context") + } + + log.Debug().Msgf("SSH connection closed for client: %s", clientConn.RemoteAddr().String()) +} diff --git a/packages/pam/pam-proxy.go b/packages/pam/pam-proxy.go index bcea3fb6..78e55ca7 100644 --- a/packages/pam/pam-proxy.go +++ b/packages/pam/pam-proxy.go @@ -5,20 +5,16 @@ import ( "crypto/tls" "crypto/x509" "fmt" - "github.com/Infisical/infisical-merge/packages/pam/handlers/mysql" "time" "github.com/Infisical/infisical-merge/packages/pam/handlers" + "github.com/Infisical/infisical-merge/packages/pam/handlers/mysql" + "github.com/Infisical/infisical-merge/packages/pam/handlers/ssh" "github.com/Infisical/infisical-merge/packages/pam/session" "github.com/go-resty/resty/v2" "github.com/rs/zerolog/log" ) -const ( - ResourceTypePostgres = "postgres" - ResourceTypeMysql = "mysql" -) - type GatewayPAMConfig struct { SessionId string ResourceType string @@ -89,7 +85,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo if err != nil { return fmt.Errorf("failed to get PAM session encryption key: %w", err) } - sessionLogger, err := session.NewSessionLogger(pamConfig.SessionId, encryptionKey, pamConfig.ExpiryTime) + sessionLogger, err := session.NewSessionLogger(pamConfig.SessionId, encryptionKey, pamConfig.ExpiryTime, pamConfig.ResourceType) if err != nil { return fmt.Errorf("failed to create session logger: %w", err) } @@ -114,7 +110,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo } switch pamConfig.ResourceType { - case ResourceTypePostgres: + case session.ResourceTypePostgres: proxyConfig := handlers.PostgresProxyConfig{ TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), InjectUsername: credentials.Username, @@ -132,7 +128,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Bool("sslEnabled", credentials.SSLEnabled). Msg("Starting PostgreSQL PAM proxy") return proxy.HandleConnection(ctx, conn) - case ResourceTypeMysql: + case session.ResourceTypeMysql: mysqlConfig := mysql.MysqlProxyConfig{ TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), InjectUsername: credentials.Username, @@ -151,6 +147,23 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Bool("sslEnabled", credentials.SSLEnabled). Msg("Starting MySQL PAM proxy") return proxy.HandleConnection(ctx, conn) + case session.ResourceTypeSSH: + sshConfig := ssh.SSHProxyConfig{ + TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), + AuthMethod: credentials.AuthMethod, + InjectUsername: credentials.Username, + InjectPassword: credentials.Password, + InjectPrivateKey: credentials.PrivateKey, + SessionID: pamConfig.SessionId, + SessionLogger: sessionLogger, + } + proxy := ssh.NewSSHProxy(sshConfig) + log.Info(). + Str("sessionId", pamConfig.SessionId). + Str("target", sshConfig.TargetAddr). + Msg("Starting SSH PAM proxy") + + return proxy.HandleConnection(ctx, conn) default: return fmt.Errorf("unsupported resource type: %s", pamConfig.ResourceType) } diff --git a/packages/pam/session/credentials.go b/packages/pam/session/credentials.go index f5efd3df..3472bcf0 100644 --- a/packages/pam/session/credentials.go +++ b/packages/pam/session/credentials.go @@ -11,9 +11,11 @@ import ( ) type PAMCredentials struct { + AuthMethod string Username string Password string Database string + PrivateKey string Host string Port int SSLEnabled bool @@ -80,9 +82,11 @@ func (cm *CredentialsManager) GetPAMSessionCredentials(sessionId string, expiryT } credentials := &PAMCredentials{ + AuthMethod: response.Credentials.AuthMethod, Username: response.Credentials.Username, Password: response.Credentials.Password, Database: response.Credentials.Database, + PrivateKey: response.Credentials.PrivateKey, Host: response.Credentials.Host, Port: response.Credentials.Port, SSLEnabled: response.Credentials.SSLEnabled, diff --git a/packages/pam/session/logger.go b/packages/pam/session/logger.go index 033b2216..3f9f5040 100644 --- a/packages/pam/session/logger.go +++ b/packages/pam/session/logger.go @@ -23,8 +23,25 @@ type SessionLogEntry struct { Output string `json:"output"` } +// TerminalEventType represents the type of terminal event +type TerminalEventType string + +const ( + TerminalEventInput TerminalEventType = "input" // Data from user to server + TerminalEventOutput TerminalEventType = "output" // Data from server to user +) + +// TerminalEvent represents a single event in a terminal session +type TerminalEvent struct { + Timestamp time.Time `json:"timestamp"` + EventType TerminalEventType `json:"eventType"` + Data []byte `json:"data"` // Raw terminal data + ElapsedTime float64 `json:"elapsedTime"` // Seconds since session start (for replay) +} + type SessionLogger interface { LogEntry(entry SessionLogEntry) error + LogTerminalEvent(event TerminalEvent) error Close() error } @@ -34,6 +51,7 @@ type EncryptedSessionLogger struct { expiresAt time.Time file *os.File mutex sync.Mutex + sessionStart time.Time // Track session start time for elapsed time calculation } type RequestResponsePair struct { @@ -133,7 +151,7 @@ func CleanupSessionMutex(sessionID string) { } } -func NewSessionLogger(sessionID string, encryptionKey string, expiresAt time.Time) (*EncryptedSessionLogger, error) { +func NewSessionLogger(sessionID string, encryptionKey string, expiresAt time.Time, resourceType string) (*EncryptedSessionLogger, error) { if sessionID == "" { return nil, fmt.Errorf("session ID cannot be empty") } @@ -147,7 +165,14 @@ func NewSessionLogger(sessionID string, encryptionKey string, expiresAt time.Tim return nil, fmt.Errorf("failed to create session recording directory: %w", err) } - filename := fmt.Sprintf("pam_session_%s_expires_%d.enc", sessionID, expiresAt.Unix()) + // Use new filename format with resource type if provided + var filename string + if resourceType != "" { + filename = fmt.Sprintf("pam_session_%s_%s_expires_%d.enc", sessionID, resourceType, expiresAt.Unix()) + } else { + // Legacy format for backwards compatibility + filename = fmt.Sprintf("pam_session_%s_expires_%d.enc", sessionID, expiresAt.Unix()) + } fullPath := filepath.Join(recordingDir, filename) // Open file in append mode to support multiple connections per session @@ -161,6 +186,7 @@ func NewSessionLogger(sessionID string, encryptionKey string, expiresAt time.Tim encryptionKey: encryptionKey, expiresAt: expiresAt, file: file, + sessionStart: time.Now(), }, nil } @@ -206,6 +232,55 @@ func (sl *EncryptedSessionLogger) LogEntry(entry SessionLogEntry) error { return nil } +func (sl *EncryptedSessionLogger) LogTerminalEvent(event TerminalEvent) error { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + if sl.file == nil { + return fmt.Errorf("session logger not initialized") + } + + // Calculate elapsed time if not already set + if event.ElapsedTime == 0 { + event.ElapsedTime = time.Since(sl.sessionStart).Seconds() + } + + jsonData, err := json.Marshal(event) + if err != nil { + return fmt.Errorf("failed to marshal terminal event: %w", err) + } + + encryptedData, err := EncryptData(jsonData, sl.encryptionKey) + if err != nil { + return fmt.Errorf("failed to encrypt data: %w", err) + } + + // Use session-level mutex to ensure atomic writes across concurrent connections + sessionMutex := getSessionMutex(sl.sessionID, sl.expiresAt) + sessionMutex.Lock() + defer sessionMutex.Unlock() + + // Write length-prefixed encrypted record (4 bytes length + encrypted data) + lengthBytes := make([]byte, 4) + binary.BigEndian.PutUint32(lengthBytes, uint32(len(encryptedData))) + + if _, err := sl.file.Write(lengthBytes); err != nil { + return fmt.Errorf("failed to write length prefix: %w", err) + } + + if _, err := sl.file.Write(encryptedData); err != nil { + return fmt.Errorf("failed to write encrypted data: %w", err) + } + + // For high-frequency events like terminal I/O, we might want to buffer + // But for now, sync to ensure durability + if err := sl.file.Sync(); err != nil { + return fmt.Errorf("failed to sync file: %w", err) + } + + return nil +} + func (sl *EncryptedSessionLogger) Close() error { sl.mutex.Lock() defer sl.mutex.Unlock() diff --git a/packages/pam/session/uploader.go b/packages/pam/session/uploader.go index a9179f07..afe51448 100644 --- a/packages/pam/session/uploader.go +++ b/packages/pam/session/uploader.go @@ -20,10 +20,18 @@ import ( var ErrSessionFileNotFound = errors.New("session file not found") +// Resource type constants +const ( + ResourceTypePostgres = "postgres" + ResourceTypeMysql = "mysql" + ResourceTypeSSH = "ssh" +) + type SessionFileInfo struct { - SessionID string - ExpiresAt time.Time - Filename string + SessionID string + ExpiresAt time.Time + Filename string + ResourceType string // ResourceTypeSSH, ResourceTypePostgres, ResourceTypeMysql (empty for legacy files) } type SessionUploader struct { @@ -43,10 +51,35 @@ func NewSessionUploader(httpClient *resty.Client, credentialsManager *Credential } func ParseSessionFilename(filename string) (*SessionFileInfo, error) { - regex := regexp.MustCompile(`^pam_session_(.+)_expires_(\d+)\.enc$`) - matches := regex.FindStringSubmatch(filename) + // Try new format first: pam_session_{sessionID}_{resourceType}_expires_{timestamp}.enc + // Build regex pattern using constants + resourceTypePattern := fmt.Sprintf("(%s|%s|%s)", ResourceTypeSSH, ResourceTypePostgres, ResourceTypeMysql) + newFormatRegex := regexp.MustCompile(fmt.Sprintf(`^pam_session_(.+)_%s_expires_(\d+)\.enc$`, resourceTypePattern)) + matches := newFormatRegex.FindStringSubmatch(filename) + + if len(matches) == 4 { + sessionID := matches[1] + resourceType := matches[2] + timestampStr := matches[3] + + timestamp, err := strconv.ParseInt(timestampStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid timestamp in filename %s: %w", filename, err) + } + + return &SessionFileInfo{ + SessionID: sessionID, + ExpiresAt: time.Unix(timestamp, 0), + Filename: filename, + ResourceType: resourceType, + }, nil + } + + // Fall back to legacy format for backwards compatibility: pam_session_{sessionID}_expires_{timestamp}.enc + legacyFormatRegex := regexp.MustCompile(`^pam_session_(.+)_expires_(\d+)\.enc$`) + matches = legacyFormatRegex.FindStringSubmatch(filename) if len(matches) != 3 { - return nil, fmt.Errorf("filename %s does not match expected format: pam_session_{sessionID}_expires_{timestamp}.enc", filename) + return nil, fmt.Errorf("filename %s does not match expected format", filename) } sessionID := matches[1] @@ -58,9 +91,10 @@ func ParseSessionFilename(filename string) (*SessionFileInfo, error) { } return &SessionFileInfo{ - SessionID: sessionID, - ExpiresAt: time.Unix(timestamp, 0), - Filename: filename, + SessionID: sessionID, + ExpiresAt: time.Unix(timestamp, 0), + Filename: filename, + ResourceType: "", // Empty for legacy files (assume database format) }, nil } @@ -168,6 +202,60 @@ func ReadEncryptedSessionLogByFilename(filename string, encryptionKey string) ([ return entries, nil } +// ReadEncryptedTerminalEventsFromFile reads terminal events from an encrypted session file +func ReadEncryptedTerminalEventsFromFile(filename string, encryptionKey string) ([]TerminalEvent, error) { + recordingDir := GetSessionRecordingDir() + fullPath := filepath.Join(recordingDir, filename) + + file, err := os.Open(fullPath) + if err != nil { + return nil, fmt.Errorf("failed to open session file: %w", err) + } + defer file.Close() + + var events []TerminalEvent + + for { + // Read length prefix (4 bytes) + lengthBytes := make([]byte, 4) + n, err := file.Read(lengthBytes) + if err == io.EOF { + break // End of file + } + if err != nil { + return nil, fmt.Errorf("failed to read length prefix: %w", err) + } + if n != 4 { + return nil, fmt.Errorf("incomplete length prefix read") + } + + length := binary.BigEndian.Uint32(lengthBytes) + + encryptedData := make([]byte, length) + n, err = io.ReadFull(file, encryptedData) + if err != nil { + return nil, fmt.Errorf("failed to read encrypted data: %w", err) + } + if uint32(n) != length { + return nil, fmt.Errorf("incomplete encrypted data read") + } + + decryptedData, err := DecryptData(encryptedData, encryptionKey) + if err != nil { + return nil, fmt.Errorf("failed to decrypt session data: %w", err) + } + + var event TerminalEvent + if err := json.Unmarshal(decryptedData, &event); err != nil { + return nil, fmt.Errorf("failed to unmarshal terminal event: %w", err) + } + + events = append(events, event) + } + + return events, nil +} + func (su *SessionUploader) Start() { su.startOnce.Do(su.startUploadRoutine) } @@ -229,11 +317,54 @@ func (su *SessionUploader) uploadSessionFile(fileInfo *SessionFileInfo) error { return fmt.Errorf("failed to get encryption key: %w", err) } + // Use resource type to determine how to read the file + if fileInfo.ResourceType == ResourceTypeSSH { + // SSH session - read as terminal events + terminalEvents, err := ReadEncryptedTerminalEventsFromFile(fileInfo.Filename, encryptionKey) + if err != nil { + return fmt.Errorf("failed to read SSH session file: %w", err) + } + + log.Debug(). + Str("sessionId", fileInfo.SessionID). + Str("resourceType", fileInfo.ResourceType). + Int("eventCount", len(terminalEvents)). + Msg("Uploading terminal session events") + + var logs []api.UploadTerminalEvent + for _, event := range terminalEvents { + logs = append(logs, api.UploadTerminalEvent{ + Timestamp: event.Timestamp, + EventType: string(event.EventType), + Data: event.Data, + ElapsedTime: event.ElapsedTime, + }) + } + + request := api.UploadPAMSessionLogsRequest{ + Logs: logs, + } + + return api.CallUploadPamSessionLogs(su.httpClient, fileInfo.SessionID, request) + } + + // Database session (postgres, mysql, or legacy format) - read as request/response logs entries, err := ReadEncryptedSessionLogByFilename(fileInfo.Filename, encryptionKey) if err != nil { return fmt.Errorf("failed to read session file: %w", err) } + resourceTypeMsg := fileInfo.ResourceType + if resourceTypeMsg == "" { + resourceTypeMsg = "legacy" + } + + log.Debug(). + Str("sessionId", fileInfo.SessionID). + Str("resourceType", resourceTypeMsg). + Int("entryCount", len(entries)). + Msg("Uploading database session logs") + var logs []api.UploadSessionLogEntry for _, entry := range entries { logs = append(logs, api.UploadSessionLogEntry{