@@ -908,8 +908,8 @@ func getSGUserRolesForAudit(db *db.DatabaseContext, user auth.User) ([]string, e
908908 return roleNames , nil
909909}
910910
911- // checkPublicAuth verifies that the current request is authenticated for the given database.
912- //
911+ // checkPublicAuth verifies that the current request is authenticated for the given database. Returns an HTTPError if
912+ // authentication fails.
913913// NOTE: checkPublicAuth is not used for the admin interface.
914914func (h * handler ) checkPublicAuth (dbCtx * db.DatabaseContext ) (err error ) {
915915
@@ -918,59 +918,69 @@ func (h *handler) checkPublicAuth(dbCtx *db.DatabaseContext) (err error) {
918918 return nil
919919 }
920920
921- var auditFields base.AuditFields
922-
923- // Record Auth stats
924- defer func (t time.Time ) {
925- delta := time .Since (t ).Nanoseconds ()
926- dbCtx .DbStats .Security ().TotalAuthTime .Add (delta )
927- if err != nil {
928- dbCtx .DbStats .Security ().AuthFailedCount .Add (1 )
929- if errors .Is (err , ErrInvalidLogin ) {
930- base .Audit (h .ctx (), base .AuditIDPublicUserAuthenticationFailed , auditFields )
931- }
932- } else {
933- dbCtx .DbStats .Security ().AuthSuccessCount .Add (1 )
934-
935- username := ""
936- if h .isGuest () {
937- username = base .GuestUsername
938- } else if h .user != nil {
939- username = h .user .Name ()
940- }
941- roleNames , err := getSGUserRolesForAudit (dbCtx , h .user )
942- if err != nil {
943- base .InfofCtx (h .ctx (), base .KeyHTTP , "Error getting user roles for audit logging: %v" , err )
944- }
945- h .rqCtx = base .UserLogCtx (h .ctx (), username , base .UserDomainSyncGateway , roleNames )
946- base .Audit (h .ctx (), base .AuditIDPublicUserAuthenticated , auditFields )
921+ start := time .Now ()
922+ auditFields , err := h .setUserForPublicAuth (dbCtx )
923+ dbCtx .DbStats .Security ().TotalAuthTime .Add (time .Since (start ).Nanoseconds ())
924+ if err != nil {
925+ dbCtx .DbStats .Security ().AuthFailedCount .Add (1 )
926+ if errors .Is (err , ErrInvalidLogin ) {
927+ base .Audit (h .ctx (), base .AuditIDPublicUserAuthenticationFailed , auditFields )
947928 }
948- }(time .Now ())
929+ return
930+ }
931+ dbCtx .DbStats .Security ().AuthSuccessCount .Add (1 )
932+
933+ username := ""
934+ if h .isGuest () {
935+ username = base .GuestUsername
936+ } else if h .user != nil {
937+ username = h .user .Name ()
938+ }
939+ roleNames , err := getSGUserRolesForAudit (dbCtx , h .user )
940+ if err != nil {
941+ base .InfofCtx (h .ctx (), base .KeyHTTP , "Error getting user roles for audit logging: %v" , err )
942+ }
943+ h .rqCtx = base .UserLogCtx (h .ctx (), username , base .UserDomainSyncGateway , roleNames )
944+ base .Audit (h .ctx (), base .AuditIDPublicUserAuthenticated , auditFields )
945+ return err
946+ }
949947
948+ // setUserForPublicAuth sets h.user based on the authentication information in the request. Returns an error if the user
949+ // can not authenticate successfully, and returns AuditFields even in the case that there is an error in the request.
950+ //
951+ // Uses:
952+ //
953+ // 1. Bearer token (OIDC JWT) if present and OIDC is enabled
954+ // 2. Basic auth if present and password authentication is not disabled
955+ // 3. Cookie auth if present
956+ // 4. Guest access if enabled
957+ func (h * handler ) setUserForPublicAuth (dbCtx * db.DatabaseContext ) (base.AuditFields , error ) {
958+ var auditFields base.AuditFields
950959 // If oidc enabled, check for bearer ID token
951960 if dbCtx .Options .OIDCOptions != nil || len (dbCtx .LocalJWTProviders ) > 0 {
952961 if token := h .getBearerToken (); token != "" {
953962 auditFields = base.AuditFields {base .AuditFieldAuthMethod : "bearer" }
954963 var updates auth.PrincipalConfig
964+ var err error
955965 h .user , updates , err = dbCtx .Authenticator (h .ctx ()).AuthenticateUntrustedJWT (token , dbCtx .OIDCProviders , dbCtx .LocalJWTProviders , h .getOIDCCallbackURL )
956966 if h .user == nil || err != nil {
957- return ErrInvalidLogin
967+ return auditFields , ErrInvalidLogin
958968 }
959969 if issuer := h .user .JWTIssuer (); issuer != "" {
960970 auditFields ["oidc_issuer" ] = issuer
961971 }
962972 if changes := checkJWTIssuerStillValid (h .ctx (), dbCtx , h .user ); changes != nil {
963973 updates = updates .Merge (* changes )
964974 }
965- _ , _ , err : = dbCtx .UpdatePrincipal (h .ctx (), & updates , true , true )
975+ _ , _ , err = dbCtx .UpdatePrincipal (h .ctx (), & updates , true , true )
966976 if err != nil {
967- return fmt .Errorf ("failed to update OIDC user after sign-in: %w" , err )
977+ return auditFields , fmt .Errorf ("failed to update OIDC user after sign-in: %w" , err )
968978 }
969979 // TODO: could avoid this extra fetch if UpdatePrincipal returned the newly updated principal
970980 if updates .Name != nil {
971981 h .user , err = dbCtx .Authenticator (h .ctx ()).GetUser (* updates .Name )
972982 }
973- return err
983+ return auditFields , err
974984 }
975985
976986 /*
@@ -985,8 +995,7 @@ func (h *handler) checkPublicAuth(dbCtx *db.DatabaseContext) (err error) {
985995 provider := dbCtx .Options .OIDCOptions .Providers .GetProviderForIssuer (h .ctx (), issuerUrlForDB (h , dbCtx .Name ), testProviderAudiences )
986996 if provider != nil && provider .ValidationKey != nil {
987997 if base .ValDefault (provider .ClientID , "" ) == username && * provider .ValidationKey == password {
988- auditFields = base.AuditFields {base .AuditFieldAuthMethod : "basic" }
989- return nil
998+ return base.AuditFields {base .AuditFieldAuthMethod : "basic" }, nil
990999 }
9911000 }
9921001 }
@@ -996,47 +1005,49 @@ func (h *handler) checkPublicAuth(dbCtx *db.DatabaseContext) (err error) {
9961005 // Check basic auth first
9971006 if ! dbCtx .Options .DisablePasswordAuthentication {
9981007 if userName , password := h .getBasicAuth (); userName != "" {
999- auditFields = base.AuditFields {base .AuditFieldAuthMethod : "basic" }
1008+ auditFields := base.AuditFields {base .AuditFieldAuthMethod : "basic" }
1009+ var err error
10001010 h .user , err = dbCtx .Authenticator (h .ctx ()).AuthenticateUser (userName , password )
10011011 if err != nil {
1002- return err
1012+ return auditFields , err
10031013 }
10041014 if h .user == nil {
10051015 auditFields ["username" ] = userName
10061016 if dbCtx .Options .SendWWWAuthenticateHeader == nil || * dbCtx .Options .SendWWWAuthenticateHeader {
10071017 h .response .Header ().Set ("WWW-Authenticate" , wwwAuthenticateHeader )
10081018 }
1009- return ErrInvalidLogin
1019+ return auditFields , ErrInvalidLogin
10101020 }
1011- return nil
1021+ return auditFields , nil
10121022 }
10131023 }
10141024
10151025 // Check cookie
10161026 auditFields = base.AuditFields {base .AuditFieldAuthMethod : "cookie" }
1027+ var err error
10171028 h .user , err = dbCtx .Authenticator (h .ctx ()).AuthenticateCookie (h .rq , h .response )
10181029 if err != nil && h .privs != publicPrivs {
1019- return err
1030+ return auditFields , err
10201031 } else if h .user != nil {
1021- return nil
1032+ return auditFields , nil
10221033 }
10231034
10241035 // No auth given -- check guest access
10251036 auditFields = base.AuditFields {base .AuditFieldAuthMethod : "guest" }
10261037 if h .user , err = dbCtx .Authenticator (h .ctx ()).GetUser ("" ); err != nil {
1027- return err
1038+ return auditFields , err
10281039 }
10291040 if h .privs == regularPrivs && h .user .Disabled () {
10301041 if dbCtx .Options .SendWWWAuthenticateHeader == nil || * dbCtx .Options .SendWWWAuthenticateHeader {
10311042 h .response .Header ().Set ("WWW-Authenticate" , wwwAuthenticateHeader )
10321043 }
10331044 if h .providedAuthCredentials () {
1034- return ErrInvalidLogin
1045+ return auditFields , ErrInvalidLogin
10351046 }
1036- return ErrLoginRequired
1047+ return auditFields , ErrLoginRequired
10371048 }
10381049
1039- return nil
1050+ return auditFields , nil
10401051}
10411052
10421053func checkJWTIssuerStillValid (ctx context.Context , dbCtx * db.DatabaseContext , user auth.User ) * auth.PrincipalConfig {
0 commit comments