diff --git a/api/pkg/auth/session_manager.go b/api/pkg/auth/session_manager.go new file mode 100644 index 0000000000..e392803168 --- /dev/null +++ b/api/pkg/auth/session_manager.go @@ -0,0 +1,368 @@ +package auth + +import ( + "context" + "crypto/rand" + "encoding/base64" + "errors" + "net/http" + "strings" + "time" + + "github.com/helixml/helix/api/pkg/config" + "github.com/helixml/helix/api/pkg/store" + "github.com/helixml/helix/api/pkg/system" + "github.com/helixml/helix/api/pkg/types" + "github.com/rs/zerolog/log" + "gorm.io/gorm" +) + +const ( + // SessionCookieName is the name of the HttpOnly session cookie + SessionCookieName = "helix_session" + + // CSRFCookieName is the name of the CSRF token cookie (readable by JS) + CSRFCookieName = "helix_csrf" + + // CSRFHeaderName is the name of the header that must contain the CSRF token + CSRFHeaderName = "X-CSRF-Token" + + // DefaultSessionDuration is the default session lifetime (30 days) + DefaultSessionDuration = 30 * 24 * time.Hour +) + +var ( + ErrSessionNotFound = errors.New("session not found") + ErrSessionExpired = errors.New("session expired") +) + +// SessionManager handles user session lifecycle in the BFF pattern +// It stores session IDs in HttpOnly cookies and manages OIDC token refresh transparently +type SessionManager struct { + store store.Store + oidcClient OIDC + cfg *config.ServerConfig +} + +// NewSessionManager creates a new session manager +func NewSessionManager(store store.Store, oidcClient OIDC, cfg *config.ServerConfig) *SessionManager { + return &SessionManager{ + store: store, + oidcClient: oidcClient, + cfg: cfg, + } +} + +// CreateSession creates a new user session and sets the session cookie +func (sm *SessionManager) CreateSession( + ctx context.Context, + w http.ResponseWriter, + r *http.Request, + userID string, + authProvider types.AuthProvider, + oidcAccessToken, oidcRefreshToken string, + oidcTokenExpiry time.Time, +) (*types.UserSession, error) { + session := &types.UserSession{ + ID: system.GenerateUserSessionID(), + UserID: userID, + AuthProvider: authProvider, + ExpiresAt: time.Now().Add(DefaultSessionDuration), + OIDCAccessToken: oidcAccessToken, + OIDCRefreshToken: oidcRefreshToken, + OIDCTokenExpiry: oidcTokenExpiry, + UserAgent: r.UserAgent(), + IPAddress: getClientIP(r), + } + + createdSession, err := sm.store.CreateUserSession(ctx, session) + if err != nil { + return nil, err + } + + // Set the session cookie + sm.setSessionCookie(w, createdSession.ID, createdSession.ExpiresAt) + + log.Info(). + Str("session_id", createdSession.ID). + Str("user_id", userID). + Str("auth_provider", string(authProvider)). + Msg("Created new user session") + + return createdSession, nil +} + +// GetSessionFromRequest extracts and validates the session from the request cookie +// If the session has OIDC tokens that need refresh, it refreshes them transparently +func (sm *SessionManager) GetSessionFromRequest(ctx context.Context, r *http.Request) (*types.UserSession, error) { + sessionCookie, err := r.Cookie(SessionCookieName) + if err != nil { + log.Debug().Err(err).Str("path", r.URL.Path).Msg("No session cookie found") + return nil, ErrSessionNotFound + } + + sessionID := sessionCookie.Value + if sessionID == "" { + log.Debug().Str("path", r.URL.Path).Msg("Session cookie is empty") + return nil, ErrSessionNotFound + } + + log.Debug().Str("session_id", sessionID).Str("path", r.URL.Path).Msg("Looking up session from request") + + session, err := sm.store.GetUserSession(ctx, sessionID) + if err != nil { + log.Debug().Err(err).Str("session_id", sessionID).Str("path", r.URL.Path).Msg("Session lookup failed in store") + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrSessionNotFound + } + return nil, err + } + + // Check if session is expired + if session.IsExpired() { + // Clean up expired session + _ = sm.store.DeleteUserSession(ctx, session.ID) + return nil, ErrSessionExpired + } + + // For OIDC sessions, check if access token needs refresh + if session.NeedsOIDCRefresh() && sm.oidcClient != nil { + if err := sm.refreshOIDCToken(ctx, session); err != nil { + log.Warn().Err(err).Str("session_id", session.ID).Msg("Failed to refresh OIDC token") + // Continue with potentially expired token - the API call might still work + // or will fail and force re-login + } + } + + // Touch the session periodically (not every request to avoid DB load) + // Update if last used more than 1 hour ago + if time.Since(session.LastUsedAt) > time.Hour { + session.Touch() + _, _ = sm.store.UpdateUserSession(ctx, session) + } + + return session, nil +} + +// refreshOIDCToken refreshes the OIDC access token using the refresh token +func (sm *SessionManager) refreshOIDCToken(ctx context.Context, session *types.UserSession) error { + if session.OIDCRefreshToken == "" { + return errors.New("no refresh token available") + } + + newToken, err := sm.oidcClient.RefreshAccessToken(ctx, session.OIDCRefreshToken) + if err != nil { + return err + } + + // Update session with new tokens + session.UpdateOIDCTokens(newToken.AccessToken, newToken.RefreshToken, newToken.Expiry) + + // Persist the updated session + _, err = sm.store.UpdateUserSession(ctx, session) + if err != nil { + return err + } + + log.Debug(). + Str("session_id", session.ID). + Time("new_expiry", newToken.Expiry). + Msg("Refreshed OIDC token for session") + + return nil +} + +// DeleteSession removes a session and clears the session cookie +func (sm *SessionManager) DeleteSession(ctx context.Context, w http.ResponseWriter, sessionID string) error { + err := sm.store.DeleteUserSession(ctx, sessionID) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + // Clear the session cookie + sm.ClearSessionCookie(w) + + log.Info().Str("session_id", sessionID).Msg("Deleted user session") + return nil +} + +// DeleteAllUserSessions removes all sessions for a user (logout from all devices) +func (sm *SessionManager) DeleteAllUserSessions(ctx context.Context, w http.ResponseWriter, userID string) error { + err := sm.store.DeleteUserSessionsByUser(ctx, userID) + if err != nil { + return err + } + + // Clear the session cookie + sm.ClearSessionCookie(w) + + log.Info().Str("user_id", userID).Msg("Deleted all user sessions") + return nil +} + +// setSessionCookie sets the session cookie with proper security settings +// It also sets a companion CSRF token cookie (readable by JS) +func (sm *SessionManager) setSessionCookie(w http.ResponseWriter, sessionID string, expiresAt time.Time) { + secure := sm.isSecureCookies() + + // Set the HttpOnly session cookie + http.SetCookie(w, &http.Cookie{ + Name: SessionCookieName, + Value: sessionID, + Path: "/", + Expires: expiresAt, + MaxAge: int(time.Until(expiresAt).Seconds()), + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) + + // Set the CSRF token cookie (readable by JS, used for X-CSRF-Token header) + csrfToken := generateCSRFToken() + http.SetCookie(w, &http.Cookie{ + Name: CSRFCookieName, + Value: csrfToken, + Path: "/", + Expires: expiresAt, + MaxAge: int(time.Until(expiresAt).Seconds()), + HttpOnly: false, // JS needs to read this + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +// ClearSessionCookie clears the session and CSRF cookies +func (sm *SessionManager) ClearSessionCookie(w http.ResponseWriter) { + http.SetCookie(w, &http.Cookie{ + Name: SessionCookieName, + Value: "", + Path: "/", + MaxAge: -1, + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + http.SetCookie(w, &http.Cookie{ + Name: CSRFCookieName, + Value: "", + Path: "/", + MaxAge: -1, + HttpOnly: false, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) +} + +// generateCSRFToken generates a cryptographically secure random CSRF token +func generateCSRFToken() string { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + // Fallback to ULID if crypto/rand fails (shouldn't happen) + return system.GenerateID() + } + return base64.RawURLEncoding.EncodeToString(b) +} + +// ValidateCSRF checks that the X-CSRF-Token header matches the helix_csrf cookie +// Returns true if valid, false otherwise +func ValidateCSRF(r *http.Request) bool { + csrfCookie, err := r.Cookie(CSRFCookieName) + if err != nil || csrfCookie.Value == "" { + return false + } + + csrfHeader := r.Header.Get(CSRFHeaderName) + if csrfHeader == "" { + return false + } + + return csrfCookie.Value == csrfHeader +} + +// getClientIP extracts the client IP from the request +func getClientIP(r *http.Request) string { + // Check X-Forwarded-For first (for proxied requests) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP in the chain + if idx := len(xff); idx > 0 { + for i, c := range xff { + if c == ',' { + return xff[:i] + } + } + return xff + } + } + + // Check X-Real-IP + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + + // Fall back to RemoteAddr + return r.RemoteAddr +} + +// StartBackgroundRefresh starts a background goroutine that refreshes OIDC tokens +// before they expire, similar to OAuth manager's RefreshExpiredTokens +func (sm *SessionManager) StartBackgroundRefresh(ctx context.Context) { + ticker := time.NewTicker(1 * time.Minute) + go func() { + for { + select { + case <-ctx.Done(): + ticker.Stop() + return + case <-ticker.C: + sm.refreshExpiredSessions(ctx) + } + } + }() +} + +// refreshExpiredSessions refreshes OIDC tokens for sessions approaching expiry +func (sm *SessionManager) refreshExpiredSessions(ctx context.Context) { + // Get sessions with tokens expiring in the next 5 minutes + sessions, err := sm.store.GetUserSessionsNearOIDCExpiry(ctx, time.Now().Add(5*time.Minute)) + if err != nil { + log.Error().Err(err).Msg("Failed to get sessions near OIDC expiry") + return + } + + if len(sessions) == 0 { + return + } + + log.Debug().Int("count", len(sessions)).Msg("Refreshing OIDC tokens for sessions") + + for _, session := range sessions { + if err := sm.refreshOIDCToken(ctx, session); err != nil { + log.Warn(). + Err(err). + Str("session_id", session.ID). + Str("user_id", session.UserID). + Msg("Failed to refresh OIDC token in background job") + } + } +} + +// CleanupExpiredSessions removes expired sessions from the database +// This should be called periodically (e.g., daily) +func (sm *SessionManager) CleanupExpiredSessions(ctx context.Context) error { + return sm.store.DeleteExpiredUserSessions(ctx) +} + +// isSecureCookies determines if cookies should have the Secure flag set. +// Returns true if OIDC_SECURE_COOKIES is explicitly set to true, +// or if SERVER_URL starts with https:// +func (sm *SessionManager) isSecureCookies() bool { + if sm.cfg == nil { + return true // Safe default + } + // Explicit setting takes precedence + if sm.cfg.Auth.OIDC.SecureCookies { + return true + } + // Auto-detect from server URL + return strings.HasPrefix(sm.cfg.WebServer.URL, "https://") +} diff --git a/api/pkg/auth/session_manager_test.go b/api/pkg/auth/session_manager_test.go new file mode 100644 index 0000000000..eeaa50f813 --- /dev/null +++ b/api/pkg/auth/session_manager_test.go @@ -0,0 +1,429 @@ +package auth + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/helixml/helix/api/pkg/config" + "github.com/helixml/helix/api/pkg/store" + "github.com/helixml/helix/api/pkg/types" + "github.com/stretchr/testify/suite" + "go.uber.org/mock/gomock" + "golang.org/x/oauth2" + "gorm.io/gorm" +) + +type SessionManagerSuite struct { + suite.Suite + ctrl *gomock.Controller + mockStore *store.MockStore + mockOIDC *MockOIDC + sessionManager *SessionManager + ctx context.Context + cfg *config.ServerConfig +} + +func TestSessionManagerSuite(t *testing.T) { + suite.Run(t, new(SessionManagerSuite)) +} + +func (s *SessionManagerSuite) SetupTest() { + s.ctrl = gomock.NewController(s.T()) + s.ctx = context.Background() + s.mockStore = store.NewMockStore(s.ctrl) + s.mockOIDC = NewMockOIDC(s.ctrl) + s.cfg = &config.ServerConfig{} + s.cfg.WebServer.URL = "https://example.com" + s.sessionManager = NewSessionManager(s.mockStore, s.mockOIDC, s.cfg) +} + +func (s *SessionManagerSuite) TearDownTest() { + s.ctrl.Finish() +} + +func (s *SessionManagerSuite) TestCreateSession() { + userID := "usr_test123" + authProvider := types.AuthProviderOIDC + + s.mockStore.EXPECT(). + CreateUserSession(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, session *types.UserSession) (*types.UserSession, error) { + s.Equal(userID, session.UserID) + s.Equal(authProvider, session.AuthProvider) + s.NotEmpty(session.ID) + s.True(session.ExpiresAt.After(time.Now())) + return session, nil + }) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("User-Agent", "Test-Agent") + + session, err := s.sessionManager.CreateSession( + s.ctx, w, r, + userID, + authProvider, + "access_token", + "refresh_token", + time.Now().Add(time.Hour), + ) + + s.NoError(err) + s.NotNil(session) + s.Equal(userID, session.UserID) + + // Check cookies were set + cookies := w.Result().Cookies() + s.Require().Len(cookies, 2) + + var sessionCookie, csrfCookie *http.Cookie + for _, c := range cookies { + if c.Name == SessionCookieName { + sessionCookie = c + } + if c.Name == CSRFCookieName { + csrfCookie = c + } + } + + s.NotNil(sessionCookie) + s.True(sessionCookie.HttpOnly) + s.True(sessionCookie.Secure) + + s.NotNil(csrfCookie) + s.False(csrfCookie.HttpOnly) // JS needs to read this +} + +func (s *SessionManagerSuite) TestGetSessionFromRequest_NoSessionCookie() { + r := httptest.NewRequest(http.MethodGet, "/", nil) + + session, err := s.sessionManager.GetSessionFromRequest(s.ctx, r) + + s.Nil(session) + s.ErrorIs(err, ErrSessionNotFound) +} + +func (s *SessionManagerSuite) TestGetSessionFromRequest_EmptySessionCookie() { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{Name: SessionCookieName, Value: ""}) + + session, err := s.sessionManager.GetSessionFromRequest(s.ctx, r) + + s.Nil(session) + s.ErrorIs(err, ErrSessionNotFound) +} + +func (s *SessionManagerSuite) TestGetSessionFromRequest_SessionNotFound() { + sessionID := "uss_nonexistent" + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{Name: SessionCookieName, Value: sessionID}) + + s.mockStore.EXPECT(). + GetUserSession(gomock.Any(), sessionID). + Return(nil, gorm.ErrRecordNotFound) + + session, err := s.sessionManager.GetSessionFromRequest(s.ctx, r) + + s.Nil(session) + s.ErrorIs(err, ErrSessionNotFound) +} + +func (s *SessionManagerSuite) TestGetSessionFromRequest_SessionExpired() { + sessionID := "uss_expired" + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{Name: SessionCookieName, Value: sessionID}) + + expiredSession := &types.UserSession{ + ID: sessionID, + UserID: "usr_test", + ExpiresAt: time.Now().Add(-time.Hour), // Expired + } + + s.mockStore.EXPECT(). + GetUserSession(gomock.Any(), sessionID). + Return(expiredSession, nil) + + s.mockStore.EXPECT(). + DeleteUserSession(gomock.Any(), sessionID). + Return(nil) + + session, err := s.sessionManager.GetSessionFromRequest(s.ctx, r) + + s.Nil(session) + s.ErrorIs(err, ErrSessionExpired) +} + +func (s *SessionManagerSuite) TestGetSessionFromRequest_ValidSession() { + sessionID := "uss_valid" + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{Name: SessionCookieName, Value: sessionID}) + + validSession := &types.UserSession{ + ID: sessionID, + UserID: "usr_test", + ExpiresAt: time.Now().Add(time.Hour), + LastUsedAt: time.Now(), // Recently used, no touch needed + } + + s.mockStore.EXPECT(). + GetUserSession(gomock.Any(), sessionID). + Return(validSession, nil) + + session, err := s.sessionManager.GetSessionFromRequest(s.ctx, r) + + s.NoError(err) + s.NotNil(session) + s.Equal(sessionID, session.ID) +} + +func (s *SessionManagerSuite) TestGetSessionFromRequest_TouchesOldSession() { + sessionID := "uss_old" + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{Name: SessionCookieName, Value: sessionID}) + + oldSession := &types.UserSession{ + ID: sessionID, + UserID: "usr_test", + ExpiresAt: time.Now().Add(time.Hour), + LastUsedAt: time.Now().Add(-2 * time.Hour), // Last used more than 1 hour ago + } + + s.mockStore.EXPECT(). + GetUserSession(gomock.Any(), sessionID). + Return(oldSession, nil) + + s.mockStore.EXPECT(). + UpdateUserSession(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, session *types.UserSession) (*types.UserSession, error) { + s.True(session.LastUsedAt.After(time.Now().Add(-time.Minute))) + return session, nil + }) + + session, err := s.sessionManager.GetSessionFromRequest(s.ctx, r) + + s.NoError(err) + s.NotNil(session) +} + +func (s *SessionManagerSuite) TestDeleteSession() { + sessionID := "uss_todelete" + + s.mockStore.EXPECT(). + DeleteUserSession(gomock.Any(), sessionID). + Return(nil) + + w := httptest.NewRecorder() + + err := s.sessionManager.DeleteSession(s.ctx, w, sessionID) + + s.NoError(err) + + // Check cookies were cleared + cookies := w.Result().Cookies() + for _, c := range cookies { + if c.Name == SessionCookieName || c.Name == CSRFCookieName { + s.Equal(-1, c.MaxAge) + } + } +} + +func (s *SessionManagerSuite) TestDeleteAllUserSessions() { + userID := "usr_test123" + + s.mockStore.EXPECT(). + DeleteUserSessionsByUser(gomock.Any(), userID). + Return(nil) + + w := httptest.NewRecorder() + + err := s.sessionManager.DeleteAllUserSessions(s.ctx, w, userID) + + s.NoError(err) +} + +func (s *SessionManagerSuite) TestValidateCSRF_Valid() { + r := httptest.NewRequest(http.MethodPost, "/", nil) + r.AddCookie(&http.Cookie{Name: CSRFCookieName, Value: "test-csrf-token"}) + r.Header.Set(CSRFHeaderName, "test-csrf-token") + + s.True(ValidateCSRF(r)) +} + +func (s *SessionManagerSuite) TestValidateCSRF_MissingCookie() { + r := httptest.NewRequest(http.MethodPost, "/", nil) + r.Header.Set(CSRFHeaderName, "test-csrf-token") + + s.False(ValidateCSRF(r)) +} + +func (s *SessionManagerSuite) TestValidateCSRF_MissingHeader() { + r := httptest.NewRequest(http.MethodPost, "/", nil) + r.AddCookie(&http.Cookie{Name: CSRFCookieName, Value: "test-csrf-token"}) + + s.False(ValidateCSRF(r)) +} + +func (s *SessionManagerSuite) TestValidateCSRF_Mismatch() { + r := httptest.NewRequest(http.MethodPost, "/", nil) + r.AddCookie(&http.Cookie{Name: CSRFCookieName, Value: "token-a"}) + r.Header.Set(CSRFHeaderName, "token-b") + + s.False(ValidateCSRF(r)) +} + +func (s *SessionManagerSuite) TestIsSecureCookies_HTTPS() { + s.cfg.WebServer.URL = "https://example.com" + s.True(s.sessionManager.isSecureCookies()) +} + +func (s *SessionManagerSuite) TestIsSecureCookies_HTTP() { + s.cfg.WebServer.URL = "http://localhost:8080" + s.cfg.Auth.OIDC.SecureCookies = false + s.False(s.sessionManager.isSecureCookies()) +} + +func (s *SessionManagerSuite) TestIsSecureCookies_ExplicitTrue() { + s.cfg.WebServer.URL = "http://localhost:8080" + s.cfg.Auth.OIDC.SecureCookies = true + s.True(s.sessionManager.isSecureCookies()) +} + +func (s *SessionManagerSuite) TestGetClientIP_XForwardedFor() { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("X-Forwarded-For", "192.168.1.1, 10.0.0.1") + + s.Equal("192.168.1.1", getClientIP(r)) +} + +func (s *SessionManagerSuite) TestGetClientIP_XRealIP() { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("X-Real-IP", "192.168.1.2") + + s.Equal("192.168.1.2", getClientIP(r)) +} + +func (s *SessionManagerSuite) TestGetClientIP_RemoteAddr() { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.RemoteAddr = "192.168.1.3:12345" + + s.Equal("192.168.1.3:12345", getClientIP(r)) +} + +// Tests for OIDC token refresh scenarios (addresses PR review comment about OIDC_OFFLINE_ACCESS) + +func (s *SessionManagerSuite) TestGetSessionFromRequest_OIDCRefreshNeeded_NoRefreshToken() { + // Test case: OIDC session needs refresh but no refresh token available + // This happens when OIDC_OFFLINE_ACCESS is not enabled + sessionID := "uss_no_refresh_token" + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{Name: SessionCookieName, Value: sessionID}) + + oidcSession := &types.UserSession{ + ID: sessionID, + UserID: "usr_test", + AuthProvider: types.AuthProviderOIDC, + ExpiresAt: time.Now().Add(time.Hour), + LastUsedAt: time.Now(), + OIDCRefreshToken: "", // No refresh token (OIDC_OFFLINE_ACCESS not enabled) + OIDCAccessToken: "expired_access_token", + OIDCTokenExpiry: time.Now().Add(-time.Minute), // Token expired + } + + s.mockStore.EXPECT(). + GetUserSession(gomock.Any(), sessionID). + Return(oidcSession, nil) + + // Session should still be returned even without refresh capability + // The access token is expired but we continue - API calls may fail and force re-login + session, err := s.sessionManager.GetSessionFromRequest(s.ctx, r) + + s.NoError(err) + s.NotNil(session) + s.Equal(sessionID, session.ID) + s.Empty(session.OIDCRefreshToken) // No refresh token +} + +func (s *SessionManagerSuite) TestGetSessionFromRequest_OIDCRefreshNeeded_RefreshSucceeds() { + // Test case: OIDC session needs refresh and refresh succeeds + sessionID := "uss_refresh_success" + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{Name: SessionCookieName, Value: sessionID}) + + oidcSession := &types.UserSession{ + ID: sessionID, + UserID: "usr_test", + AuthProvider: types.AuthProviderOIDC, + ExpiresAt: time.Now().Add(time.Hour), + LastUsedAt: time.Now(), + OIDCRefreshToken: "valid_refresh_token", + OIDCAccessToken: "expired_access_token", + OIDCTokenExpiry: time.Now().Add(-time.Minute), // Token expired + } + + s.mockStore.EXPECT(). + GetUserSession(gomock.Any(), sessionID). + Return(oidcSession, nil) + + // Mock successful token refresh + newExpiry := time.Now().Add(time.Hour) + s.mockOIDC.EXPECT(). + RefreshAccessToken(gomock.Any(), "valid_refresh_token"). + Return(&oauth2.Token{ + AccessToken: "new_access_token", + RefreshToken: "new_refresh_token", + Expiry: newExpiry, + }, nil) + + // Mock session update + s.mockStore.EXPECT(). + UpdateUserSession(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, session *types.UserSession) (*types.UserSession, error) { + s.Equal("new_access_token", session.OIDCAccessToken) + s.Equal("new_refresh_token", session.OIDCRefreshToken) + return session, nil + }) + + session, err := s.sessionManager.GetSessionFromRequest(s.ctx, r) + + s.NoError(err) + s.NotNil(session) +} + +func (s *SessionManagerSuite) TestGetSessionFromRequest_OIDCRefreshNeeded_RefreshFails() { + // Test case: OIDC session needs refresh but refresh fails + // Session should still be returned - the expired token may still work or force re-login + sessionID := "uss_refresh_fails" + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{Name: SessionCookieName, Value: sessionID}) + + oidcSession := &types.UserSession{ + ID: sessionID, + UserID: "usr_test", + AuthProvider: types.AuthProviderOIDC, + ExpiresAt: time.Now().Add(time.Hour), + LastUsedAt: time.Now(), + OIDCRefreshToken: "invalid_refresh_token", + OIDCAccessToken: "expired_access_token", + OIDCTokenExpiry: time.Now().Add(-time.Minute), // Token expired + } + + s.mockStore.EXPECT(). + GetUserSession(gomock.Any(), sessionID). + Return(oidcSession, nil) + + // Mock failed token refresh + s.mockOIDC.EXPECT(). + RefreshAccessToken(gomock.Any(), "invalid_refresh_token"). + Return(nil, errors.New("refresh token expired")) + + // Session should still be returned despite refresh failure + session, err := s.sessionManager.GetSessionFromRequest(s.ctx, r) + + s.NoError(err) // No error returned to caller + s.NotNil(session) + s.Equal(sessionID, session.ID) +} diff --git a/api/pkg/server/auth.go b/api/pkg/server/auth.go index f4d91182bb..69d8457ca8 100644 --- a/api/pkg/server/auth.go +++ b/api/pkg/server/auth.go @@ -537,22 +537,28 @@ func (s *HelixAPIServer) login(w http.ResponseWriter, r *http.Request) { return } - // Generate a new token - token, err := s.authenticator.GenerateUserToken(r.Context(), user) + // Create a BFF session for this user + // For regular auth, we don't have OIDC tokens + _, err = s.sessionManager.CreateSession( + r.Context(), + w, + r, + user.ID, + types.AuthProviderRegular, + "", // no OIDC access token + "", // no OIDC refresh token + time.Time{}, // no OIDC token expiry + ) if err != nil { - log.Error().Err(err).Msg("Failed to generate user token") - http.Error(w, "Failed to generate user token: "+err.Error(), http.StatusInternalServerError) + log.Error().Err(err).Msg("Failed to create session") + http.Error(w, "Failed to create session: "+err.Error(), http.StatusInternalServerError) return } - // OK, set authentication cookies and redirect - cookieManager.Set(w, accessTokenCookie, token) - cookieManager.Set(w, refreshTokenCookie, token) - + // Return user info without token (frontend doesn't need it with BFF) response := types.UserResponse{ ID: user.ID, Email: user.Email, - Token: token, Name: user.FullName, } writeResponse(w, response, http.StatusOK) @@ -653,24 +659,50 @@ func (s *HelixAPIServer) callback(w http.ResponseWriter, r *http.Request) { return } - // Set cookies, if applicable - if oauth2Token.AccessToken != "" { - cm.Set(w, accessTokenCookie, oauth2Token.AccessToken) - } else { + // Validate access token is present + if oauth2Token.AccessToken == "" { log.Debug().Msg("access_token is empty") http.Error(w, "access_token is empty", http.StatusBadRequest) return } - if oauth2Token.RefreshToken != "" { - cm.Set(w, refreshTokenCookie, oauth2Token.RefreshToken) - log.Info().Msg("Refresh token received and stored") - } else { - // No refresh token - this is expected for providers like Google when - // OIDC_OFFLINE_ACCESS is not enabled. Without a refresh token, the session - // will expire when the access token expires (typically 1 hour for Google). + + // Get or create user from OIDC claims via ValidateUserToken + // This handles user lookup/creation in the database + user, err := s.oidcClient.ValidateUserToken(ctx, oauth2Token.AccessToken) + if err != nil { + log.Error().Err(err).Msg("Failed to validate user token and get/create user") + http.Error(w, "Failed to get user info: "+err.Error(), http.StatusInternalServerError) + return + } + + // Log if no refresh token received + if oauth2Token.RefreshToken == "" { log.Warn().Msg("No refresh token received from OIDC provider. Set OIDC_OFFLINE_ACCESS=true for Google to enable token refresh.") } + // Create a BFF session with OIDC tokens stored on the backend + _, err = s.sessionManager.CreateSession( + ctx, + w, + r, + user.ID, + types.AuthProviderOIDC, + oauth2Token.AccessToken, + oauth2Token.RefreshToken, + oauth2Token.Expiry, + ) + if err != nil { + log.Error().Err(err).Msg("Failed to create session") + http.Error(w, "Failed to create session: "+err.Error(), http.StatusInternalServerError) + return + } + + log.Info(). + Str("user_id", user.ID). + Bool("has_refresh_token", oauth2Token.RefreshToken != ""). + Time("token_expiry", oauth2Token.Expiry). + Msg("Created BFF session for OIDC user") + // Check if we have a stored redirect URI redirectURI := s.Cfg.WebServer.URL // default redirect if cookie, err := cm.Get(r, redirectURICookie); err == nil { @@ -692,6 +724,36 @@ func (s *HelixAPIServer) user(w http.ResponseWriter, r *http.Request) { return } + // BFF pattern: Check for session cookie first + if s.sessionManager != nil { + session, err := s.sessionManager.GetSessionFromRequest(ctx, r) + if err == nil && session != nil { + // Get user info from database using session + user, err := s.Store.GetUser(ctx, &store.GetUserQuery{ID: session.UserID}) + if err != nil { + log.Error().Err(err).Str("user_id", session.UserID).Msg("Failed to get user for BFF session") + http.Error(w, "Failed to get user info", http.StatusInternalServerError) + return + } + + log.Debug(). + Str("session_id", session.ID). + Str("user_id", session.UserID). + Msg("User info retrieved via BFF session") + + response := types.UserResponse{ + ID: user.ID, + Email: user.Email, + Token: "", // No token exposed with BFF pattern + Name: user.FullName, + Admin: user.Admin, + } + writeResponse(w, response, http.StatusOK) + return + } + } + + // Fallback: Check for legacy access_token cookie (for backwards compatibility) cm := NewCookieManager(s.Cfg) accessToken, err := cm.Get(r, accessTokenCookie) if err != nil { @@ -769,6 +831,42 @@ func (s *HelixAPIServer) user(w http.ResponseWriter, r *http.Request) { writeResponse(w, response, http.StatusOK) } +// session godoc +// @Summary Get current session info +// @Description Returns session info for BFF authentication. The frontend uses this to check if the user is logged in. +// @Tags auth +// @Success 200 {object} types.SessionInfo +// @Failure 401 {string} string "Not authenticated" +// @Router /api/v1/auth/session [get] +func (s *HelixAPIServer) session(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Try to get user from BFF session + session, err := s.sessionManager.GetSessionFromRequest(ctx, r) + if err != nil { + // No valid session + http.Error(w, "Not authenticated", http.StatusUnauthorized) + return + } + + // Get user info from database + user, err := s.Store.GetUser(ctx, &store.GetUserQuery{ID: session.UserID}) + if err != nil { + log.Error().Err(err).Str("user_id", session.UserID).Msg("Failed to get user for session") + http.Error(w, "Failed to get user info", http.StatusInternalServerError) + return + } + + // Return user info without tokens + response := types.UserResponse{ + ID: user.ID, + Email: user.Email, + Name: user.FullName, + Admin: s.authMiddleware.isAdminWithContext(ctx, user.ID), + } + writeResponse(w, response, http.StatusOK) +} + // user godoc // @Summary Logout // @Description Logout the user @@ -780,7 +878,14 @@ func (s *HelixAPIServer) logout(w http.ResponseWriter, r *http.Request) { return } - // Remove cookies + // Delete BFF session if one exists + if session, err := s.sessionManager.GetSessionFromRequest(r.Context(), r); err == nil && session != nil { + if err := s.sessionManager.DeleteSession(r.Context(), w, session.ID); err != nil { + log.Warn().Err(err).Str("session_id", session.ID).Msg("Failed to delete session during logout") + } + } + + // Remove legacy cookies (for backward compatibility during migration) NewCookieManager(s.Cfg).DeleteAllCookies(w) // Use redirect_uri from query param if provided (set by frontend from window.location.origin) @@ -849,6 +954,27 @@ func (s *HelixAPIServer) authenticated(w http.ResponseWriter, r *http.Request) { return } + // BFF pattern: Check for session cookie first + if s.sessionManager != nil { + session, err := s.sessionManager.GetSessionFromRequest(ctx, r) + if err != nil { + log.Debug().Err(err).Msg("BFF session lookup failed in authenticated endpoint") + } + if session != nil { + log.Debug(). + Str("session_id", session.ID). + Str("user_id", session.UserID). + Msg("User authenticated via BFF session in authenticated endpoint") + writeResponse(w, types.AuthenticatedResponse{ + Authenticated: true, + }, http.StatusOK) + return + } + } else { + log.Debug().Msg("sessionManager is nil in authenticated endpoint") + } + + // Fallback: Check for legacy access_token cookie (for backwards compatibility) cm := NewCookieManager(s.Cfg) accessToken, err := cm.Get(r, accessTokenCookie) if err != nil { diff --git a/api/pkg/server/auth_middleware.go b/api/pkg/server/auth_middleware.go index 6faba09d72..008c3f1bd4 100644 --- a/api/pkg/server/auth_middleware.go +++ b/api/pkg/server/auth_middleware.go @@ -42,11 +42,12 @@ type authMiddlewareConfig struct { } type authMiddleware struct { - authenticator authpkg.Authenticator - oidcClient authpkg.OIDC // For OIDC token validation (nil if not using OIDC) - store store.Store - cfg authMiddlewareConfig - serverCfg *config.ServerConfig // Server config for cookie management + authenticator authpkg.Authenticator + oidcClient authpkg.OIDC // For OIDC token validation (nil if not using OIDC) + store store.Store + cfg authMiddlewareConfig + serverCfg *config.ServerConfig // Server config for cookie management + sessionManager *authpkg.SessionManager // BFF session manager (nil if not using sessions) } func newAuthMiddleware( @@ -55,13 +56,15 @@ func newAuthMiddleware( store store.Store, cfg authMiddlewareConfig, serverCfg *config.ServerConfig, + sessionManager *authpkg.SessionManager, ) *authMiddleware { return &authMiddleware{ - authenticator: authenticator, - oidcClient: oidcClient, - store: store, - cfg: cfg, - serverCfg: serverCfg, + authenticator: authenticator, + oidcClient: oidcClient, + store: store, + cfg: cfg, + serverCfg: serverCfg, + sessionManager: sessionManager, } } @@ -259,13 +262,82 @@ func (auth *authMiddleware) getUserFromToken(ctx context.Context, token string) return user, nil } +// getUserFromSession checks for a valid BFF session and returns the user +// This is the primary auth method for browser-based clients using HttpOnly session cookies +func (auth *authMiddleware) getUserFromSession(ctx context.Context, r *http.Request) (*types.User, error) { + if auth.sessionManager == nil { + return nil, authpkg.ErrSessionNotFound + } + + session, err := auth.sessionManager.GetSessionFromRequest(ctx, r) + if err != nil { + return nil, err + } + + // Load user from database + dbUser, err := auth.store.GetUser(ctx, &store.GetUserQuery{ID: session.UserID}) + if err != nil { + return nil, fmt.Errorf("failed to load user for session: %w", err) + } + if dbUser == nil { + return nil, fmt.Errorf("user not found for session: %s", session.UserID) + } + + // Set token type based on auth provider + if session.AuthProvider == types.AuthProviderOIDC { + dbUser.Token = session.OIDCAccessToken + dbUser.TokenType = types.TokenTypeOIDC + } else { + dbUser.TokenType = types.TokenTypeSession + } + + dbUser.Admin = auth.isAdminWithContext(ctx, dbUser.ID) + + log.Debug(). + Str("session_id", session.ID). + Str("user_id", dbUser.ID). + Str("auth_provider", string(session.AuthProvider)). + Msg("Authenticated user via BFF session") + + return dbUser, nil +} + // this will extract the token from the request and then load the correct // user based on what type of token it is // if there is no token, a default user object will be written to the // request context func (auth *authMiddleware) extractMiddleware(next http.Handler) http.Handler { f := func(w http.ResponseWriter, r *http.Request) { - user, err := auth.getUserFromToken(r.Context(), getRequestToken(r)) + var user *types.User + var err error + + // First, try BFF session authentication (from helix_session cookie) + // This is the primary auth method for browser clients + if auth.sessionManager != nil { + user, err = auth.getUserFromSession(r.Context(), r) + if err == nil && user != nil { + // Successfully authenticated via session + r = r.WithContext(setRequestUser(r.Context(), *user)) + next.ServeHTTP(w, r) + return + } + + // Session expired - clear the session cookie + if errors.Is(err, authpkg.ErrSessionExpired) { + auth.sessionManager.ClearSessionCookie(w) + } + + // If session auth failed but it's not a "not found" error, log it + if err != nil && !errors.Is(err, authpkg.ErrSessionNotFound) { + log.Debug().Err(err).Str("path", r.URL.Path).Msg("BFF session auth failed, trying token auth") + } + + // Fall through to token-based auth + err = nil + } + + // Fall back to token-based authentication (API keys, runner tokens, OIDC tokens) + user, err = auth.getUserFromToken(r.Context(), getRequestToken(r)) if err != nil { // Check if error is due to server not ready vs invalid token // Return 503 for server errors so frontend doesn't auto-logout during API restart @@ -283,39 +355,8 @@ func (auth *authMiddleware) extractMiddleware(next http.Handler) http.Handler { return } - // Attempt transparent token refresh using refresh_token cookie - // This allows requests to succeed even when access token is expired, - // as long as we have a valid refresh token - if auth.oidcClient != nil && auth.serverCfg != nil { - cm := NewCookieManager(auth.serverCfg) - refreshToken, refreshErr := cm.Get(r, refreshTokenCookie) - if refreshErr == nil && refreshToken != "" && !looksLikeHelixJWT(refreshToken) { - newToken, refreshErr := auth.oidcClient.RefreshAccessToken(r.Context(), refreshToken) - if refreshErr == nil && newToken.AccessToken != "" { - // Update cookies with new tokens - cm.Set(w, accessTokenCookie, newToken.AccessToken) - if newToken.RefreshToken != "" { - cm.Set(w, refreshTokenCookie, newToken.RefreshToken) - } - - // Set header for frontend to update its in-memory token - w.Header().Set("X-Token-Refreshed", newToken.AccessToken) - - // Retry auth with the new token - user, err = auth.getUserFromToken(r.Context(), newToken.AccessToken) - if err == nil { - log.Info().Str("path", r.URL.Path).Msg("Token refreshed transparently in middleware") - // Fall through to continue request processing - } else { - log.Warn().Err(err).Str("path", r.URL.Path).Msg("Token refresh succeeded but validation failed") - } - } else if refreshErr != nil { - log.Debug().Err(refreshErr).Str("path", r.URL.Path).Msg("Transparent token refresh failed") - } - } - } - - // If still no valid user after refresh attempt, return 401 + // With BFF pattern, token refresh is handled by SessionManager + // API keys and runner tokens don't need refresh if err != nil { log.Debug().Err(err).Str("path", r.URL.Path).Msg("Auth error - returning 401") http.Error(w, err.Error(), http.StatusUnauthorized) @@ -378,3 +419,55 @@ func (auth *authMiddleware) auth(f http.HandlerFunc) http.HandlerFunc { f(w, r) } } + +// csrfExemptPaths are paths that don't require CSRF protection +// These are typically auth endpoints or APIs used before/during session creation +var csrfExemptPaths = map[string]bool{ + "/api/v1/auth/login": true, // Login doesn't have CSRF cookie yet + "/api/v1/auth/logout": true, // Logout clears the session + "/api/v1/auth/oidc": true, // OIDC redirect + "/api/v1/auth/oidc/callback": true, // OIDC callback + "/api/v1/auth/authenticated": true, // Read-only check + "/api/v1/auth/session": true, // Session info (GET-like semantics) + "/api/v1/auth/user": true, // Get user info +} + +// csrfMiddleware validates CSRF tokens for state-changing requests +// when the request uses cookie-based session authentication. +// API key and runner token authenticated requests skip CSRF validation. +func (auth *authMiddleware) csrfMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Only check CSRF for state-changing methods + if r.Method != "POST" && r.Method != "PUT" && r.Method != "DELETE" && r.Method != "PATCH" { + next.ServeHTTP(w, r) + return + } + + // Check if path is exempt from CSRF + if csrfExemptPaths[r.URL.Path] { + next.ServeHTTP(w, r) + return + } + + // Check if this request was authenticated via session cookie + // If using API key or runner token, skip CSRF validation + _, err := r.Cookie(authpkg.SessionCookieName) + if err != nil { + // No session cookie - this is API key or runner token auth, skip CSRF + next.ServeHTTP(w, r) + return + } + + // Session cookie exists - validate CSRF token + if !authpkg.ValidateCSRF(r) { + log.Warn(). + Str("path", r.URL.Path). + Str("method", r.Method). + Msg("CSRF validation failed") + http.Error(w, "CSRF validation failed", http.StatusForbidden) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/api/pkg/server/auth_middleware_test.go b/api/pkg/server/auth_middleware_test.go index 53ee59431e..dd28ed4bac 100644 --- a/api/pkg/server/auth_middleware_test.go +++ b/api/pkg/server/auth_middleware_test.go @@ -21,7 +21,7 @@ func TestIsAdminWithContext_EmptyUserID(t *testing.T) { auth := newAuthMiddleware(nil, nil, mockStore, authMiddlewareConfig{ adminUserIDs: nil, - }, nil) + }, nil, nil) result := auth.isAdminWithContext(context.Background(), "") assert.False(t, result, "empty userID should return false") @@ -36,7 +36,7 @@ func TestIsAdminWithContext_DevMode_EveryoneIsAdmin(t *testing.T) { auth := newAuthMiddleware(nil, nil, mockStore, authMiddlewareConfig{ adminUserIDs: []string{config.AdminAllUsers}, - }, nil) + }, nil, nil) result := auth.isAdminWithContext(context.Background(), "any-user-id") assert.True(t, result, "with ADMIN_USER_IDS=all, any user should be admin") @@ -52,7 +52,7 @@ func TestIsAdminWithContext_SpecificUserInList(t *testing.T) { auth := newAuthMiddleware(nil, nil, mockStore, authMiddlewareConfig{ adminUserIDs: []string{"user-123", "user-456"}, - }, nil) + }, nil, nil) result := auth.isAdminWithContext(context.Background(), userID) assert.True(t, result, "user in ADMIN_USER_IDS list should be admin") @@ -75,7 +75,7 @@ func TestIsAdminWithContext_UserNotInList_ChecksDatabase(t *testing.T) { auth := newAuthMiddleware(nil, nil, mockStore, authMiddlewareConfig{ adminUserIDs: []string{"user-123", "user-456"}, // user-789 not in list - }, nil) + }, nil, nil) result := auth.isAdminWithContext(context.Background(), userID) assert.True(t, result, "user not in list but Admin=true in database should be admin") @@ -97,7 +97,7 @@ func TestIsAdminWithContext_DatabaseAdmin_True(t *testing.T) { auth := newAuthMiddleware(nil, nil, mockStore, authMiddlewareConfig{ adminUserIDs: nil, // Empty list - use database - }, nil) + }, nil, nil) result := auth.isAdminWithContext(context.Background(), userID) assert.True(t, result, "user with Admin=true in database should be admin") @@ -119,7 +119,7 @@ func TestIsAdminWithContext_DatabaseAdmin_False(t *testing.T) { auth := newAuthMiddleware(nil, nil, mockStore, authMiddlewareConfig{ adminUserIDs: nil, // Empty list - use database - }, nil) + }, nil, nil) result := auth.isAdminWithContext(context.Background(), userID) assert.False(t, result, "user with Admin=false in database should not be admin") @@ -138,7 +138,7 @@ func TestIsAdminWithContext_UserNotFound(t *testing.T) { auth := newAuthMiddleware(nil, nil, mockStore, authMiddlewareConfig{ adminUserIDs: nil, // Empty list - use database - }, nil) + }, nil, nil) result := auth.isAdminWithContext(context.Background(), userID) assert.False(t, result, "user not found should return false") @@ -157,7 +157,7 @@ func TestIsAdminWithContext_DatabaseError(t *testing.T) { auth := newAuthMiddleware(nil, nil, mockStore, authMiddlewareConfig{ adminUserIDs: nil, // Empty list - use database - }, nil) + }, nil, nil) result := auth.isAdminWithContext(context.Background(), userID) assert.False(t, result, "database error should return false") diff --git a/api/pkg/server/auth_test.go b/api/pkg/server/auth_test.go index 2f3698e657..22272dad92 100644 --- a/api/pkg/server/auth_test.go +++ b/api/pkg/server/auth_test.go @@ -58,11 +58,13 @@ func (suite *AuthSuite) SetupTest() { cfg.Auth.Provider = types.AuthProviderOIDC suite.oidcClient = auth.NewMockOIDC(ctrl) suite.authenticator = auth.NewMockAuthenticator(ctrl) + sessionManager := auth.NewSessionManager(suite.store, suite.oidcClient, cfg) suite.server = &HelixAPIServer{ - Cfg: cfg, - oidcClient: suite.oidcClient, - authenticator: suite.authenticator, - Store: suite.store, + Cfg: cfg, + oidcClient: suite.oidcClient, + authenticator: suite.authenticator, + Store: suite.store, + sessionManager: sessionManager, } } @@ -265,25 +267,33 @@ func (suite *AuthSuite) TestCallback() { mockIDToken := &oidc.IDToken{Nonce: testNonce} suite.oidcClient.EXPECT().Exchange(gomock.Any(), testCode).Return(mockToken, nil) suite.oidcClient.EXPECT().VerifyIDToken(gomock.Any(), mockToken).Return(mockIDToken, nil) + // BFF: Now we call ValidateUserToken to get/create user + suite.oidcClient.EXPECT().ValidateUserToken(gomock.Any(), testAccessToken).Return(&types.User{ + ID: "user-123", + Email: testEmail, + }, nil) + // BFF: CreateUserSession is called to create the session + suite.store.EXPECT().CreateUserSession(gomock.Any(), gomock.Any()).Return(&types.UserSession{ + ID: "session-123", + UserID: "user-123", + }, nil) }, expectedStatus: http.StatusFound, checkResponse: func(rec *httptest.ResponseRecorder) { suite.Equal(testServerURL+"/dashboard", rec.Header().Get("Location")) - // Verify access_token and refresh_token are set + // Verify helix_session cookie is set (BFF pattern) res := rec.Result() defer res.Body.Close() + foundSessionCookie := false for _, cookie := range res.Cookies() { - switch cookie.Name { - case "access_token": - suite.NotEmpty(cookie.Value) - suite.Equal(testAccessToken, cookie.Value) - suite.Equal("/", cookie.Path) - case "refresh_token": + if cookie.Name == "helix_session" { + foundSessionCookie = true suite.NotEmpty(cookie.Value) - suite.Equal(testRefreshToken, cookie.Value) + suite.True(cookie.HttpOnly, "Session cookie should be HttpOnly") suite.Equal("/", cookie.Path) } } + suite.True(foundSessionCookie, "helix_session cookie should be set") }, }, { @@ -300,6 +310,15 @@ func (suite *AuthSuite) TestCallback() { mockIDToken := &oidc.IDToken{Nonce: testNonce} suite.oidcClient.EXPECT().Exchange(gomock.Any(), testCode).Return(mockToken, nil) suite.oidcClient.EXPECT().VerifyIDToken(gomock.Any(), mockToken).Return(mockIDToken, nil) + // BFF: ValidateUserToken and CreateUserSession are still called + suite.oidcClient.EXPECT().ValidateUserToken(gomock.Any(), testAccessToken).Return(&types.User{ + ID: "user-123", + Email: testEmail, + }, nil) + suite.store.EXPECT().CreateUserSession(gomock.Any(), gomock.Any()).Return(&types.UserSession{ + ID: "session-123", + UserID: "user-123", + }, nil) }, expectedStatus: http.StatusFound, checkResponse: func(rec *httptest.ResponseRecorder) { diff --git a/api/pkg/server/auth_utils.go b/api/pkg/server/auth_utils.go index 97301b6d24..79485a71bd 100644 --- a/api/pkg/server/auth_utils.go +++ b/api/pkg/server/auth_utils.go @@ -163,8 +163,6 @@ func addCorsHeaders(w http.ResponseWriter) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization") - // Expose X-Token-Refreshed header so frontend can read it for transparent token refresh - w.Header().Set("Access-Control-Expose-Headers", "X-Token-Refreshed") } /* diff --git a/api/pkg/server/server.go b/api/pkg/server/server.go index 0049c6f354..25eae50e08 100644 --- a/api/pkg/server/server.go +++ b/api/pkg/server/server.go @@ -108,6 +108,7 @@ type HelixAPIServer struct { authenticator auth.Authenticator oidcClient auth.OIDC oauthManager *oauth.Manager + sessionManager *auth.SessionManager fileServerHandler http.Handler cache *ristretto.Cache[string, string] avatarsBucket *blob.Bucket @@ -254,16 +255,7 @@ func NewServer( requestToCommenterMapping: make(map[string]string), streamingRateLimiter: make(map[string]time.Time), inferenceServer: inferenceServer, - authMiddleware: newAuthMiddleware( - authenticator, - oidcClient, - store, - authMiddlewareConfig{ - adminUserIDs: cfg.WebServer.AdminUserIDs, - runnerToken: cfg.WebServer.RunnerToken, - }, - cfg, - ), + sessionManager: auth.NewSessionManager(store, oidcClient, cfg), providerManager: providerManager, modelInfoProvider: modelInfoProvider, pubsub: ps, @@ -297,6 +289,19 @@ func NewServer( auditLogService: services.NewAuditLogService(store), } + // Initialize auth middleware with session manager for BFF authentication + apiServer.authMiddleware = newAuthMiddleware( + authenticator, + oidcClient, + store, + authMiddlewareConfig{ + adminUserIDs: cfg.WebServer.AdminUserIDs, + runnerToken: cfg.WebServer.RunnerToken, + }, + cfg, + apiServer.sessionManager, + ) + // Initialize SummaryService for async interaction summaries and session titles apiServer.summaryService = NewSummaryService(store, providerManager, ps) @@ -506,6 +511,7 @@ func (apiServer *HelixAPIServer) registerRoutes(_ context.Context) (*mux.Router, // Extract auth for /api/v1 routes only (not frontend static assets) subRouter := router.PathPrefix(APIPrefix).Subrouter() subRouter.Use(apiServer.authMiddleware.extractMiddleware) + subRouter.Use(apiServer.authMiddleware.csrfMiddleware) // auth router requires a valid token from keycloak or api key authRouter := subRouter.MatcherFunc(matchAllRoutes).Subrouter() @@ -770,6 +776,7 @@ func (apiServer *HelixAPIServer) registerRoutes(_ context.Context) (*mux.Router, insecureRouter.HandleFunc("/auth/login", apiServer.login).Methods(http.MethodPost) insecureRouter.HandleFunc("/auth/callback", apiServer.callback).Methods(http.MethodGet) insecureRouter.HandleFunc("/auth/user", apiServer.user).Methods(http.MethodGet) + insecureRouter.HandleFunc("/auth/session", apiServer.session).Methods(http.MethodGet) insecureRouter.HandleFunc("/auth/logout", apiServer.logout).Methods(http.MethodPost) insecureRouter.HandleFunc("/auth/authenticated", apiServer.authenticated).Methods(http.MethodGet) insecureRouter.HandleFunc("/auth/refresh", apiServer.refresh).Methods(http.MethodPost) diff --git a/api/pkg/store/postgres.go b/api/pkg/store/postgres.go index 61ae29ca72..cd63ede2ac 100644 --- a/api/pkg/store/postgres.go +++ b/api/pkg/store/postgres.go @@ -150,6 +150,7 @@ func (s *PostgresStore) runMigrations() error { &types.OAuthProvider{}, &types.OAuthConnection{}, &types.OAuthRequestToken{}, + &types.UserSession{}, &types.GitProviderConnection{}, &types.ServiceConnection{}, &types.UsageMetric{}, diff --git a/api/pkg/store/store.go b/api/pkg/store/store.go index 7fa099d08c..f935ee81d2 100644 --- a/api/pkg/store/store.go +++ b/api/pkg/store/store.go @@ -363,6 +363,16 @@ type Store interface { UpdateServiceConnection(ctx context.Context, connection *types.ServiceConnection) error DeleteServiceConnection(ctx context.Context, id string) error + // User Session methods (BFF authentication) + CreateUserSession(ctx context.Context, session *types.UserSession) (*types.UserSession, error) + GetUserSession(ctx context.Context, id string) (*types.UserSession, error) + GetUserSessionsByUser(ctx context.Context, userID string) ([]*types.UserSession, error) + UpdateUserSession(ctx context.Context, session *types.UserSession) (*types.UserSession, error) + DeleteUserSession(ctx context.Context, id string) error + DeleteUserSessionsByUser(ctx context.Context, userID string) error + GetUserSessionsNearOIDCExpiry(ctx context.Context, expiresBefore time.Time) ([]*types.UserSession, error) + DeleteExpiredUserSessions(ctx context.Context) error + CreateUsageMetric(ctx context.Context, metric *types.UsageMetric) (*types.UsageMetric, error) GetAppUsageMetrics(ctx context.Context, appID string, from time.Time, to time.Time) ([]*types.UsageMetric, error) GetAppDailyUsageMetrics(ctx context.Context, appID string, from time.Time, to time.Time) ([]*types.AggregatedUsageMetric, error) diff --git a/api/pkg/store/store_mocks.go b/api/pkg/store/store_mocks.go index dd6d624761..c4e707bda1 100644 --- a/api/pkg/store/store_mocks.go +++ b/api/pkg/store/store_mocks.go @@ -521,20 +521,6 @@ func (mr *MockStoreMockRecorder) CreateProjectRepository(ctx, projectID, reposit return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProjectRepository", reflect.TypeOf((*MockStore)(nil).CreateProjectRepository), ctx, projectID, repositoryID, organizationID) } -// UpdateProjectRepository mocks base method. -func (m *MockStore) UpdateProjectRepository(ctx context.Context, pr *types.ProjectRepository) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateProjectRepository", ctx, pr) - ret0, _ := ret[0].(error) - return ret0 -} - -// UpdateProjectRepository indicates an expected call of UpdateProjectRepository. -func (mr *MockStoreMockRecorder) UpdateProjectRepository(ctx, pr any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProjectRepository", reflect.TypeOf((*MockStore)(nil).UpdateProjectRepository), ctx, pr) -} - // CreateProviderEndpoint mocks base method. func (m *MockStore) CreateProviderEndpoint(ctx context.Context, providerEndpoint *types.ProviderEndpoint) (*types.ProviderEndpoint, error) { m.ctrl.T.Helper() @@ -945,6 +931,21 @@ func (mr *MockStoreMockRecorder) CreateUserMeta(ctx, UserMeta any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUserMeta", reflect.TypeOf((*MockStore)(nil).CreateUserMeta), ctx, UserMeta) } +// CreateUserSession mocks base method. +func (m *MockStore) CreateUserSession(ctx context.Context, session *types.UserSession) (*types.UserSession, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateUserSession", ctx, session) + ret0, _ := ret[0].(*types.UserSession) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateUserSession indicates an expected call of CreateUserSession. +func (mr *MockStoreMockRecorder) CreateUserSession(ctx, session any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUserSession", reflect.TypeOf((*MockStore)(nil).CreateUserSession), ctx, session) +} + // CreateWallet mocks base method. func (m *MockStore) CreateWallet(ctx context.Context, wallet *types.Wallet) (*types.Wallet, error) { m.ctrl.T.Helper() @@ -1086,6 +1087,20 @@ func (mr *MockStoreMockRecorder) DeleteDynamicModelInfo(ctx, id any) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteDynamicModelInfo", reflect.TypeOf((*MockStore)(nil).DeleteDynamicModelInfo), ctx, id) } +// DeleteExpiredUserSessions mocks base method. +func (m *MockStore) DeleteExpiredUserSessions(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteExpiredUserSessions", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteExpiredUserSessions indicates an expected call of DeleteExpiredUserSessions. +func (mr *MockStoreMockRecorder) DeleteExpiredUserSessions(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteExpiredUserSessions", reflect.TypeOf((*MockStore)(nil).DeleteExpiredUserSessions), ctx) +} + // DeleteGitProviderConnection mocks base method. func (m *MockStore) DeleteGitProviderConnection(ctx context.Context, id string) error { m.ctrl.T.Helper() @@ -1648,6 +1663,34 @@ func (mr *MockStoreMockRecorder) DeleteUser(ctx, id any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUser", reflect.TypeOf((*MockStore)(nil).DeleteUser), ctx, id) } +// DeleteUserSession mocks base method. +func (m *MockStore) DeleteUserSession(ctx context.Context, id string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteUserSession", ctx, id) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteUserSession indicates an expected call of DeleteUserSession. +func (mr *MockStoreMockRecorder) DeleteUserSession(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSession", reflect.TypeOf((*MockStore)(nil).DeleteUserSession), ctx, id) +} + +// DeleteUserSessionsByUser mocks base method. +func (m *MockStore) DeleteUserSessionsByUser(ctx context.Context, userID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteUserSessionsByUser", ctx, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteUserSessionsByUser indicates an expected call of DeleteUserSessionsByUser. +func (mr *MockStoreMockRecorder) DeleteUserSessionsByUser(ctx, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSessionsByUser", reflect.TypeOf((*MockStore)(nil).DeleteUserSessionsByUser), ctx, userID) +} + // DeleteWallet mocks base method. func (m *MockStore) DeleteWallet(ctx context.Context, id string) error { m.ctrl.T.Helper() @@ -3054,6 +3097,51 @@ func (mr *MockStoreMockRecorder) GetUserMonthlyTokenUsage(ctx, userID, providers return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserMonthlyTokenUsage", reflect.TypeOf((*MockStore)(nil).GetUserMonthlyTokenUsage), ctx, userID, providers) } +// GetUserSession mocks base method. +func (m *MockStore) GetUserSession(ctx context.Context, id string) (*types.UserSession, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserSession", ctx, id) + ret0, _ := ret[0].(*types.UserSession) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserSession indicates an expected call of GetUserSession. +func (mr *MockStoreMockRecorder) GetUserSession(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserSession", reflect.TypeOf((*MockStore)(nil).GetUserSession), ctx, id) +} + +// GetUserSessionsByUser mocks base method. +func (m *MockStore) GetUserSessionsByUser(ctx context.Context, userID string) ([]*types.UserSession, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserSessionsByUser", ctx, userID) + ret0, _ := ret[0].([]*types.UserSession) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserSessionsByUser indicates an expected call of GetUserSessionsByUser. +func (mr *MockStoreMockRecorder) GetUserSessionsByUser(ctx, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserSessionsByUser", reflect.TypeOf((*MockStore)(nil).GetUserSessionsByUser), ctx, userID) +} + +// GetUserSessionsNearOIDCExpiry mocks base method. +func (m *MockStore) GetUserSessionsNearOIDCExpiry(ctx context.Context, expiresBefore time.Time) ([]*types.UserSession, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserSessionsNearOIDCExpiry", ctx, expiresBefore) + ret0, _ := ret[0].([]*types.UserSession) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserSessionsNearOIDCExpiry indicates an expected call of GetUserSessionsNearOIDCExpiry. +func (mr *MockStoreMockRecorder) GetUserSessionsNearOIDCExpiry(ctx, expiresBefore any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserSessionsNearOIDCExpiry", reflect.TypeOf((*MockStore)(nil).GetUserSessionsNearOIDCExpiry), ctx, expiresBefore) +} + // GetUsersAggregatedUsageMetrics mocks base method. func (m *MockStore) GetUsersAggregatedUsageMetrics(ctx context.Context, provider string, from, to time.Time) ([]*types.UsersAggregatedUsageMetric, error) { m.ctrl.T.Helper() @@ -4720,6 +4808,20 @@ func (mr *MockStoreMockRecorder) UpdateProject(ctx, project any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProject", reflect.TypeOf((*MockStore)(nil).UpdateProject), ctx, project) } +// UpdateProjectRepository mocks base method. +func (m *MockStore) UpdateProjectRepository(ctx context.Context, pr *types.ProjectRepository) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateProjectRepository", ctx, pr) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateProjectRepository indicates an expected call of UpdateProjectRepository. +func (mr *MockStoreMockRecorder) UpdateProjectRepository(ctx, pr any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProjectRepository", reflect.TypeOf((*MockStore)(nil).UpdateProjectRepository), ctx, pr) +} + // UpdatePromptPin mocks base method. func (m *MockStore) UpdatePromptPin(ctx context.Context, promptID string, pinned bool) error { m.ctrl.T.Helper() @@ -5154,6 +5256,21 @@ func (mr *MockStoreMockRecorder) UpdateUserMeta(ctx, UserMeta any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserMeta", reflect.TypeOf((*MockStore)(nil).UpdateUserMeta), ctx, UserMeta) } +// UpdateUserSession mocks base method. +func (m *MockStore) UpdateUserSession(ctx context.Context, session *types.UserSession) (*types.UserSession, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateUserSession", ctx, session) + ret0, _ := ret[0].(*types.UserSession) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateUserSession indicates an expected call of UpdateUserSession. +func (mr *MockStoreMockRecorder) UpdateUserSession(ctx, session any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSession", reflect.TypeOf((*MockStore)(nil).UpdateUserSession), ctx, session) +} + // UpdateWallet mocks base method. func (m *MockStore) UpdateWallet(ctx context.Context, wallet *types.Wallet) (*types.Wallet, error) { m.ctrl.T.Helper() diff --git a/api/pkg/store/store_user_sessions.go b/api/pkg/store/store_user_sessions.go new file mode 100644 index 0000000000..10821a9eda --- /dev/null +++ b/api/pkg/store/store_user_sessions.go @@ -0,0 +1,123 @@ +package store + +import ( + "context" + "fmt" + "time" + + "github.com/helixml/helix/api/pkg/system" + "github.com/helixml/helix/api/pkg/types" +) + +// CreateUserSession creates a new user session +func (s *PostgresStore) CreateUserSession(ctx context.Context, session *types.UserSession) (*types.UserSession, error) { + if session.UserID == "" { + return nil, fmt.Errorf("user ID is required") + } + + if session.ID == "" { + session.ID = system.GenerateUserSessionID() + } + + now := time.Now() + session.CreatedAt = now + session.UpdatedAt = now + session.LastUsedAt = now + + err := s.gdb.WithContext(ctx).Create(session).Error + if err != nil { + return nil, err + } + + return session, nil +} + +// GetUserSession retrieves a user session by ID +func (s *PostgresStore) GetUserSession(ctx context.Context, id string) (*types.UserSession, error) { + if id == "" { + return nil, fmt.Errorf("session ID is required") + } + + var session types.UserSession + err := s.gdb.WithContext(ctx).Where("id = ?", id).First(&session).Error + if err != nil { + return nil, err + } + return &session, nil +} + +// GetUserSessionsByUser retrieves all sessions for a user +func (s *PostgresStore) GetUserSessionsByUser(ctx context.Context, userID string) ([]*types.UserSession, error) { + if userID == "" { + return nil, fmt.Errorf("user ID is required") + } + + var sessions []*types.UserSession + err := s.gdb.WithContext(ctx). + Where("user_id = ?", userID). + Order("created_at DESC"). + Find(&sessions).Error + if err != nil { + return nil, err + } + return sessions, nil +} + +// UpdateUserSession updates an existing user session +func (s *PostgresStore) UpdateUserSession(ctx context.Context, session *types.UserSession) (*types.UserSession, error) { + if session.ID == "" { + return nil, fmt.Errorf("session ID is required") + } + if session.UserID == "" { + return nil, fmt.Errorf("user ID is required") + } + + session.UpdatedAt = time.Now() + err := s.gdb.WithContext(ctx).Save(session).Error + if err != nil { + return nil, err + } + return session, nil +} + +// DeleteUserSession deletes a user session by ID +func (s *PostgresStore) DeleteUserSession(ctx context.Context, id string) error { + if id == "" { + return fmt.Errorf("session ID is required") + } + + return s.gdb.WithContext(ctx).Where("id = ?", id).Delete(&types.UserSession{}).Error +} + +// DeleteUserSessionsByUser deletes all sessions for a user (e.g., on logout from all devices) +func (s *PostgresStore) DeleteUserSessionsByUser(ctx context.Context, userID string) error { + if userID == "" { + return fmt.Errorf("user ID is required") + } + + return s.gdb.WithContext(ctx).Where("user_id = ?", userID).Delete(&types.UserSession{}).Error +} + +// GetUserSessionsNearOIDCExpiry retrieves OIDC sessions that need token refresh +// This is used by the background refresh job +func (s *PostgresStore) GetUserSessionsNearOIDCExpiry(ctx context.Context, expiresBefore time.Time) ([]*types.UserSession, error) { + var sessions []*types.UserSession + err := s.gdb.WithContext(ctx). + Where("auth_provider = ?", types.AuthProviderOIDC). + Where("oidc_refresh_token != ''"). + Where("oidc_token_expiry < ?", expiresBefore). + Where("expires_at > ?", time.Now()). // Only non-expired sessions + Find(&sessions).Error + if err != nil { + return nil, err + } + return sessions, nil +} + +// DeleteExpiredUserSessions deletes all expired sessions +// This should be run periodically to clean up the database +func (s *PostgresStore) DeleteExpiredUserSessions(ctx context.Context) error { + return s.gdb.WithContext(ctx). + Where("expires_at < ?", time.Now()). + Delete(&types.UserSession{}).Error +} diff --git a/api/pkg/store/store_user_sessions_test.go b/api/pkg/store/store_user_sessions_test.go new file mode 100644 index 0000000000..8a77843dcc --- /dev/null +++ b/api/pkg/store/store_user_sessions_test.go @@ -0,0 +1,273 @@ +package store + +import ( + "context" + "time" + + "github.com/helixml/helix/api/pkg/system" + "github.com/helixml/helix/api/pkg/types" +) + +func (suite *PostgresStoreTestSuite) TestPostgresStore_CreateUserSession() { + session := &types.UserSession{ + UserID: system.GenerateUserID(), + AuthProvider: types.AuthProviderOIDC, + ExpiresAt: time.Now().Add(24 * time.Hour), + UserAgent: "Test-Agent/1.0", + IPAddress: "192.168.1.1", + } + + createdSession, err := suite.db.CreateUserSession(suite.ctx, session) + + suite.NoError(err) + suite.NotNil(createdSession) + suite.NotEmpty(createdSession.ID) + suite.True(createdSession.ID[:4] == "uss_") // ULID with prefix + suite.Equal(session.UserID, createdSession.UserID) + suite.Equal(session.AuthProvider, createdSession.AuthProvider) + suite.False(createdSession.CreatedAt.IsZero()) + suite.False(createdSession.UpdatedAt.IsZero()) + suite.False(createdSession.LastUsedAt.IsZero()) + + suite.T().Cleanup(func() { + _ = suite.db.DeleteUserSession(suite.ctx, createdSession.ID) + }) +} + +func (suite *PostgresStoreTestSuite) TestPostgresStore_CreateUserSession_RequiresUserID() { + session := &types.UserSession{ + AuthProvider: types.AuthProviderOIDC, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + + _, err := suite.db.CreateUserSession(suite.ctx, session) + + suite.Error(err) + suite.Contains(err.Error(), "user ID is required") +} + +func (suite *PostgresStoreTestSuite) TestPostgresStore_GetUserSession() { + session := &types.UserSession{ + UserID: system.GenerateUserID(), + AuthProvider: types.AuthProviderRegular, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + + created, err := suite.db.CreateUserSession(suite.ctx, session) + suite.NoError(err) + + suite.T().Cleanup(func() { + _ = suite.db.DeleteUserSession(suite.ctx, created.ID) + }) + + retrieved, err := suite.db.GetUserSession(suite.ctx, created.ID) + + suite.NoError(err) + suite.NotNil(retrieved) + suite.Equal(created.ID, retrieved.ID) + suite.Equal(created.UserID, retrieved.UserID) +} + +func (suite *PostgresStoreTestSuite) TestPostgresStore_GetUserSession_RequiresID() { + _, err := suite.db.GetUserSession(suite.ctx, "") + + suite.Error(err) + suite.Contains(err.Error(), "session ID is required") +} + +func (suite *PostgresStoreTestSuite) TestPostgresStore_GetUserSessionsByUser() { + userID := system.GenerateUserID() + + // Create multiple sessions for the same user + for i := 0; i < 3; i++ { + session := &types.UserSession{ + UserID: userID, + AuthProvider: types.AuthProviderOIDC, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + created, err := suite.db.CreateUserSession(suite.ctx, session) + suite.NoError(err) + + suite.T().Cleanup(func() { + _ = suite.db.DeleteUserSession(suite.ctx, created.ID) + }) + } + + sessions, err := suite.db.GetUserSessionsByUser(suite.ctx, userID) + + suite.NoError(err) + suite.Len(sessions, 3) + for _, s := range sessions { + suite.Equal(userID, s.UserID) + } +} + +func (suite *PostgresStoreTestSuite) TestPostgresStore_GetUserSessionsByUser_RequiresUserID() { + _, err := suite.db.GetUserSessionsByUser(suite.ctx, "") + + suite.Error(err) + suite.Contains(err.Error(), "user ID is required") +} + +func (suite *PostgresStoreTestSuite) TestPostgresStore_UpdateUserSession() { + session := &types.UserSession{ + UserID: system.GenerateUserID(), + AuthProvider: types.AuthProviderOIDC, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + + created, err := suite.db.CreateUserSession(suite.ctx, session) + suite.NoError(err) + + suite.T().Cleanup(func() { + _ = suite.db.DeleteUserSession(suite.ctx, created.ID) + }) + + // Update the session + created.UserAgent = "Updated-Agent/2.0" + created.IPAddress = "10.0.0.1" + + updated, err := suite.db.UpdateUserSession(suite.ctx, created) + + suite.NoError(err) + suite.NotNil(updated) + suite.Equal("Updated-Agent/2.0", updated.UserAgent) + suite.Equal("10.0.0.1", updated.IPAddress) + suite.True(updated.UpdatedAt.After(created.CreatedAt)) +} + +func (suite *PostgresStoreTestSuite) TestPostgresStore_UpdateUserSession_RequiresID() { + session := &types.UserSession{ + UserID: system.GenerateUserID(), + } + + _, err := suite.db.UpdateUserSession(suite.ctx, session) + + suite.Error(err) + suite.Contains(err.Error(), "session ID is required") +} + +func (suite *PostgresStoreTestSuite) TestPostgresStore_UpdateUserSession_RequiresUserID() { + session := &types.UserSession{ + ID: system.GenerateUserSessionID(), + } + + _, err := suite.db.UpdateUserSession(suite.ctx, session) + + suite.Error(err) + suite.Contains(err.Error(), "user ID is required") +} + +func (suite *PostgresStoreTestSuite) TestPostgresStore_DeleteUserSession() { + session := &types.UserSession{ + UserID: system.GenerateUserID(), + AuthProvider: types.AuthProviderOIDC, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + + created, err := suite.db.CreateUserSession(suite.ctx, session) + suite.NoError(err) + + err = suite.db.DeleteUserSession(suite.ctx, created.ID) + suite.NoError(err) + + // Verify it was deleted + _, err = suite.db.GetUserSession(suite.ctx, created.ID) + suite.Error(err) +} + +func (suite *PostgresStoreTestSuite) TestPostgresStore_DeleteUserSession_RequiresID() { + err := suite.db.DeleteUserSession(suite.ctx, "") + + suite.Error(err) + suite.Contains(err.Error(), "session ID is required") +} + +func (suite *PostgresStoreTestSuite) TestPostgresStore_DeleteUserSessionsByUser() { + userID := system.GenerateUserID() + + // Create multiple sessions + var sessionIDs []string + for i := 0; i < 3; i++ { + session := &types.UserSession{ + UserID: userID, + AuthProvider: types.AuthProviderOIDC, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + created, err := suite.db.CreateUserSession(suite.ctx, session) + suite.NoError(err) + sessionIDs = append(sessionIDs, created.ID) + } + + err := suite.db.DeleteUserSessionsByUser(suite.ctx, userID) + suite.NoError(err) + + // Verify all were deleted + sessions, err := suite.db.GetUserSessionsByUser(suite.ctx, userID) + suite.NoError(err) + suite.Len(sessions, 0) +} + +func (suite *PostgresStoreTestSuite) TestPostgresStore_DeleteUserSessionsByUser_RequiresUserID() { + err := suite.db.DeleteUserSessionsByUser(suite.ctx, "") + + suite.Error(err) + suite.Contains(err.Error(), "user ID is required") +} + +func (suite *PostgresStoreTestSuite) TestPostgresStore_GetUserSessionsNearOIDCExpiry() { + userID := system.GenerateUserID() + + // Create a session with OIDC token expiring soon + expiringSession := &types.UserSession{ + UserID: userID, + AuthProvider: types.AuthProviderOIDC, + ExpiresAt: time.Now().Add(24 * time.Hour), + OIDCRefreshToken: "refresh_token", + OIDCAccessToken: "access_token", + OIDCTokenExpiry: time.Now().Add(2 * time.Minute), // Expires in 2 minutes + } + + created, err := suite.db.CreateUserSession(suite.ctx, expiringSession) + suite.NoError(err) + + suite.T().Cleanup(func() { + _ = suite.db.DeleteUserSession(suite.ctx, created.ID) + }) + + // Query for sessions expiring before 5 minutes from now + sessions, err := suite.db.GetUserSessionsNearOIDCExpiry(context.Background(), time.Now().Add(5*time.Minute)) + + suite.NoError(err) + suite.NotEmpty(sessions) + + found := false + for _, s := range sessions { + if s.ID == created.ID { + found = true + break + } + } + suite.True(found, "Expected to find the expiring session") +} + +func (suite *PostgresStoreTestSuite) TestPostgresStore_DeleteExpiredUserSessions() { + userID := system.GenerateUserID() + + // Create an expired session + expiredSession := &types.UserSession{ + UserID: userID, + AuthProvider: types.AuthProviderOIDC, + ExpiresAt: time.Now().Add(-time.Hour), // Already expired + } + + created, err := suite.db.CreateUserSession(suite.ctx, expiredSession) + suite.NoError(err) + + err = suite.db.DeleteExpiredUserSessions(suite.ctx) + suite.NoError(err) + + // Verify the expired session was deleted + _, err = suite.db.GetUserSession(suite.ctx, created.ID) + suite.Error(err) +} diff --git a/api/pkg/system/uuid.go b/api/pkg/system/uuid.go index 4a4f421c19..3637cd8e66 100644 --- a/api/pkg/system/uuid.go +++ b/api/pkg/system/uuid.go @@ -42,6 +42,7 @@ const ( SpecTaskPrefix = "spt_" ProjectPrefix = "prj_" CloneGroupPrefix = "clg_" + UserSessionPrefix = "uss_" ) func GenerateUUID() string { @@ -190,3 +191,7 @@ func GenerateSpecTaskID() string { func GenerateCloneGroupID() string { return fmt.Sprintf("%s%s", CloneGroupPrefix, newID()) } + +func GenerateUserSessionID() string { + return fmt.Sprintf("%s%s", UserSessionPrefix, newID()) +} diff --git a/api/pkg/types/enums.go b/api/pkg/types/enums.go index a9b39aa9e7..d222b9d48a 100644 --- a/api/pkg/types/enums.go +++ b/api/pkg/types/enums.go @@ -284,11 +284,12 @@ func ValidateEntityType(datasetType string, acceptEmpty bool) (DataEntityType, e type TokenType string const ( - TokenTypeNone TokenType = "" - TokenTypeRunner TokenType = "runner" - TokenTypeOIDC TokenType = "oidc" - TokenTypeAPIKey TokenType = "api_key" - TokenTypeSocket TokenType = "socket" + TokenTypeNone TokenType = "" + TokenTypeRunner TokenType = "runner" + TokenTypeOIDC TokenType = "oidc" + TokenTypeAPIKey TokenType = "api_key" + TokenTypeSocket TokenType = "socket" + TokenTypeSession TokenType = "session" // BFF session for regular (email/password) auth ) type ScriptRunState string diff --git a/api/pkg/types/session.go b/api/pkg/types/session.go new file mode 100644 index 0000000000..86f544abdf --- /dev/null +++ b/api/pkg/types/session.go @@ -0,0 +1,100 @@ +package types + +import ( + "time" +) + +// Note: AuthProvider type and constants (AuthProviderRegular, AuthProviderOIDC) +// are already defined in authz.go - we reuse them here. + +// UserSession represents an authenticated user session. +// This is the core of the BFF (Backend-For-Frontend) authentication system. +// +// The frontend only sees the session ID via an HttpOnly cookie. +// All token management (OIDC refresh, etc.) happens transparently on the backend. +// +// Similar pattern to OAuthConnection but specifically for user authentication. +type UserSession struct { + ID string `json:"id" gorm:"primaryKey;type:text"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + + // User who owns this session + UserID string `json:"user_id" gorm:"not null;index"` + + // Auth provider used: "regular" or "oidc" + AuthProvider AuthProvider `json:"auth_provider" gorm:"not null;type:text"` + + // Session expiry (30 days from creation by default) + ExpiresAt time.Time `json:"expires_at" gorm:"not null;index"` + + // For OIDC sessions: store the refresh token so backend can refresh access tokens + // These are never exposed to the frontend (json:"-") + OIDCRefreshToken string `json:"-" gorm:"column:oidc_refresh_token;type:text"` + OIDCAccessToken string `json:"-" gorm:"column:oidc_access_token;type:text"` + OIDCTokenExpiry time.Time `json:"-" gorm:"column:oidc_token_expiry"` + + // Optional metadata for security/audit + UserAgent string `json:"user_agent,omitempty" gorm:"type:text"` + IPAddress string `json:"ip_address,omitempty"` + LastUsedAt time.Time `json:"last_used_at"` +} + +// TableName returns the table name for UserSession +func (UserSession) TableName() string { + return "user_sessions" +} + +// IsExpired returns true if the session has expired +func (s *UserSession) IsExpired() bool { + return time.Now().After(s.ExpiresAt) +} + +// NeedsOIDCRefresh returns true if the OIDC access token needs to be refreshed +// We refresh if the token expires within the next 5 minutes +func (s *UserSession) NeedsOIDCRefresh() bool { + if s.AuthProvider != AuthProviderOIDC { + return false + } + if s.OIDCRefreshToken == "" { + return false + } + // Refresh if token expires within 5 minutes + return time.Now().Add(5 * time.Minute).After(s.OIDCTokenExpiry) +} + +// UpdateOIDCTokens updates the OIDC tokens after a refresh +func (s *UserSession) UpdateOIDCTokens(accessToken, refreshToken string, expiry time.Time) { + s.OIDCAccessToken = accessToken + if refreshToken != "" { + s.OIDCRefreshToken = refreshToken + } + s.OIDCTokenExpiry = expiry + s.UpdatedAt = time.Now() +} + +// Touch updates the LastUsedAt timestamp +func (s *UserSession) Touch() { + s.LastUsedAt = time.Now() +} + +// SessionInfo is the public session information returned to the frontend +// It does NOT include any tokens - just session metadata +type SessionInfo struct { + ID string `json:"id"` + UserID string `json:"user_id"` + AuthProvider AuthProvider `json:"auth_provider"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` +} + +// ToSessionInfo converts a UserSession to public SessionInfo +func (s *UserSession) ToSessionInfo() *SessionInfo { + return &SessionInfo{ + ID: s.ID, + UserID: s.UserID, + AuthProvider: s.AuthProvider, + CreatedAt: s.CreatedAt, + ExpiresAt: s.ExpiresAt, + } +} diff --git a/design/2026-02-05-bff-auth-rewrite.md b/design/2026-02-05-bff-auth-rewrite.md new file mode 100644 index 0000000000..c3cac1102f --- /dev/null +++ b/design/2026-02-05-bff-auth-rewrite.md @@ -0,0 +1,484 @@ +# BFF Authentication Rewrite + +**Date:** 2026-02-05 +**Status:** In Progress +**Author:** Claude (with Luke) + +## Problem + +The current frontend authentication system is a hybrid approach that's the worst of both worlds: + +1. **Frontend manages tokens in 5+ different locations:** + - `axios.defaults.headers.common` (global axios) + - `apiClientSingleton.setSecurityData()` (generated API client) + - `localStorage.setItem('token', ...)` (persistence) + - React state via `account.token` + - Custom event dispatch (`TOKEN_REFRESHED_EVENT`) + +2. **Frontend deals with OAuth/OIDC complexity:** + - Understands refresh tokens exist + - Has to capture `X-Token-Refreshed` headers + - Multiple interceptors on axios instances + - Race conditions when token expires + +3. **Two different auth systems with different behaviors:** + - Regular Helix Auth: Long-lived JWTs (7 days default) + - OIDC Auth: Short-lived access tokens (~1 hour) + refresh tokens + +## Solution: Backend-For-Frontend (BFF) Pattern + +### Core Principle + +The frontend should only care about ONE thing: **an HTTP-only session cookie**. + +- Frontend initiates login → gets redirected to auth flow +- On successful auth → backend sets HTTP-only session cookie +- All API requests automatically include the cookie (same-origin) +- Session expires naturally after 30 days +- **No tokens in JavaScript memory, ever** + +### Architecture + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ BROWSER / FRONTEND │ +│ - No tokens in memory │ +│ - No Authorization headers │ +│ - No localStorage token storage │ +│ - Just calls APIs, cookies sent automatically │ +└─────────────────────────────────────────────────────────────────┘ + │ + │ HTTP requests (session cookie) + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ HELIX API (BFF) │ +│ │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ Auth Middleware │ │ +│ │ 1. Check BFF session cookie (helix_session) │ │ +│ │ 2. If no session, check Authorization header (API key) │ │ +│ │ 3. OIDC token refresh handled transparently │ │ +│ └─────────────────────────────────────────────────────────┘ │ +│ │ │ +│ Browser auth: │ CLI/API client auth: │ +│ ┌────────────────────────┴────────────────────────┐ │ +│ ▼ ▼ ▼ │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │ +│ │ BFF Session │ │ BFF Session │ │ API Key │ │ +│ │ (Regular) │ │ (OIDC) │ │ (unchanged) │ │ +│ │ │ │ │ │ │ │ +│ │ Email/pass │ │ Google etc │ │ Authorization: │ │ +│ │ → session │ │ → session │ │ Bearer hl-xxx │ │ +│ └──────────────┘ └──────────────┘ └──────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + ▲ + │ + ┌───────────────────┘ + │ Authorization: Bearer hl-xxx +┌─────────────────────────────────────────────────────────────────┐ +│ CLI / API CLIENTS │ +│ - Uses API keys (hl-xxx) in Authorization header │ +│ - No cookies, no sessions │ +│ - Runner tokens for internal system auth │ +└─────────────────────────────────────────────────────────────────┘ +``` + +**Key Point:** API keys (`hl-xxx`) and runner tokens continue to work exactly as before. +The BFF session pattern only applies to browser-based authentication. + +### Session Table Schema + +We'll follow the same pattern as `OAuthConnection` (see `api/pkg/types/oauth.go`), which already has: +- Token storage (access_token, refresh_token, expires_at) +- Background refresh via `RefreshExpiredTokens()` in the OAuth manager +- Database-backed persistence + +```go +// UserSession represents an authenticated user session +// Similar pattern to OAuthConnection but for user auth +type UserSession struct { + ID string `json:"id" gorm:"primaryKey;type:uuid"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + DeletedAt gorm.DeletedAt `json:"deleted_at" gorm:"index"` + + // User who owns this session + UserID string `json:"user_id" gorm:"not null;index"` + + // Auth provider used: "regular" or "oidc" + AuthProvider string `json:"auth_provider" gorm:"not null;type:text"` + + // Session expiry (30 days from creation) + ExpiresAt time.Time `json:"expires_at" gorm:"not null;index"` + + // For OIDC sessions: store the refresh token so backend can refresh access tokens + // Access token is short-lived and fetched on-demand when needed + OIDCRefreshToken string `json:"-" gorm:"type:text"` // Never expose to frontend + OIDCAccessToken string `json:"-" gorm:"type:text"` // Cache for backend use + OIDCTokenExpiry time.Time `json:"-"` // When access token expires + + // Optional metadata for security/audit + UserAgent string `json:"user_agent,omitempty" gorm:"type:text"` + IPAddress string `json:"ip_address,omitempty" gorm:"type:varchar(45)"` + LastUsedAt time.Time `json:"last_used_at"` +} +``` + +### Code Reuse from OAuth Manager + +The `api/pkg/oauth/manager.go` has patterns we can reuse: + +1. **Background refresh job** (`RefreshExpiredTokens`): + ```go + // Already runs every minute, refreshes tokens approaching expiry + err := m.RefreshExpiredTokens(ctx, 5*time.Minute) + ``` + +2. **Token refresh on access** (`RefreshTokenIfNeeded`): + ```go + // Called when getting a connection, refreshes if needed + if err := provider.RefreshTokenIfNeeded(ctx, connection); err != nil { + return nil, fmt.Errorf("failed to refresh token: %w", err) + } + ``` + +For user sessions, we'll create a `SessionManager` that follows the same pattern: +- Background goroutine to refresh OIDC tokens before they expire +- `RefreshSessionIfNeeded()` called when validating session +- Reuse the existing `auth/oidc.go` OIDC client for token refresh + +### Silent Token Renewal (Back-Channel Refresh) + +The BFF pattern uses **back-channel refresh** instead of traditional Silent Renew (front-channel): + +**How It Works:** +1. User authenticates via OIDC (Google, Keycloak, etc.) +2. Backend receives access token + refresh token in callback +3. Backend stores both tokens in `UserSession` (never exposed to frontend) +4. On each request, `SessionManager.GetSessionFromRequest()` checks if access token expires soon +5. If expiring, calls `oidcClient.RefreshAccessToken(refreshToken)` to get new tokens +6. Session is updated with new tokens - user never knows this happened + +**Why This Is Better Than Front-Channel Silent Renew:** + +| Aspect | Front-Channel (iframe) | Back-Channel (refresh tokens) | +|--------|------------------------|------------------------------| +| Third-party cookies | Requires them (being blocked) | Not needed | +| Browser support | Breaking in modern browsers | Works everywhere | +| Complexity | Requires iframe + postMessage | Simple backend API call | +| Security | Token in browser | Token never in browser | +| PKCE | Required per request | Used once at login | + +**Requirements for Google OIDC:** +- Set `OIDC_OFFLINE_ACCESS=true` in environment +- This adds `offline_access` scope which grants a refresh token +- Without this, access tokens expire in ~1 hour with no renewal option + +**Refresh Token Rotation:** +- Google returns a new refresh token with each refresh (rotation) +- Backend automatically stores the new refresh token +- Old refresh tokens become invalid after use + +### Cookie Design + +**Single cookie: `helix_session`** +- Value: Session ID (UUID) +- HttpOnly: true (JavaScript cannot access) +- Secure: true (HTTPS only in production) +- SameSite: Lax (CSRF protection) +- Path: / +- MaxAge: 30 days (2592000 seconds) + +### Auth Flow + +#### Regular Helix Auth (Email/Password) + +``` +1. POST /api/v1/auth/login { email, password } +2. Backend validates credentials +3. Backend creates session in database +4. Backend sets helix_session cookie +5. Returns { authenticated: true, user: { ... } } +``` + +#### OIDC Auth (Google) + +``` +1. POST /api/v1/auth/login → returns { redirect_url: "/api/v1/auth/oidc" } +2. Frontend redirects to OIDC provider +3. User authenticates with Google +4. Callback: /api/v1/auth/oidc/callback?code=... +5. Backend exchanges code for tokens +6. Backend creates session with OIDC tokens stored +7. Backend sets helix_session cookie +8. Redirects to frontend (original URL or /) +``` + +### Session Validation Middleware + +```go +func (s *HelixAPIServer) sessionMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sessionID, err := r.Cookie("helix_session") + if err != nil || sessionID.Value == "" { + // No session - continue without user (public endpoint) + // or return 401 for protected endpoints + next.ServeHTTP(w, r) + return + } + + session, err := s.store.GetSession(r.Context(), sessionID.Value) + if err != nil || session.IsExpired() { + // Invalid/expired session - clear cookie + clearSessionCookie(w) + next.ServeHTTP(w, r) + return + } + + // For OIDC sessions, check if access token needs refresh + if session.AuthProvider == "oidc" && session.OIDCTokenNeedsRefresh() { + newToken, err := s.oidcClient.RefreshAccessToken(r.Context(), session.OIDCRefreshToken) + if err != nil { + // Refresh failed - session is invalid + s.store.DeleteSession(r.Context(), session.ID) + clearSessionCookie(w) + next.ServeHTTP(w, r) + return + } + session.UpdateOIDCTokens(newToken) + s.store.UpdateSession(r.Context(), session) + } + + // Update last_used_at periodically (not every request) + session.TouchIfNeeded(s.store) + + // Get user from database and add to context + user, err := s.store.GetUser(r.Context(), session.UserID) + if err != nil { + clearSessionCookie(w) + next.ServeHTTP(w, r) + return + } + + ctx := setRequestUser(r.Context(), *user) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} +``` + +### API Changes + +#### New Endpoints + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/api/v1/auth/session` | GET | Get current session info | +| `/api/v1/auth/logout` | POST | Delete session, clear cookie | + +#### Modified Endpoints + +| Endpoint | Change | +|----------|--------| +| `/api/v1/auth/login` | Create session, set cookie | +| `/api/v1/auth/oidc/callback` | Create session, set cookie | +| `/api/v1/auth/authenticated` | Check session cookie instead of access_token | +| `/api/v1/auth/user` | Get user from session, no token in response | + +#### Removed/Deprecated + +| Endpoint | Reason | +|----------|--------| +| `/api/v1/auth/refresh` | No longer needed - backend handles transparently | +| `X-Token-Refreshed` header | No longer needed | + +### Frontend Changes + +#### Files to Modify + +1. **`frontend/src/hooks/useApi.ts`** + - Remove `setToken()` function + - Remove `handleTokenRefreshHeader()` + - Remove all token-related interceptors + - Remove `localStorage.setItem('token', ...)` + - Remove `securityWorker` (no Authorization header needed) + +2. **`frontend/src/contexts/account.tsx`** + - Remove `token` from context + - Remove `tokenUrlEscaped` + - Remove `TOKEN_REFRESHED_EVENT` handling + - Simplify `initialize()` - just call `/api/v1/auth/session` + - Simplify `onLogout()` - just call `/api/v1/auth/logout` + +3. **`frontend/src/hooks/useWebsocket.ts`** + - Remove token dependency (cookies sent automatically with WS) + +4. **`frontend/src/contexts/streaming.tsx`** + - Remove Authorization header from EventSource + - Cookies are sent automatically with EventSource + +5. **`frontend/src/hooks/useKnowledge.ts`** + - Remove Authorization header from fetch calls + +6. **`frontend/src/components/auth/TokenExpiryCounter.tsx`** + - Remove entirely (no token expiry visible to frontend) + +### Migration Path + +Clean cutover - no backward compatibility with old token-based auth: + +1. **Add session table** (AutoMigrate) +2. **Add session middleware** to auth layer +3. **Update auth endpoints** to create sessions (login, OIDC callback) +4. **Update frontend** - remove all token management code +5. **Deploy** - users will need to log in again (acceptable trade-off for simplicity) + +**Note:** API keys and runner tokens continue to work unchanged. Only browser-based +authentication is affected - users just need to log in again after the update. + +### Security Considerations + +1. **CSRF Protection (Double-Submit Cookie Pattern)** + - **SameSite=Lax** on both cookies (prevents cross-site requests, allows OAuth redirects) + - **Dual-cookie approach:** + - `helix_session`: HttpOnly cookie (not readable by JS) - contains session ID + - `helix_csrf`: Regular cookie (readable by JS) - contains CSRF token + - **X-CSRF-Token header**: Frontend reads `helix_csrf` cookie and sends as `X-CSRF-Token` header + - **Validation**: Backend validates that cookie value matches header value + - **State-changing requests only**: CSRF validated for POST, PUT, DELETE, PATCH + - **Exempt paths**: Login, logout, OIDC endpoints (before session exists) + - **API keys bypass**: Requests with API keys (no session cookie) skip CSRF validation + - State parameter in OIDC flow (already implemented) + +2. **Session Hijacking** + - Secure cookies (HTTPS only in production) + - HttpOnly session cookie (no JS access to session ID) + - Session tied to IP/user-agent (recorded for audit) + +3. **Session Fixation** + - New session ID generated on each login + - New CSRF token generated with each session + +4. **Token Storage** + - OIDC tokens encrypted at rest in database (recommended) + - Only refresh token stored long-term; access token cached briefly + +### Testing Plan + +1. **Regular Auth Flow** + - Login with email/password + - Verify session cookie set + - Verify API calls work without Authorization header + - Logout and verify session deleted + +2. **OIDC Auth Flow** + - Login with Google + - Verify session cookie set + - Wait for OIDC token to expire (or simulate) + - Verify backend refreshes transparently + - Logout and verify session deleted + +3. **Page Refresh** + - Login, refresh page + - Verify still authenticated + - Verify no 401 errors + +4. **Cross-tab** + - Login in one tab + - Open new tab + - Verify authenticated in new tab + +5. **Session Expiry** + - Create session + - Wait for expiry (or simulate) + - Verify redirect to login + +## Frontend API Interaction Audit + +All places in the frontend that interact with the API, and how they handle auth: + +### 1. Generated API Client (`apiClient.xxx()`) +**Files:** All service files (`projectService.ts`, `sessionService.ts`, etc.) +**Current auth:** `securityWorker` adds Authorization header from in-memory token +**BFF change:** Remove `securityWorker` - cookies sent automatically + +### 2. Raw axios via useApi hook (`api.get()`, `api.post()`) +**Files:** `account.tsx` (`loadStatus`), various legacy code +**Current auth:** `axios.defaults.headers.common['Authorization']` set by `setToken()` +**BFF change:** Remove `setToken()` - cookies sent automatically + +### 3. Direct fetch() calls +**Files:** `useKnowledge.ts`, `filestoreService.ts`, `account.tsx` (login/logout) +**Current auth:** Manual `Authorization: Bearer ${account.token}` header +**BFF change:** Add `credentials: 'same-origin'` (default for same-origin, but explicit is clearer) + +### 4. EventSource (Server-Sent Events) +**Files:** `streaming.tsx` +**Current auth:** Cannot set custom headers on EventSource; currently passes token in URL or relies on cookies +**BFF change:** Cookies sent automatically by browser for same-origin + +### 5. WebSocket connections +**Files:** `useWebsocket.ts`, `DesignReviewContent.tsx`, `streaming.tsx` +**Current auth:** Browsers send cookies automatically with WebSocket connections +**BFF change:** No change needed - already works via cookies + +### 6. helix-stream library +**Files:** `lib/helix-stream/api.ts` +**Purpose:** Moonlight streaming host control (not Helix API auth) +**Current auth:** Separate credential system (`sessionStorage.mlCredentials`) +**BFF change:** Not affected - this is for external streaming hosts + +### Summary of Changes Needed + +| API Method | Current Token Source | BFF Change | +|------------|---------------------|------------| +| Generated client | `securityWorker` | Remove `securityWorker` | +| Raw axios | `axios.defaults.headers.common` | Remove header management | +| fetch() | Manual header | Remove Authorization header | +| EventSource | Cookies (already) | No change | +| WebSocket | Cookies (already) | No change | + +## Best Practices Applied (from industry research) + +Based on [Auth0](https://auth0.com/blog/the-backend-for-frontend-pattern-bff/), [FusionAuth](https://fusionauth.io/blog/backend-for-frontend), and [Duende BFF Framework](https://docs.duendesoftware.com/bff/): + +1. **No tokens in JavaScript, ever** - Session ID only, HttpOnly cookie +2. **Server-side token storage** - OIDC tokens stored in database, not in cookies +3. **Transparent refresh** - Backend handles OIDC token refresh; frontend never knows +4. **CSRF protection** - SameSite=Lax cookies (may add custom header for extra safety) +5. **Secure cookies** - HttpOnly + Secure + SameSite + +## Implementation Order + +1. Add session table and store methods +2. Add session middleware (in parallel with existing auth) +3. Update login endpoints to create sessions +4. Update frontend to use session-based auth +5. Remove old token code +6. Clean up and test + +## Answers to Key Questions + +**Should frontend have NO refresh logic?** +Yes. The frontend should be completely unaware of token refresh. The backend handles all of this: +- Regular auth: Long-lived session (30 days), no refresh needed +- OIDC auth: Session (30 days), backend refreshes OIDC tokens transparently + +**Should we rip out all token storage locations?** +Yes. The frontend currently stores tokens in 5+ locations. After BFF: +- No localStorage token +- No axios.defaults.headers.common['Authorization'] +- No apiClientSingleton.setSecurityData() +- No account.token React state +- No TOKEN_REFRESHED_EVENT + +The only "state" is the HttpOnly cookie, which JavaScript cannot access. + +**Do cookies work with WebSockets?** +Yes. Browsers automatically send cookies with WebSocket connections to the same origin. This is why the current system already relies on cookies for WebSocket auth. + +## Sources + +- [Auth0: The Backend for Frontend Pattern (BFF)](https://auth0.com/blog/the-backend-for-frontend-pattern-bff/) +- [FusionAuth: A Guide to Backend-for-Frontend Auth](https://fusionauth.io/blog/backend-for-frontend) +- [Duende: BFF Security Framework](https://docs.duendesoftware.com/bff/) +- [Medium: Secure Your Tokens the Right Way: BFF + Redis Explained](https://dev.to/sovannaro/the-backend-for-frontend-bff-pattern-secure-auth-done-right-fm7) diff --git a/frontend/src/components/app/KnowledgeEditor.tsx b/frontend/src/components/app/KnowledgeEditor.tsx index bc1d9296bc..27529283d8 100644 --- a/frontend/src/components/app/KnowledgeEditor.tsx +++ b/frontend/src/components/app/KnowledgeEditor.tsx @@ -503,7 +503,7 @@ const KnowledgeEditor: FC = ({ // Add functions to open files in a new tab and in the filestore const openFileInNewTab = (file: IFileStoreItem, sourcePath: string) => { - if (!account.token) { + if (!account.user) { snackbar.error('Must be logged in to view files'); return; } diff --git a/frontend/src/components/datagrid/FileStore.tsx b/frontend/src/components/datagrid/FileStore.tsx index d2e50c0eed..71cd460a46 100644 --- a/frontend/src/components/datagrid/FileStore.tsx +++ b/frontend/src/components/datagrid/FileStore.tsx @@ -54,7 +54,7 @@ const FileStoreDataGrid: FC> = ( let icon = null if(isImage(data.name)) { - icon = account.token ? ( + icon = account.user ? ( > = ( onView, onEdit, onDelete, - account.token, + account.user, ]) const theme = useTheme() diff --git a/frontend/src/components/session/InteractionInference.tsx b/frontend/src/components/session/InteractionInference.tsx index c8b5ba3494..9756d021a2 100644 --- a/frontend/src/components/session/InteractionInference.tsx +++ b/frontend/src/components/session/InteractionInference.tsx @@ -171,7 +171,7 @@ export const InteractionInference: FC<{ {serverConfig?.filestore_prefix && imageURLs .filter((file) => { - return account.token ? true : false; + return account.user ? true : false; }) .map((imageURL: string) => { const useURL = getFileURL(imageURL); diff --git a/frontend/src/components/spec-tasks/DesignReviewContent.tsx b/frontend/src/components/spec-tasks/DesignReviewContent.tsx index 4148902249..a7be98ef83 100644 --- a/frontend/src/components/spec-tasks/DesignReviewContent.tsx +++ b/frontend/src/components/spec-tasks/DesignReviewContent.tsx @@ -217,14 +217,15 @@ export default function DesignReviewContent({ // Always subscribe when viewing a spec task - that way we're already connected when comments are created useEffect(() => { // [DRWS-DEBUG] Log subscription decision + // With BFF auth, session cookie is automatically sent with WebSocket connections console.log('[DRWS-DEBUG] Subscription check:', { planningSessionId, - hasToken: !!account.token, - willSubscribe: !!(planningSessionId && account.token), + hasUser: !!account.user, + willSubscribe: !!(planningSessionId && account.user), }) - if (!planningSessionId || !account.token) { - console.log('[DRWS-DEBUG] Not subscribing - missing planningSessionId or token') + if (!planningSessionId || !account.user) { + console.log('[DRWS-DEBUG] Not subscribing - missing planningSessionId or user') return } diff --git a/frontend/src/components/system/Sidebar.tsx b/frontend/src/components/system/Sidebar.tsx index 26c90383e0..90052d4d74 100644 --- a/frontend/src/components/system/Sidebar.tsx +++ b/frontend/src/components/system/Sidebar.tsx @@ -21,7 +21,6 @@ import useLightTheme from '../../hooks/useLightTheme' import useRouter from '../../hooks/useRouter' import useAccount from '../../hooks/useAccount' import useApp from '../../hooks/useApp' -import useApi from '../../hooks/useApi' import { useCreateFilestoreFolder, useUploadFilestoreFiles, useFilestoreConfig } from '../../services/filestoreService' import DarkDialog from '../dialog/DarkDialog' @@ -104,13 +103,10 @@ const SidebarContentInner: React.FC<{ const router = useRouter() - const api = useApi() const account = useAccount() const appTools = useApp(params.app_id) const snackbar = useSnackbar() - const apiClient = api.getApiClient() - // New file menu state const [menuAnchorEl, setMenuAnchorEl] = useState(null) const [createFolderDialogOpen, setCreateFolderDialogOpen] = useState(false) @@ -124,25 +120,6 @@ const SidebarContentInner: React.FC<{ const uploadFilesMutation = useUploadFilestoreFiles() const { data: filestoreConfig } = useFilestoreConfig() - - - // Ensure apps are loaded when apps tab is selected - useEffect(() => { - const checkAuthAndLoad = async () => { - try { - const authResponse = await apiClient.v1AuthAuthenticatedList() - if (!authResponse.data.authenticated) { - return - } - - } catch (error) { - console.error('[SIDEBAR] Error checking authentication:', error) - } - } - - checkAuthAndLoad() - }, [router.params]) - // Handle create a new chat const handleCreateNew = () => { if (!appTools.app) { diff --git a/frontend/src/contexts/account.tsx b/frontend/src/contexts/account.tsx index 7dc75e929e..9cc7269d9e 100644 --- a/frontend/src/contexts/account.tsx +++ b/frontend/src/contexts/account.tsx @@ -1,6 +1,6 @@ import bluebird from 'bluebird' import { createContext, FC, useCallback, useEffect, useMemo, useState, useContext, ReactNode } from 'react' -import useApi, { TOKEN_REFRESHED_EVENT } from '../hooks/useApi' +import useApi from '../hooks/useApi' import { extractErrorMessage } from '../hooks/useErrorCallback' import useLoading from '../hooks/useLoading' import useRouter from '../hooks/useRouter' @@ -25,8 +25,6 @@ export interface IAccountContext { isOrgMember: boolean, user?: IKeycloakUser, userMeta?: { slug: string }, // User metadata including slug for GitHub-style URLs - token?: string, - tokenUrlEscaped?: string, loggingOut?: boolean, serverConfig: IServerConfig, userConfig: IUserConfig, @@ -114,15 +112,6 @@ export const useAccountContext = (): IAccountContext => { const [providerEndpoints, setProviderEndpoints] = useState([]) const [hasImageModels, setHasImageModels] = useState(false) - const token = useMemo(() => { - if (user && user.token) { - return user.token - } else { - return '' - } - }, [ - user, - ]) const isOrgAdmin = useMemo(() => { if(admin) return true @@ -152,10 +141,6 @@ export const useAccountContext = (): IAccountContext => { isOrgAdmin, ]) - const tokenUrlEscaped = useMemo(() => { - if (!token) return ''; - return encodeURIComponent(token); - }, [token]); const loadStatus = useCallback(async () => { try { @@ -291,10 +276,6 @@ export const useAccountContext = (): IAccountContext => { const onLogout = useCallback(async () => { setLoggingOut(true) - - // Clear the in-memory token BEFORE redirecting to logout - // This prevents stale tokens from being used if the redirect doesn't fully reload the page - api.setToken('') setUser(undefined) try { @@ -346,7 +327,6 @@ export const useAccountContext = (): IAccountContext => { if (authenticated.data.authenticated) { const userResponse = await client.v1AuthUserList() const user = userResponse.data as IKeycloakUser - api.setToken(user.token) const win = (window as any) if (win.setUser) { @@ -361,11 +341,6 @@ export const useAccountContext = (): IAccountContext => { } setUser(user) - - // Token refresh is handled transparently by the backend - // When the backend refreshes a token, it sends X-Token-Refreshed header - // The useApi interceptor catches this and dispatches TOKEN_REFRESHED_EVENT - // We listen for that event here to update React state } } catch (e) { const errorMessage = extractErrorMessage(e) @@ -445,22 +420,6 @@ export const useAccountContext = (): IAccountContext => { initialize() }, []) - // Listen for token refresh events from useApi interceptor - // When the backend transparently refreshes a token, we need to update React state - useEffect(() => { - const handleTokenRefreshed = (event: Event) => { - const customEvent = event as CustomEvent<{ token: string }> - const newToken = customEvent.detail?.token - if (newToken && user) { - console.log('[AUTH] Token refreshed by backend, updating React state') - setUser(prevUser => prevUser ? { ...prevUser, token: newToken } : prevUser) - api.setToken(newToken) - } - } - - window.addEventListener(TOKEN_REFRESHED_EVENT, handleTokenRefreshed) - return () => window.removeEventListener(TOKEN_REFRESHED_EVENT, handleTokenRefreshed) - }, [user, api]) useEffect(() => { try { @@ -479,8 +438,6 @@ export const useAccountContext = (): IAccountContext => { initialized, user, userMeta, - token, - tokenUrlEscaped, admin, loggingOut, serverConfig, diff --git a/frontend/src/contexts/streaming.tsx b/frontend/src/contexts/streaming.tsx index 031c0b9bea..6365af8dd8 100644 --- a/frontend/src/contexts/streaming.tsx +++ b/frontend/src/contexts/streaming.tsx @@ -1,12 +1,17 @@ import React, { createContext, useContext, ReactNode, useState, useCallback, useEffect, useRef } from 'react'; import ReconnectingWebSocket from 'reconnecting-websocket'; import { IWebsocketEvent, WEBSOCKET_EVENT_TYPE_WORKER_TASK_RESPONSE, WORKER_TASK_RESPONSE_TYPE_PROGRESS, ISessionChatRequest, ISessionType, IAgentType } from '../types'; -import useAccount from '../hooks/useAccount'; import { TypesInteraction, TypesMessage, TypesSession } from '../api/api'; import { GET_SESSION_QUERY_KEY, SESSION_STEPS_QUERY_KEY } from '../services/sessionService'; import { useQueryClient } from '@tanstack/react-query'; import { invalidateSessionsQuery } from '../services/sessionService'; +// CSRF helper - reads the CSRF token from the helix_csrf cookie +const getCSRFToken = (): string | null => { + const match = document.cookie.match(/(^| )helix_csrf=([^;]+)/) + return match ? decodeURIComponent(match[2]) : null +} + interface NewInferenceParams { regenerate?: boolean; type: ISessionType; @@ -47,7 +52,6 @@ export const useStreaming = (): StreamingContextType => { }; export const StreamingContextProvider: React.FC<{ children: ReactNode }> = ({ children }) => { - const account = useAccount(); const queryClient = useQueryClient(); const [currentResponses, setCurrentResponses] = useState>>(new Map()); const [currentSessionId, setCurrentSessionId] = useState(null); @@ -288,7 +292,8 @@ export const StreamingContextProvider: React.FC<{ children: ReactNode }> = ({ ch }, [currentSessionId]); useEffect(() => { - if (!account.token || !currentSessionId) return; + // With BFF auth, session cookie is automatically sent with WebSocket connections + if (!currentSessionId) return; const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; const wsHost = window.location.host; @@ -329,7 +334,7 @@ export const StreamingContextProvider: React.FC<{ children: ReactNode }> = ({ ch invalidateTimerRef.current = null; } }; - }, [account.token, currentSessionId, handleWebsocketEvent, queryClient]); + }, [currentSessionId]); const NewInference = async ({ regenerate = false, @@ -457,12 +462,20 @@ export const StreamingContextProvider: React.FC<{ children: ReactNode }> = ({ ch }); try { + // With BFF auth, session cookie is sent automatically with same-origin requests + // Include CSRF token for protection against cross-site request forgery + const csrfToken = getCSRFToken() + const headers: Record = { + 'Content-Type': 'application/json', + } + if (csrfToken) { + headers['X-CSRF-Token'] = csrfToken + } + const response = await fetch('/api/v1/sessions/chat', { method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${account.token}`, - }, + headers, + credentials: 'same-origin', body: JSON.stringify(sessionChatRequest), }); diff --git a/frontend/src/hooks/useApi.test.ts b/frontend/src/hooks/useApi.test.ts deleted file mode 100644 index 420db63678..0000000000 --- a/frontend/src/hooks/useApi.test.ts +++ /dev/null @@ -1,374 +0,0 @@ -import { describe, it, expect, vi, beforeEach } from 'vitest'; - -/** - * Integration test for the 401 token refresh interceptor. - * - * This simulates the exact scenario: user opens laptop after overnight sleep, - * access token is expired, but refresh token is still valid. - * - * Flow: - * 1. API call with expired token → 401 - * 2. Interceptor catches 401 - * 3. Calls /api/v1/auth/refresh (backend uses refresh_token cookie) - * 4. Calls /api/v1/auth/user to get new token - * 5. Updates in-memory token (axios.defaults + securityData) - * 6. Deletes old Authorization header - * 7. Retries original request - * 8. Success - */ -describe('Token Refresh Interceptor - Integration Test', () => { - - it('should refresh token and retry request when 401 is received (laptop resume scenario)', async () => { - // Track what happened during the test - const callLog: string[] = []; - let retryCount = 0; - - // Simulate the axios instance and interceptor behavior - const mockAxiosInstance = { - request: vi.fn(), - }; - - // Mock API responses - const mockRefreshCreate = vi.fn().mockImplementation(async () => { - callLog.push('refresh-called'); - return { status: 204 }; // Backend sets new cookies - }); - - const mockUserList = vi.fn().mockImplementation(async () => { - callLog.push('user-list-called'); - return { data: { token: 'new-fresh-token', id: 'user-123' } }; - }); - - // Simulate the interceptor logic (extracted for testing) - const simulateInterceptor = async (error: any) => { - const originalRequest = error.config; - - if (error.response?.status === 401) { - const url = originalRequest?.url || ''; - const isAuthEndpoint = url.includes('/api/v1/auth/'); - - if (!isAuthEndpoint && !originalRequest._retry) { - originalRequest._retry = true; - callLog.push('interceptor-caught-401'); - - // Attempt refresh - try { - await mockRefreshCreate(); - callLog.push('refresh-succeeded'); - - // Get new token - const userResponse = await mockUserList(); - const newToken = userResponse.data?.token; - - if (newToken) { - callLog.push(`token-updated:${newToken}`); - } - - // Delete old Authorization header - delete originalRequest.headers['Authorization']; - callLog.push('old-auth-header-deleted'); - - // Retry the original request - retryCount++; - callLog.push('request-retried'); - - // Simulate successful retry - return { data: { success: true }, status: 200 }; - } catch (refreshError) { - callLog.push('refresh-failed'); - throw error; - } - } - } - - throw error; - }; - - // Simulate the original request that gets a 401 - const originalRequest = { - url: '/api/v1/sessions', - method: 'GET', - headers: { - 'Authorization': 'Bearer expired-token-from-8-hours-ago', - 'Content-Type': 'application/json', - }, - _retry: false, - }; - - const error401 = { - response: { status: 401, data: { error: 'Token expired' } }, - config: originalRequest, - }; - - // Execute the interceptor - const result = await simulateInterceptor(error401); - - // Verify the complete flow happened in order - expect(callLog).toEqual([ - 'interceptor-caught-401', - 'refresh-called', - 'refresh-succeeded', - 'user-list-called', - 'token-updated:new-fresh-token', - 'old-auth-header-deleted', - 'request-retried', - ]); - - // Verify the request was retried exactly once - expect(retryCount).toBe(1); - - // Verify the retry succeeded - expect(result.status).toBe(200); - expect(result.data.success).toBe(true); - - // Verify old Authorization header was removed - expect(originalRequest.headers['Authorization']).toBeUndefined(); - - // Verify _retry flag was set (prevents infinite loop) - expect(originalRequest._retry).toBe(true); - }); - - it('should NOT retry if refresh token is also expired (hard logout scenario)', async () => { - const callLog: string[] = []; - - const mockRefreshCreate = vi.fn().mockImplementation(async () => { - callLog.push('refresh-called'); - throw new Error('Refresh token expired'); // This happens when refresh token is invalid - }); - - const simulateInterceptor = async (error: any) => { - const originalRequest = error.config; - - if (error.response?.status === 401) { - const url = originalRequest?.url || ''; - const isAuthEndpoint = url.includes('/api/v1/auth/'); - - if (!isAuthEndpoint && !originalRequest._retry) { - originalRequest._retry = true; - callLog.push('interceptor-caught-401'); - - try { - await mockRefreshCreate(); - } catch (refreshError) { - callLog.push('refresh-failed'); - throw error; // Propagate original 401 - will trigger logout - } - } - } - - throw error; - }; - - const originalRequest = { - url: '/api/v1/sessions', - method: 'GET', - headers: { 'Authorization': 'Bearer expired-token' }, - _retry: false, - }; - - const error401 = { - response: { status: 401, data: { error: 'Token expired' } }, - config: originalRequest, - }; - - // Should throw the original error (triggers logout in account context) - await expect(simulateInterceptor(error401)).rejects.toEqual(error401); - - expect(callLog).toEqual([ - 'interceptor-caught-401', - 'refresh-called', - 'refresh-failed', - ]); - }); - - it('should NOT attempt refresh for auth endpoints (prevents infinite loop)', async () => { - const callLog: string[] = []; - const mockRefreshCreate = vi.fn(); - - const simulateInterceptor = async (error: any) => { - const originalRequest = error.config; - - if (error.response?.status === 401) { - const url = originalRequest?.url || ''; - const isAuthEndpoint = url.includes('/api/v1/auth/'); - - if (isAuthEndpoint) { - callLog.push('skipped-auth-endpoint'); - } - - if (!isAuthEndpoint && !originalRequest._retry) { - await mockRefreshCreate(); - callLog.push('refresh-called'); - } - } - - throw error; - }; - - // Simulate 401 on the refresh endpoint itself - const refreshEndpointError = { - response: { status: 401 }, - config: { - url: '/api/v1/auth/refresh', - headers: {}, - _retry: false, - }, - }; - - await expect(simulateInterceptor(refreshEndpointError)).rejects.toBeDefined(); - - expect(callLog).toEqual(['skipped-auth-endpoint']); - expect(mockRefreshCreate).not.toHaveBeenCalled(); - }); -}); - -// Test the core refresh logic in isolation -describe('Token Refresh Logic - Unit Tests', () => { - // These tests verify the expected behavior of individual pieces - - describe('attemptTokenRefresh behavior', () => { - it('should return true when refresh succeeds', async () => { - // This tests the pattern: refresh -> fetch user -> update token - const mockRefreshCreate = vi.fn().mockResolvedValue({ status: 204 }); - const mockUserList = vi.fn().mockResolvedValue({ - data: { token: 'new-access-token', id: 'user-123' } - }); - - // Simulate the refresh flow - await mockRefreshCreate(); - const userResponse = await mockUserList(); - const newToken = userResponse.data?.token; - - expect(newToken).toBe('new-access-token'); - expect(mockRefreshCreate).toHaveBeenCalledTimes(1); - expect(mockUserList).toHaveBeenCalledTimes(1); - }); - - it('should handle refresh failure gracefully', async () => { - const mockRefreshCreate = vi.fn().mockRejectedValue( - new Error('Refresh token expired') - ); - - let refreshSucceeded = true; - try { - await mockRefreshCreate(); - } catch { - refreshSucceeded = false; - } - - expect(refreshSucceeded).toBe(false); - }); - }); - - describe('401 interceptor behavior', () => { - it('should not retry auth endpoints to prevent infinite loops', () => { - const url = '/api/v1/auth/refresh'; - const isAuthEndpoint = url.includes('/api/v1/auth/'); - - expect(isAuthEndpoint).toBe(true); - // Auth endpoints should be skipped - }); - - it('should retry non-auth endpoints', () => { - const url = '/api/v1/sessions'; - const isAuthEndpoint = url.includes('/api/v1/auth/'); - - expect(isAuthEndpoint).toBe(false); - // Non-auth endpoints should trigger refresh - }); - - it('should use _retry flag to prevent double retry', () => { - const originalRequest = { _retry: false, url: '/api/v1/sessions' }; - - // First 401 - should retry - expect(originalRequest._retry).toBe(false); - originalRequest._retry = true; - - // Second 401 - should not retry - expect(originalRequest._retry).toBe(true); - }); - - it('should delete Authorization header before retry', () => { - const originalRequest = { - headers: { - 'Authorization': 'Bearer old-expired-token', - 'Content-Type': 'application/json' - } - }; - - // Simulate what the interceptor does after refresh - delete originalRequest.headers['Authorization']; - - expect(originalRequest.headers['Authorization']).toBeUndefined(); - expect(originalRequest.headers['Content-Type']).toBe('application/json'); - }); - }); - - describe('race condition handling', () => { - it('should deduplicate concurrent refresh attempts', async () => { - let isRefreshing = false; - let refreshPromise: Promise | null = null; - let refreshCallCount = 0; - - const attemptRefresh = async (): Promise => { - refreshCallCount++; - await new Promise(resolve => setTimeout(resolve, 100)); - return true; - }; - - const handleRequest = async (): Promise => { - if (isRefreshing && refreshPromise) { - // Wait for existing refresh - return refreshPromise; - } - - isRefreshing = true; - refreshPromise = attemptRefresh(); - - try { - return await refreshPromise; - } finally { - isRefreshing = false; - refreshPromise = null; - } - }; - - // Simulate 5 concurrent 401 responses - const results = await Promise.all([ - handleRequest(), - handleRequest(), - handleRequest(), - handleRequest(), - handleRequest(), - ]); - - // All should succeed - expect(results.every(r => r === true)).toBe(true); - // But only 1-2 actual refresh calls (first one + possibly one more if timing is tight) - expect(refreshCallCount).toBeLessThanOrEqual(2); - }); - }); - - describe('token update after refresh', () => { - it('should update both axios defaults and security data', () => { - const newToken = 'new-access-token'; - const axiosDefaults = { headers: { common: {} as Record } }; - let securityData: { token: string } | null = null; - - const setSecurityData = (data: { token: string }) => { - securityData = data; - }; - - const getTokenHeaders = (token: string) => ({ - Authorization: `Bearer ${token}`, - }); - - // Simulate what attemptTokenRefresh does after getting new token - axiosDefaults.headers.common = getTokenHeaders(newToken); - setSecurityData({ token: newToken }); - - expect(axiosDefaults.headers.common['Authorization']).toBe('Bearer new-access-token'); - expect(securityData?.token).toBe('new-access-token'); - }); - }); -}); diff --git a/frontend/src/hooks/useApi.ts b/frontend/src/hooks/useApi.ts index 382a1b2b62..cb64740963 100644 --- a/frontend/src/hooks/useApi.ts +++ b/frontend/src/hooks/useApi.ts @@ -26,92 +26,41 @@ export interface IApiOptions { errorCapture?: (err: string) => void, } -export const getTokenHeaders = (token: string) => { - return { - Authorization: `Bearer ${token}`, - } -} - -type SecurityDataType = { token: string } - // Create a singleton instance of the API client -// This ensures it's only initialized once, regardless of how many components use the hook +// With BFF pattern, no security worker needed - cookies are sent automatically const apiClientSingleton = new Api({ baseURL: window.location.origin, secure: true, - securityWorker: (securityData: SecurityDataType | null) => { - if (securityData && securityData.token) { - return { - headers: { - Authorization: `Bearer ${securityData.token}`, - } - } - } - return {} - } + withCredentials: true, // Required for BFF pattern - send session cookies with requests + // No securityWorker needed - session cookie is sent automatically }) -// Custom event name for token refresh - account.tsx listens for this -export const TOKEN_REFRESHED_EVENT = 'helix-token-refreshed' +// Configure axios to send cookies with requests (same-origin) +axios.defaults.withCredentials = true -// Helper function to handle X-Token-Refreshed header from backend -// This is called for both successful and error responses -const handleTokenRefreshHeader = (headers: Record | undefined) => { - if (!headers) return - const newToken = headers['x-token-refreshed'] - if (newToken) { - console.log('[API] Token refreshed transparently by backend, updating all token locations') +// CSRF Protection: Add X-CSRF-Token header for state-changing requests +// The CSRF token is stored in the helix_csrf cookie (readable by JS) +const CSRF_COOKIE_NAME = 'helix_csrf' +const CSRF_HEADER_NAME = 'X-CSRF-Token' - // Update axios defaults (for raw axios calls) - axios.defaults.headers.common = getTokenHeaders(newToken) - - // Update OpenAPI client security data - apiClientSingleton.setSecurityData({ token: newToken }) - - // Also update the client instance headers directly (matches setToken behavior) - try { - apiClientSingleton.instance.defaults.headers.common['Authorization'] = `Bearer ${newToken}` - } catch (e) { - console.error('[API] Failed to set token directly on client instance:', e) - } - - // Update localStorage for direct fetch() calls - localStorage.setItem('token', newToken) - - // Dispatch event so account.tsx can update React state - window.dispatchEvent(new CustomEvent(TOKEN_REFRESHED_EVENT, { detail: { token: newToken } })) - } +// Helper to read a cookie value by name +const getCookie = (name: string): string | null => { + const match = document.cookie.match(new RegExp('(^| )' + name + '=([^;]+)')) + return match ? decodeURIComponent(match[2]) : null } -// Add response interceptor to handle X-Token-Refreshed header from backend -// The backend transparently refreshes expired tokens and sends the new token in this header -// We update all frontend token storage locations and dispatch an event for React state -apiClientSingleton.instance.interceptors.response.use( - (response) => { - handleTokenRefreshHeader(response.headers as Record) - return response - }, - (error) => { - // Also check for X-Token-Refreshed in error responses - // The backend might refresh the token but the request could still fail for other reasons - // (e.g., user not authorized for a specific resource) - handleTokenRefreshHeader(error.response?.headers as Record) - return Promise.reject(error) - } -) - -// Also add interceptor to global axios instance for raw axios.get/post/etc calls -// These are used by some legacy code paths (e.g., loadStatus in account.tsx) -axios.interceptors.response.use( - (response) => { - handleTokenRefreshHeader(response.headers as Record) - return response - }, - (error) => { - handleTokenRefreshHeader(error.response?.headers as Record) - return Promise.reject(error) +// Add CSRF token to state-changing requests +axios.interceptors.request.use((config) => { + const method = config.method?.toUpperCase() + // Only add CSRF header for state-changing methods + if (method === 'POST' || method === 'PUT' || method === 'DELETE' || method === 'PATCH') { + const csrfToken = getCookie(CSRF_COOKIE_NAME) + if (csrfToken) { + config.headers[CSRF_HEADER_NAME] = csrfToken + } } -) + return config +}) // Helper function to check if an error is auth-related const isAuthError = (error: any): boolean => { @@ -119,7 +68,7 @@ const isAuthError = (error: any): boolean => { if (error.response?.status === 401 || error.response?.status === 403) { return true } - + // Check error message for common auth failure patterns const errorMessage = extractErrorMessage(error).toLowerCase() const authErrorPatterns = [ @@ -133,7 +82,7 @@ const isAuthError = (error: any): boolean => { 'invalid token', 'expired token' ] - + return authErrorPatterns.some(pattern => errorMessage.includes(pattern)) } @@ -185,19 +134,13 @@ export const useApi = () => { const put = useCallback(async function(url: string, data: ReqT, axiosConfig?: AxiosRequestConfig, options?: IApiOptions): Promise { if(options?.loading === true) loading.setLoading(true) try { - console.log('Sending PUT request to:', `${API_MOUNT}${url}`); - console.log('Request data:', data); const res = await axios.put(`${API_MOUNT}${url}`, data, axiosConfig) if(res.status >= 400) { - console.error(`API Error: ${res.status} ${res.statusText}`); - console.error('Response data:', res.data); throw new Error(`${res.status} ${res.statusText}`) } if(options?.loading === true) loading.setLoading(false) return res.data } catch (e: any) { - console.error('Full error object:', e); - console.error('Error response:', e.response); const errorMessage = extractErrorMessage(e) console.error(errorMessage) options?.errorCapture?.(errorMessage) @@ -205,7 +148,6 @@ export const useApi = () => { const safeErrorMsg = typeof errorMessage === 'string' ? errorMessage : 'An error occurred' snackbar.setSnackbar(safeErrorMsg, 'error') reportError(new Error(safeErrorMsg)) - // Throw the error anyways throw e } if(options?.loading === true) loading.setLoading(false) @@ -233,34 +175,6 @@ export const useApi = () => { } }, []) - // this will work globally because we are applying this to the root import of axios - // therefore we don't need to worry about passing the token around to other contexts - // we can just call useApi() from anywhere and we will get the token injected into the request - // because the top level account context has called this - const setToken = useCallback(function(token: string) { - axios.defaults.headers.common = token ? getTokenHeaders(token) : {} - - // Set token for OpenAPI client - apiClientSingleton.setSecurityData({ - token: token, - }); - - // Force a direct modification of the client instance's default headers as a fallback - try { - apiClientSingleton.instance.defaults.headers.common['Authorization'] = `Bearer ${token}`; - } catch (e) { - console.error('Failed to set token directly on client instance:', e); - } - - // Also set in localStorage for direct fetch() calls (e.g., filestoreService.ts) - // This ensures all API call paths have access to the current token - if (token) { - localStorage.setItem('token', token) - } else { - localStorage.removeItem('token') - } - }, []) - const getApiClient = useCallback(() => { return apiClientSingleton.api }, []) @@ -274,10 +188,9 @@ export const useApi = () => { post, put, delete: del, - setToken, getApiClient, getV1Client, } } -export default useApi \ No newline at end of file +export default useApi diff --git a/frontend/src/hooks/useKnowledge.ts b/frontend/src/hooks/useKnowledge.ts index 016edf12fd..102e17822a 100644 --- a/frontend/src/hooks/useKnowledge.ts +++ b/frontend/src/hooks/useKnowledge.ts @@ -44,9 +44,9 @@ export const useKnowledge = ({ }) => { const api = useApi() const snackbar = useSnackbar() - const filestore = useFilestore() const account = useAccount() - + const filestore = useFilestore() + const [expanded, setExpanded] = useState(false); const [errors, setErrors] = useState>({}); @@ -577,23 +577,16 @@ export const useKnowledge = ({ } try { - if (!account.token) { - snackbar.error('Must be logged in to download files') - return - } - // Create a temporary link to trigger the download const downloadUrl = `/api/v1/knowledge/${id}/download` const link = document.createElement('a') link.href = downloadUrl link.setAttribute('download', `${source.name}-files.zip`) - - // Set auth header by creating a fetch request instead of direct link + + // With BFF auth, session cookie is sent automatically with same-origin requests const response = await fetch(downloadUrl, { method: 'GET', - headers: { - 'Authorization': `Bearer ${account.token}`, - }, + credentials: 'same-origin', }) if (!response.ok) { diff --git a/frontend/src/hooks/useWebsocket.ts b/frontend/src/hooks/useWebsocket.ts index 7facb192b3..73d2bc71d9 100644 --- a/frontend/src/hooks/useWebsocket.ts +++ b/frontend/src/hooks/useWebsocket.ts @@ -1,6 +1,5 @@ import React, { useEffect, useRef } from 'react' import ReconnectingWebSocket from 'reconnecting-websocket' -import useAccount from '../hooks/useAccount' import { IWebsocketEvent, @@ -12,7 +11,6 @@ export const useWebsocket = ( (ev: IWebsocketEvent): void, }, ) => { - const account = useAccount() const wsRef = useRef() const messageQueue = useRef([]) const processingRef = useRef(false) @@ -33,12 +31,12 @@ export const useWebsocket = ( } useEffect(() => { - if(!account.token) return + // With BFF auth, session cookie is automatically sent with WebSocket connections if(!session_id) return const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:' const wsHost = window.location.host const url = `${wsProtocol}//${wsHost}/api/v1/ws/user?session_id=${session_id}` - + const rws = new ReconnectingWebSocket(url, [], { maxRetries: 10, reconnectionDelayGrowFactor: 1.3, @@ -50,7 +48,7 @@ export const useWebsocket = ( const messageHandler = (event: MessageEvent) => { const parsedData = JSON.parse(event.data) as IWebsocketEvent - + if(parsedData.session_id != session_id) { return } @@ -70,10 +68,7 @@ export const useWebsocket = ( wsRef.current.close() } } - }, [ - account.token, - session_id, - ]) + }, [session_id]) } export default useWebsocket \ No newline at end of file diff --git a/frontend/src/pages/Account.tsx b/frontend/src/pages/Account.tsx index 72475f5c4c..f2afcf255e 100644 --- a/frontend/src/pages/Account.tsx +++ b/frontend/src/pages/Account.tsx @@ -171,13 +171,11 @@ const Account: FC = () => { }, []) useEffect(() => { - if (!account.token) { + if (!account.user) { return } // API keys are now loaded automatically via React Query hooks - }, [ - account.token, - ]) + }, [account.user]) useEffect(() => { setFullName(account.user?.name || '') diff --git a/frontend/src/pages/Layout.tsx b/frontend/src/pages/Layout.tsx index b38d9796d6..2828ab335d 100644 --- a/frontend/src/pages/Layout.tsx +++ b/frontend/src/pages/Layout.tsx @@ -37,7 +37,6 @@ import useLightTheme from '../hooks/useLightTheme' import useThemeConfig from '../hooks/useThemeConfig' import useIsBigScreen from '../hooks/useIsBigScreen' import useApps from '../hooks/useApps' -import useApi from '../hooks/useApi' import useUserMenuHeight from '../hooks/useUserMenuHeight' import { useGetConfig } from '../services/userService' import { TypesAuthProvider } from '../api/api' @@ -52,13 +51,11 @@ const Layout: FC<{ const lightTheme = useLightTheme() const isBigScreen = useIsBigScreen() const router = useRouter() - const api = useApi() const account = useAccount() const apps = useApps() const floatingRunnerState = useFloatingRunnerState() const floatingModal = useFloatingModal() const [showVersionBanner, setShowVersionBanner] = useState(true) - const [isAuthenticated, setIsAuthenticated] = useState(false) const userMenuHeight = useUserMenuHeight() const { data: config } = useGetConfig() @@ -152,25 +149,11 @@ const Layout: FC<{ let sidebarMenu = null const isOrgMenu = router.meta.menu == 'orgs' - const apiClient = api.getApiClient() - // Determine which resource type to use // 1. Use resource_type from URL params if available // 2. If app_id is present in the URL, default to 'apps' // 3. Otherwise default to 'chat' - const resourceType = router.params.resource_type || (router.params.app_id ? 'apps' : 'chat') - - // This useEffect handles registering/updating the menu - React.useEffect(() => { - const checkAuthAndLoad = async () => { - const authResponse = await apiClient.v1AuthAuthenticatedList() - if (!authResponse.data.authenticated) { - return - } - setIsAuthenticated(true) - } - checkAuthAndLoad() - }, [resourceType]) + const resourceType = router.params.resource_type || (router.params.app_id ? 'apps' : 'chat')