diff --git a/packages/api/api.go b/packages/api/api.go index a6820870..a52ccc29 100644 --- a/packages/api/api.go +++ b/packages/api/api.go @@ -54,6 +54,7 @@ const ( operationCallGetPamSessionKey = "CallGetPamSessionKey" operationCallUploadPamSessionLog = "CallUploadPamSessionLog" operationCallPAMSessionTermination = "CallPAMSessionTermination" + operationCallUploadPamSessionEventBatch = "CallUploadPamSessionEventBatch" operationCallGetMFASessionStatus = "CallGetMFASessionStatus" operationCallOrgRelayHeartBeat = "CallOrgRelayHeartBeat" operationCallInstanceRelayHeartBeat = "CallInstanceRelayHeartBeat" @@ -1008,6 +1009,23 @@ func CallUploadPamSessionLogs(httpClient *resty.Client, sessionId string, reques return nil } +func CallUploadPamSessionEventBatch(httpClient *resty.Client, sessionId string, startOffset int64, data []byte) error { + response, err := httpClient. + R(). + SetHeader("User-Agent", USER_AGENT). + SetHeader("Content-Type", "application/octet-stream"). + SetBody(data). + Post(fmt.Sprintf("%v/v1/pam/sessions/%s/event-batches?startOffset=%d", config.INFISICAL_URL, sessionId, startOffset)) + + if err != nil { + return NewGenericRequestError(operationCallUploadPamSessionEventBatch, err) + } + if response.IsError() { + return NewAPIErrorWithResponse(operationCallUploadPamSessionEventBatch, response, nil) + } + return nil +} + func CallPAMSessionTermination(httpClient *resty.Client, sessionId string) error { response, err := httpClient. R(). diff --git a/packages/pam/pam-proxy.go b/packages/pam/pam-proxy.go index a4464b52..ac5a4615 100644 --- a/packages/pam/pam-proxy.go +++ b/packages/pam/pam-proxy.go @@ -151,6 +151,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo if err != nil { return fmt.Errorf("failed to create session logger: %w", err) } + pamConfig.SessionUploader.RegisterSession(pamConfig.SessionId) serverName := credentials.Host if pamConfig.ResourceType == session.ResourceTypeKubernetes { diff --git a/packages/pam/session/uploader.go b/packages/pam/session/uploader.go index 3ca26cb7..cb30e525 100644 --- a/packages/pam/session/uploader.go +++ b/packages/pam/session/uploader.go @@ -6,10 +6,12 @@ import ( "errors" "fmt" "io" + "net/http" "os" "path/filepath" "regexp" "strconv" + "strings" "sync" "time" @@ -37,12 +39,23 @@ type SessionFileInfo struct { ResourceType string // ResourceTypeSSH, ResourceTypePostgres, ResourceTypeMysql (empty for legacy files) } +// sessionUploadState tracks incremental upload progress for an active session. +type sessionUploadState struct { + fileOffset int64 + filename string // base filename (not full path) of the session recording + legacyMode bool // true if the batch upload endpoint returned 404 (platform too old); fall back to bulk upload at session end + mu sync.Mutex +} + type SessionUploader struct { httpClient *resty.Client credentialsManager *CredentialsManager ticker *time.Ticker stopChan chan struct{} startOnce sync.Once + + activeSessions map[string]*sessionUploadState + activeSessionsMu sync.RWMutex } func NewSessionUploader(httpClient *resty.Client, credentialsManager *CredentialsManager) *SessionUploader { @@ -50,6 +63,7 @@ func NewSessionUploader(httpClient *resty.Client, credentialsManager *Credential httpClient: httpClient, credentialsManager: credentialsManager, stopChan: make(chan struct{}), + activeSessions: make(map[string]*sessionUploadState), } } @@ -124,7 +138,7 @@ func ListSessionFiles() ([]*SessionFileInfo, error) { fileInfo, err := ParseSessionFilename(entry.Name()) if err != nil { - // Skip files that don't match our format + // Skip files that don't match our format (including .offset sidecars) continue } @@ -216,6 +230,136 @@ func ReadEncryptedHttpEventsFromFile(filename string, encryptionKey string) ([]H return readEncryptedEntries[HttpEvent](filename, encryptionKey) } +// offsetFilePath returns the path to the persisted offset file for a given recording filename. +func offsetFilePath(filename string) string { + return filepath.Join(GetSessionRecordingDir(), strings.TrimSuffix(filename, ".enc")+".offset") +} + +// readPersistedOffset reads the persisted file offset for a session recording. +func readPersistedOffset(filename string) (int64, bool) { + data, err := os.ReadFile(offsetFilePath(filename)) + if err != nil { + return 0, false + } + offset, err := strconv.ParseInt(strings.TrimSpace(string(data)), 10, 64) + if err != nil { + return 0, false + } + return offset, true +} + +// writePersistedOffset atomically writes the current file offset to disk. +func writePersistedOffset(filename string, offset int64) error { + path := offsetFilePath(filename) + tmpPath := path + ".tmp" + if err := os.WriteFile(tmpPath, []byte(strconv.FormatInt(offset, 10)), 0600); err != nil { + return err + } + return os.Rename(tmpPath, path) +} + +// deletePersistedOffset removes the offset file for a session. +func deletePersistedOffset(filename string) { + _ = os.Remove(offsetFilePath(filename)) +} + +// readFromOffset reads length-prefixed encrypted records from filename starting at offset, +// decrypts each, and returns them as a JSON array payload plus the new file offset. +// Returns nil payload (and the unchanged offset) if there are no new records. +func readFromOffset(filename, encryptionKey string, offset int64) ([]byte, int64, error) { + recordingDir := GetSessionRecordingDir() + fullPath := filepath.Join(recordingDir, filename) + + file, err := os.Open(fullPath) + if err != nil { + return nil, offset, fmt.Errorf("failed to open session file: %w", err) + } + defer file.Close() + + if _, err := file.Seek(offset, io.SeekStart); err != nil { + return nil, offset, fmt.Errorf("failed to seek to offset %d: %w", offset, err) + } + + var entries []json.RawMessage + newOffset := offset + + for { + lengthBytes := make([]byte, 4) + if _, err := io.ReadFull(file, lengthBytes); err != nil { + if err == io.EOF || err == io.ErrUnexpectedEOF { + break // No more complete records + } + return nil, newOffset, fmt.Errorf("failed to read length prefix: %w", err) + } + + length := binary.BigEndian.Uint32(lengthBytes) + encryptedData := make([]byte, length) + if _, err := io.ReadFull(file, encryptedData); err != nil { + break // Partial record at EOF, stop here and retry next tick + } + + decryptedData, err := DecryptData(encryptedData, encryptionKey) + if err != nil { + return nil, newOffset, fmt.Errorf("failed to decrypt record at offset %d: %w", newOffset, err) + } + + entries = append(entries, json.RawMessage(decryptedData)) + newOffset += int64(4 + length) + } + + if len(entries) == 0 { + return nil, newOffset, nil + } + + payload, err := json.Marshal(entries) + if err != nil { + return nil, newOffset, fmt.Errorf("failed to marshal event batch: %w", err) + } + + return payload, newOffset, nil +} + +// RegisterSession registers a session for incremental batch uploads, resuming from +// any previously persisted offset if present. +func (su *SessionUploader) RegisterSession(sessionID string) { + fileInfo, err := FindSessionFileBySessionID(sessionID) + if err != nil { + log.Warn().Err(err).Str("sessionId", sessionID).Msg("[RegisterSession] session file not found, will retry on first flush") + return + } + + var startOffset int64 + if offset, ok := readPersistedOffset(fileInfo.Filename); ok { + startOffset = offset + log.Info().Str("sessionId", sessionID).Int64("resumeOffset", startOffset).Msg("Resuming incremental upload from persisted offset") + } + + su.activeSessionsMu.Lock() + su.activeSessions[sessionID] = &sessionUploadState{ + fileOffset: startOffset, + filename: fileInfo.Filename, + } + su.activeSessionsMu.Unlock() + + log.Debug().Str("sessionId", sessionID).Msg("Registered session for incremental batch upload") +} + +// UnregisterSession removes a session from incremental tracking and deletes its persisted offset. +func (su *SessionUploader) UnregisterSession(sessionID string) { + su.activeSessionsMu.Lock() + state, ok := su.activeSessions[sessionID] + if ok { + delete(su.activeSessions, sessionID) + } + su.activeSessionsMu.Unlock() + + if ok && state.filename != "" { + deletePersistedOffset(state.filename) + } + + log.Debug().Str("sessionId", sessionID).Msg("Unregistered session from incremental batch upload") +} + func (su *SessionUploader) Start() { su.startOnce.Do(su.startUploadRoutine) } @@ -224,17 +368,25 @@ func (su *SessionUploader) startUploadRoutine() { log.Info().Msg("Starting PAM session uploader routine") su.ticker = time.NewTicker(5 * time.Minute) + flushTicker := time.NewTicker(10 * time.Second) go func() { defer su.ticker.Stop() + defer flushTicker.Stop() - // call once immediately + // On startup, re-register any non-expired sessions that were in progress when + // the gateway last shut down or crashed so the flush ticker resumes uploading them. + su.resumeInProgressSessions() + + // Process any orphaned expired files from previous runs immediately su.uploadExpiredSessionFiles() for { select { case <-su.ticker.C: su.uploadExpiredSessionFiles() + case <-flushTicker.C: + su.flushActiveSessions() case <-su.stopChan: return } @@ -242,6 +394,22 @@ func (su *SessionUploader) startUploadRoutine() { }() } +// resumeInProgressSessions re-registers all session files found on disk at startup so +// the flush ticker resumes uploading them after a crash or restart. Expired sessions +// will be cleaned up naturally by uploadExpiredSessionFiles on the next tick. +func (su *SessionUploader) resumeInProgressSessions() { + allFiles, err := ListSessionFiles() + if err != nil { + log.Error().Err(err).Msg("Failed to list session files for resume on startup") + return + } + + for _, fileInfo := range allFiles { + log.Info().Str("sessionId", fileInfo.SessionID).Str("filename", fileInfo.Filename).Msg("Resuming session upload after restart") + su.RegisterSession(fileInfo.SessionID) + } +} + func (su *SessionUploader) uploadExpiredSessionFiles() { expiredFiles, err := GetExpiredSessionFiles() if err != nil { @@ -271,15 +439,79 @@ func (su *SessionUploader) uploadExpiredSessionFiles() { } } +// flushActiveSessions uploads new events for all currently active sessions. +func (su *SessionUploader) flushActiveSessions() { + encryptionKey, err := su.credentialsManager.GetPAMSessionEncryptionKey() + if err != nil { + log.Error().Err(err).Msg("[flushActiveSessions] failed to get encryption key") + return + } + + su.activeSessionsMu.RLock() + sessionIDs := make([]string, 0, len(su.activeSessions)) + for id := range su.activeSessions { + sessionIDs = append(sessionIDs, id) + } + su.activeSessionsMu.RUnlock() + + for _, sessionID := range sessionIDs { + su.flushSession(sessionID, encryptionKey) + } +} + +// flushSession reads new events from the session recording file since the last uploaded offset, +// uploads them as a batch, and advances the offset on success. +func (su *SessionUploader) flushSession(sessionID, encryptionKey string) { + su.activeSessionsMu.RLock() + state, ok := su.activeSessions[sessionID] + su.activeSessionsMu.RUnlock() + if !ok { + return + } + + state.mu.Lock() + defer state.mu.Unlock() + + if state.legacyMode { + return // Platform does not support batch uploads; bulk upload will happen at session end + } + + payload, newOffset, err := readFromOffset(state.filename, encryptionKey, state.fileOffset) + if err != nil { + log.Error().Err(err).Str("sessionId", sessionID).Msg("Failed to read session events for batch upload") + return + } + if len(payload) == 0 { + return // No new events since last flush + } + + if err := api.CallUploadPamSessionEventBatch(su.httpClient, sessionID, state.fileOffset, payload); err != nil { + var apiErr *api.APIError + if errors.As(err, &apiErr) && apiErr.StatusCode == http.StatusNotFound { + // Platform does not support the batch upload endpoint yet; fall back to bulk upload at session end + log.Warn().Str("sessionId", sessionID).Msg("Batch upload endpoint not supported by platform, will use legacy bulk upload at session end") + state.legacyMode = true + return + } + log.Error().Err(err).Str("sessionId", sessionID).Int64("startOffset", state.fileOffset).Msg("Failed to upload session event batch, will retry next tick") + return // Do not advance offset on failure so the batch is retried + } + + state.fileOffset = newOffset + if err := writePersistedOffset(state.filename, newOffset); err != nil { + log.Warn().Err(err).Str("sessionId", sessionID).Msg("Failed to persist offset after flush") + } + + log.Debug().Str("sessionId", sessionID).Int64("newOffset", newOffset).Msg("Flushed session event batch") +} + func (su *SessionUploader) uploadSessionFile(fileInfo *SessionFileInfo) error { encryptionKey, err := su.credentialsManager.GetPAMSessionEncryptionKey() if err != nil { return fmt.Errorf("failed to get encryption key: %w", err) } - // Use resource type to determine how to read the file if fileInfo.ResourceType == ResourceTypeSSH { - // SSH session - read as terminal events terminalEvents, err := ReadEncryptedTerminalEventsFromFile(fileInfo.Filename, encryptionKey) if err != nil { return fmt.Errorf("failed to read SSH session file: %w", err) @@ -302,16 +534,13 @@ func (su *SessionUploader) uploadSessionFile(fileInfo *SessionFileInfo) error { }) } - request := api.UploadPAMSessionLogsRequest{ - Logs: logs, - } - - return api.CallUploadPamSessionLogs(su.httpClient, fileInfo.SessionID, request) + return api.CallUploadPamSessionLogs(su.httpClient, fileInfo.SessionID, api.UploadPAMSessionLogsRequest{Logs: logs}) } + if fileInfo.ResourceType == ResourceTypeKubernetes { httpEvents, err := ReadEncryptedHttpEventsFromFile(fileInfo.Filename, encryptionKey) if err != nil { - return fmt.Errorf("failed to read SSH session file: %w", err) + return fmt.Errorf("failed to read Kubernetes session file: %w", err) } log.Debug(). @@ -334,14 +563,10 @@ func (su *SessionUploader) uploadSessionFile(fileInfo *SessionFileInfo) error { }) } - request := api.UploadPAMSessionLogsRequest{ - Logs: logs, - } - - return api.CallUploadPamSessionLogs(su.httpClient, fileInfo.SessionID, request) + return api.CallUploadPamSessionLogs(su.httpClient, fileInfo.SessionID, api.UploadPAMSessionLogsRequest{Logs: logs}) } - // Database session (postgres, mysql, or legacy format) - read as request/response logs + // Database session (postgres, mysql, mssql, redis, or legacy format) entries, err := ReadEncryptedSessionLogByFilename(fileInfo.Filename, encryptionKey) if err != nil { return fmt.Errorf("failed to read session file: %w", err) @@ -367,11 +592,7 @@ func (su *SessionUploader) uploadSessionFile(fileInfo *SessionFileInfo) error { }) } - request := api.UploadPAMSessionLogsRequest{ - Logs: logs, - } - - return api.CallUploadPamSessionLogs(su.httpClient, fileInfo.SessionID, request) + return api.CallUploadPamSessionLogs(su.httpClient, fileInfo.SessionID, api.UploadPAMSessionLogsRequest{Logs: logs}) } func FindSessionFileBySessionID(sessionID string) (*SessionFileInfo, error) { @@ -389,55 +610,69 @@ func FindSessionFileBySessionID(sessionID string) (*SessionFileInfo, error) { return nil, ErrSessionFileNotFound } -func (su *SessionUploader) UploadSessionLogsBySessionID(sessionID string) error { - fileInfo, err := FindSessionFileBySessionID(sessionID) - if err != nil { - if errors.Is(err, ErrSessionFileNotFound) { - log.Debug().Str("sessionId", sessionID).Msg("Session file not found, skipping upload") - return nil - } - return fmt.Errorf("failed to find session file: %w", err) - } - - log.Info().Str("sessionId", sessionID).Str("filename", fileInfo.Filename).Msg("Uploading session logs for terminating session") +// CleanupPAMSession performs a final batch upload, unregisters the session, +// deletes the local recording file, and notifies the server that the session has ended. +func (su *SessionUploader) CleanupPAMSession(sessionID string, reason string) error { + log.Info().Str("sessionId", sessionID).Str("reason", reason).Msg("Starting PAM session cleanup") - if err := su.uploadSessionFile(fileInfo); err != nil { - return fmt.Errorf("failed to upload session logs: %w", err) + // Ensure the session is registered so the final flush can read from the correct offset. + // This handles both active sessions (already registered) and orphaned files from previous runs. + su.activeSessionsMu.RLock() + _, isRegistered := su.activeSessions[sessionID] + su.activeSessionsMu.RUnlock() + if !isRegistered { + su.RegisterSession(sessionID) } - // Delete the uploaded file - recordingDir := GetSessionRecordingDir() - fullPath := filepath.Join(recordingDir, fileInfo.Filename) - if err := os.Remove(fullPath); err != nil { - log.Warn().Err(err).Str("filename", fileInfo.Filename).Msg("Failed to delete uploaded session file") - return fmt.Errorf("failed to delete uploaded session file: %w", err) + // Final flush: upload any remaining events before we delete the file. + encryptionKey, err := su.credentialsManager.GetPAMSessionEncryptionKey() + if err != nil { + log.Warn().Err(err).Str("sessionId", sessionID).Msg("Could not get encryption key for final flush") + } else { + su.flushSession(sessionID, encryptionKey) } - log.Info().Str("sessionId", sessionID).Str("filename", fileInfo.Filename).Msg("Successfully uploaded and deleted session file") - return nil -} - -// CleanupPAMSession handles the complete cleanup process for a PAM session -func (su *SessionUploader) CleanupPAMSession(sessionID string, reason string) error { - log.Info().Str("sessionId", sessionID).Str("reason", reason).Msg("Starting PAM session cleanup") + // If the batch endpoint was not supported, fall back to a single bulk upload. + su.activeSessionsMu.RLock() + state, stateExists := su.activeSessions[sessionID] + su.activeSessionsMu.RUnlock() + if stateExists { + state.mu.Lock() + useLegacy := state.legacyMode + state.mu.Unlock() + if useLegacy { + if fileInfo, err := FindSessionFileBySessionID(sessionID); err == nil { + if err := su.uploadSessionFile(fileInfo); err != nil { + log.Error().Err(err).Str("sessionId", sessionID).Msg("Legacy bulk upload failed at session end") + } + } + } + } - // Upload session logs - if err := su.UploadSessionLogsBySessionID(sessionID); err != nil { - log.Error().Err(err).Str("sessionId", sessionID).Msg("Failed to upload session logs") - } else { - log.Info().Str("sessionId", sessionID).Msg("Successfully uploaded session logs") + // Unregister: removes from activeSessions and deletes persisted offset. + su.UnregisterSession(sessionID) + + // Delete local recording file. + fileInfo, findErr := FindSessionFileBySessionID(sessionID) + if findErr == nil { + recordingDir := GetSessionRecordingDir() + fullPath := filepath.Join(recordingDir, fileInfo.Filename) + if removeErr := os.Remove(fullPath); removeErr != nil && !os.IsNotExist(removeErr) { + log.Warn().Err(removeErr).Str("filename", fileInfo.Filename).Msg("Failed to delete session recording file") + } else { + log.Info().Str("sessionId", sessionID).Str("filename", fileInfo.Filename).Msg("Deleted local session recording file") + } } - // Cleanup session resources + // Cleanup in-memory session state. CleanupSessionMutex(sessionID) su.credentialsManager.CleanupSessionCredentials(sessionID) if err := api.CallPAMSessionTermination(su.httpClient, sessionID); err != nil { log.Error().Err(err).Str("sessionId", sessionID).Msg("Failed to notify session termination via API") return err - } else { - log.Info().Str("sessionId", sessionID).Msg("Session termination processed successfully") } + log.Info().Str("sessionId", sessionID).Msg("Session termination processed successfully") return nil }