Skip to content

Commit 0b831e5

Browse files
committed
Bubble back err if role fetch fails
1 parent 0d71f75 commit 0b831e5

File tree

1 file changed

+56
-45
lines changed

1 file changed

+56
-45
lines changed

rest/handler.go

Lines changed: 56 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -908,8 +908,8 @@ func getSGUserRolesForAudit(db *db.DatabaseContext, user auth.User) ([]string, e
908908
return roleNames, nil
909909
}
910910

911-
// checkPublicAuth verifies that the current request is authenticated for the given database.
912-
//
911+
// checkPublicAuth verifies that the current request is authenticated for the given database. Returns an HTTPError if
912+
// authentication fails.
913913
// NOTE: checkPublicAuth is not used for the admin interface.
914914
func (h *handler) checkPublicAuth(dbCtx *db.DatabaseContext) (err error) {
915915

@@ -918,59 +918,69 @@ func (h *handler) checkPublicAuth(dbCtx *db.DatabaseContext) (err error) {
918918
return nil
919919
}
920920

921-
var auditFields base.AuditFields
922-
923-
// Record Auth stats
924-
defer func(t time.Time) {
925-
delta := time.Since(t).Nanoseconds()
926-
dbCtx.DbStats.Security().TotalAuthTime.Add(delta)
927-
if err != nil {
928-
dbCtx.DbStats.Security().AuthFailedCount.Add(1)
929-
if errors.Is(err, ErrInvalidLogin) {
930-
base.Audit(h.ctx(), base.AuditIDPublicUserAuthenticationFailed, auditFields)
931-
}
932-
} else {
933-
dbCtx.DbStats.Security().AuthSuccessCount.Add(1)
934-
935-
username := ""
936-
if h.isGuest() {
937-
username = base.GuestUsername
938-
} else if h.user != nil {
939-
username = h.user.Name()
940-
}
941-
roleNames, err := getSGUserRolesForAudit(dbCtx, h.user)
942-
if err != nil {
943-
base.InfofCtx(h.ctx(), base.KeyHTTP, "Error getting user roles for audit logging: %v", err)
944-
}
945-
h.rqCtx = base.UserLogCtx(h.ctx(), username, base.UserDomainSyncGateway, roleNames)
946-
base.Audit(h.ctx(), base.AuditIDPublicUserAuthenticated, auditFields)
921+
start := time.Now()
922+
auditFields, err := h.setUserForPublicAuth(dbCtx)
923+
dbCtx.DbStats.Security().TotalAuthTime.Add(time.Since(start).Nanoseconds())
924+
if err != nil {
925+
dbCtx.DbStats.Security().AuthFailedCount.Add(1)
926+
if errors.Is(err, ErrInvalidLogin) {
927+
base.Audit(h.ctx(), base.AuditIDPublicUserAuthenticationFailed, auditFields)
947928
}
948-
}(time.Now())
929+
return
930+
}
931+
dbCtx.DbStats.Security().AuthSuccessCount.Add(1)
932+
933+
username := ""
934+
if h.isGuest() {
935+
username = base.GuestUsername
936+
} else if h.user != nil {
937+
username = h.user.Name()
938+
}
939+
roleNames, err := getSGUserRolesForAudit(dbCtx, h.user)
940+
if err != nil {
941+
base.InfofCtx(h.ctx(), base.KeyHTTP, "Error getting user roles for audit logging: %v", err)
942+
}
943+
h.rqCtx = base.UserLogCtx(h.ctx(), username, base.UserDomainSyncGateway, roleNames)
944+
base.Audit(h.ctx(), base.AuditIDPublicUserAuthenticated, auditFields)
945+
return err
946+
}
949947

948+
// setUserForPublicAuth sets h.user based on the authentication information in the request. Returns an error if the user
949+
// can not authenticate successfully, and returns AuditFields even in the case that there is an error in the request.
950+
//
951+
// Uses:
952+
//
953+
// 1. Bearer token (OIDC JWT) if present and OIDC is enabled
954+
// 2. Basic auth if present and password authentication is not disabled
955+
// 3. Cookie auth if present
956+
// 4. Guest access if enabled
957+
func (h *handler) setUserForPublicAuth(dbCtx *db.DatabaseContext) (base.AuditFields, error) {
958+
var auditFields base.AuditFields
950959
// If oidc enabled, check for bearer ID token
951960
if dbCtx.Options.OIDCOptions != nil || len(dbCtx.LocalJWTProviders) > 0 {
952961
if token := h.getBearerToken(); token != "" {
953962
auditFields = base.AuditFields{base.AuditFieldAuthMethod: "bearer"}
954963
var updates auth.PrincipalConfig
964+
var err error
955965
h.user, updates, err = dbCtx.Authenticator(h.ctx()).AuthenticateUntrustedJWT(token, dbCtx.OIDCProviders, dbCtx.LocalJWTProviders, h.getOIDCCallbackURL)
956966
if h.user == nil || err != nil {
957-
return ErrInvalidLogin
967+
return auditFields, ErrInvalidLogin
958968
}
959969
if issuer := h.user.JWTIssuer(); issuer != "" {
960970
auditFields["oidc_issuer"] = issuer
961971
}
962972
if changes := checkJWTIssuerStillValid(h.ctx(), dbCtx, h.user); changes != nil {
963973
updates = updates.Merge(*changes)
964974
}
965-
_, _, err := dbCtx.UpdatePrincipal(h.ctx(), &updates, true, true)
975+
_, _, err = dbCtx.UpdatePrincipal(h.ctx(), &updates, true, true)
966976
if err != nil {
967-
return fmt.Errorf("failed to update OIDC user after sign-in: %w", err)
977+
return auditFields, fmt.Errorf("failed to update OIDC user after sign-in: %w", err)
968978
}
969979
// TODO: could avoid this extra fetch if UpdatePrincipal returned the newly updated principal
970980
if updates.Name != nil {
971981
h.user, err = dbCtx.Authenticator(h.ctx()).GetUser(*updates.Name)
972982
}
973-
return err
983+
return auditFields, err
974984
}
975985

976986
/*
@@ -985,8 +995,7 @@ func (h *handler) checkPublicAuth(dbCtx *db.DatabaseContext) (err error) {
985995
provider := dbCtx.Options.OIDCOptions.Providers.GetProviderForIssuer(h.ctx(), issuerUrlForDB(h, dbCtx.Name), testProviderAudiences)
986996
if provider != nil && provider.ValidationKey != nil {
987997
if base.ValDefault(provider.ClientID, "") == username && *provider.ValidationKey == password {
988-
auditFields = base.AuditFields{base.AuditFieldAuthMethod: "basic"}
989-
return nil
998+
return base.AuditFields{base.AuditFieldAuthMethod: "basic"}, nil
990999
}
9911000
}
9921001
}
@@ -996,47 +1005,49 @@ func (h *handler) checkPublicAuth(dbCtx *db.DatabaseContext) (err error) {
9961005
// Check basic auth first
9971006
if !dbCtx.Options.DisablePasswordAuthentication {
9981007
if userName, password := h.getBasicAuth(); userName != "" {
999-
auditFields = base.AuditFields{base.AuditFieldAuthMethod: "basic"}
1008+
auditFields := base.AuditFields{base.AuditFieldAuthMethod: "basic"}
1009+
var err error
10001010
h.user, err = dbCtx.Authenticator(h.ctx()).AuthenticateUser(userName, password)
10011011
if err != nil {
1002-
return err
1012+
return auditFields, err
10031013
}
10041014
if h.user == nil {
10051015
auditFields["username"] = userName
10061016
if dbCtx.Options.SendWWWAuthenticateHeader == nil || *dbCtx.Options.SendWWWAuthenticateHeader {
10071017
h.response.Header().Set("WWW-Authenticate", wwwAuthenticateHeader)
10081018
}
1009-
return ErrInvalidLogin
1019+
return auditFields, ErrInvalidLogin
10101020
}
1011-
return nil
1021+
return auditFields, nil
10121022
}
10131023
}
10141024

10151025
// Check cookie
10161026
auditFields = base.AuditFields{base.AuditFieldAuthMethod: "cookie"}
1027+
var err error
10171028
h.user, err = dbCtx.Authenticator(h.ctx()).AuthenticateCookie(h.rq, h.response)
10181029
if err != nil && h.privs != publicPrivs {
1019-
return err
1030+
return auditFields, err
10201031
} else if h.user != nil {
1021-
return nil
1032+
return auditFields, nil
10221033
}
10231034

10241035
// No auth given -- check guest access
10251036
auditFields = base.AuditFields{base.AuditFieldAuthMethod: "guest"}
10261037
if h.user, err = dbCtx.Authenticator(h.ctx()).GetUser(""); err != nil {
1027-
return err
1038+
return auditFields, err
10281039
}
10291040
if h.privs == regularPrivs && h.user.Disabled() {
10301041
if dbCtx.Options.SendWWWAuthenticateHeader == nil || *dbCtx.Options.SendWWWAuthenticateHeader {
10311042
h.response.Header().Set("WWW-Authenticate", wwwAuthenticateHeader)
10321043
}
10331044
if h.providedAuthCredentials() {
1034-
return ErrInvalidLogin
1045+
return auditFields, ErrInvalidLogin
10351046
}
1036-
return ErrLoginRequired
1047+
return auditFields, ErrLoginRequired
10371048
}
10381049

1039-
return nil
1050+
return auditFields, nil
10401051
}
10411052

10421053
func checkJWTIssuerStillValid(ctx context.Context, dbCtx *db.DatabaseContext, user auth.User) *auth.PrincipalConfig {

0 commit comments

Comments
 (0)