Skip to content

Commit 5b00cbb

Browse files
committed
Initial version of Sieve JWT Caching
1 parent 1609e32 commit 5b00cbb

File tree

9 files changed

+474
-237
lines changed

9 files changed

+474
-237
lines changed

postgrest.cabal

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ library
4949
PostgREST.App
5050
PostgREST.AppState
5151
PostgREST.Auth
52+
PostgREST.Auth.Jwt
5253
PostgREST.Auth.JwtCache
5354
PostgREST.Auth.Types
55+
PostgREST.Cache.Sieve
5456
PostgREST.CLI
5557
PostgREST.Config
5658
PostgREST.Config.Database
@@ -153,6 +155,9 @@ library
153155
-- https://github.com/kazu-yamamoto/logger/commit/3a71ca70afdbb93d4ecf0083eeba1fbbbcab3fc3
154156
, wai-logger >= 2.4.0
155157
, warp >= 3.3.19 && < 3.4
158+
, stm >= 2.5 && < 3
159+
, stm-hamt >= 1.2 && < 2
160+
, focus >= 1.0 && < 2
156161
-- -fno-spec-constr may help keep compile time memory use in check,
157162
-- see https://gitlab.haskell.org/ghc/ghc/issues/16017#note_219304
158163
-- -optP-Wno-nonportable-include-path

src/PostgREST/AppState.hs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ import Data.IORef (IORef, atomicWriteIORef, newIORef,
5757
readIORef)
5858
import Data.Time.Clock (UTCTime, getCurrentTime)
5959

60-
import PostgREST.Auth.JwtCache (JwtCacheState)
60+
import PostgREST.Auth.JwtCache (JwtCacheState, update)
6161
import PostgREST.Config (AppConfig (..),
6262
addFallbackAppName,
6363
readAppConfig)
@@ -127,7 +127,7 @@ init conf@AppConfig{configLogLevel, configDbPoolSize} = do
127127

128128
observer $ AppStartObs prettyVersion
129129

130-
jwtCacheState <- JwtCache.init
130+
jwtCacheState <- JwtCache.init conf
131131
pool <- initPool conf observer
132132
(sock, adminSock) <- initSockets conf
133133
state' <- initWithPool (sock, adminSock) pool conf jwtCacheState loggerState metricsState observer
@@ -471,10 +471,7 @@ readInDbConfig startingUp appState@AppState{stateObserver=observer} = do
471471
-- After the config has reloaded, jwt-secret might have changed, so
472472
-- if it has changed, it is important to invalidate the jwt cache
473473
-- entries, because they were cached using the old secret
474-
if configJwtSecret conf == configJwtSecret newConf then
475-
pass
476-
else
477-
JwtCache.emptyCache (getJwtCacheState appState) -- atomic O(1) operation
474+
update newConf $ getJwtCacheState appState
478475

479476
if startingUp then
480477
pass

src/PostgREST/Auth.hs

Lines changed: 16 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE RecordWildCards #-}
12
{-|
23
Module : PostgREST.Auth
34
Description : PostgREST authentication functions.
@@ -10,186 +11,52 @@ Authentication should always be implemented in an external service.
1011
In the test suite there is an example of simple login function that can be used for a
1112
very simple authentication system inside the PostgreSQL database.
1213
-}
13-
{-# LANGUAGE LambdaCase #-}
14-
{-# LANGUAGE RecordWildCards #-}
14+
1515
module PostgREST.Auth
1616
( getResult
1717
, getJwtDur
1818
, getRole
1919
, middleware
2020
) where
2121

22-
import qualified Data.Aeson as JSON
23-
import qualified Data.Aeson.Key as K
24-
import qualified Data.Aeson.KeyMap as KM
25-
import qualified Data.Aeson.Types as JSON
2622
import qualified Data.ByteString as BS
27-
import qualified Data.ByteString.Internal as BS
28-
import qualified Data.ByteString.Lazy.Char8 as LBS
29-
import qualified Data.Scientific as Sci
30-
import qualified Data.Text as T
3123
import qualified Data.Vault.Lazy as Vault
32-
import qualified Data.Vector as V
33-
import qualified Jose.Jwk as JWT
34-
import qualified Jose.Jwt as JWT
3524
import qualified Network.HTTP.Types.Header as HTTP
3625
import qualified Network.Wai as Wai
3726
import qualified Network.Wai.Middleware.HttpAuth as Wai
3827

39-
import Control.Monad.Except (liftEither)
40-
import Data.Either.Combinators (mapLeft)
4128
import Data.List (lookup)
42-
import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds)
43-
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
4429
import System.IO.Unsafe (unsafePerformIO)
4530
import System.TimeIt (timeItT)
4631

4732
import PostgREST.AppState (AppState, getConfig, getJwtCacheState,
4833
getTime)
4934
import PostgREST.Auth.JwtCache (lookupJwtCache)
35+
import PostgREST.Auth.Jwt (parseClaims)
5036
import PostgREST.Auth.Types (AuthResult (..))
51-
import PostgREST.Config (AppConfig (..), FilterExp (..),
52-
JSPath, JSPathExp (..))
53-
import PostgREST.Error (Error (..), JwtError (..))
37+
import PostgREST.Config (AppConfig (..))
38+
import PostgREST.Error (Error (..))
5439

5540
import Protolude
5641

57-
-- | Receives the JWT secret and audience (from config) and a JWT and returns a
58-
-- JSON object of JWT claims.
59-
parseToken :: AppConfig -> Maybe ByteString -> UTCTime -> ExceptT Error IO JSON.Value
60-
parseToken _ Nothing _ = return JSON.emptyObject
61-
parseToken _ (Just "") _ = throwE . JwtErr $ JwtDecodeError "Empty JWT is sent in Authorization header"
62-
parseToken AppConfig{..} (Just tkn) time = do
63-
secret <- liftEither . maybeToRight (JwtErr JwtSecretMissing) $ configJWKS
64-
tknWith3Parts <- liftEither $ hasThreeParts tkn
65-
eitherContent <- liftIO $ JWT.decode (JWT.keys secret) Nothing tknWith3Parts
66-
content <- liftEither . mapLeft (JwtErr . jwtDecodeError) $ eitherContent
67-
liftEither $ mapLeft JwtErr $ verifyClaims content
68-
where
69-
hasThreeParts :: ByteString -> Either Error ByteString
70-
hasThreeParts token = case length $ BS.split (BS.c2w '.') token of
71-
3 -> Right token
72-
n -> Left $ JwtErr $ JwtDecodeError ("Expected 3 parts in JWT; got " <> show n)
73-
jwtDecodeError :: JWT.JwtError -> JwtError
74-
-- The only errors we can get from JWT.decode function are:
75-
-- BadAlgorithm
76-
-- KeyError
77-
-- BadCrypto
78-
jwtDecodeError (JWT.KeyError _) = JwtDecodeError "No suitable key or wrong key type"
79-
jwtDecodeError (JWT.BadAlgorithm _) = JwtDecodeError "Wrong or unsupported encoding algorithm"
80-
jwtDecodeError JWT.BadCrypto = JwtDecodeError "JWT cryptographic operation failed"
81-
-- Control never reaches here, the decode function only returns the above three
82-
jwtDecodeError _ = JwtDecodeError "JWT couldn't be decoded"
83-
84-
verifyClaims :: JWT.JwtContent -> Either JwtError JSON.Value
85-
verifyClaims (JWT.Jws (_, claims)) = case JSON.decodeStrict claims of
86-
Just jclaims@(JSON.Object mclaims) ->
87-
verifyClaim mclaims "exp" isValidExpClaim "JWT expired" >>
88-
verifyClaim mclaims "nbf" isValidNbfClaim "JWT not yet valid" >>
89-
verifyClaim mclaims "iat" isValidIatClaim "JWT issued at future" >>
90-
verifyClaim mclaims "aud" isValidAudClaim "JWT not in audience" >>
91-
return jclaims
92-
_ -> Left $ JwtClaimsError "Parsing claims failed"
93-
-- TODO: We could enable JWE support here (encrypted tokens)
94-
verifyClaims _ = Left $ JwtDecodeError "Unsupported token type"
95-
96-
verifyClaim mclaims claim func err = do
97-
isValid <- maybe (Right True) func (KM.lookup claim mclaims)
98-
unless isValid $ Left $ JwtClaimsError err
99-
100-
allowedSkewSeconds = 30 :: Int64
101-
now = floor . nominalDiffTimeToSeconds $ utcTimeToPOSIXSeconds time
102-
sciToInt = fromMaybe 0 . Sci.toBoundedInteger
103-
allStrings = all (\case (JSON.String _) -> True; _ -> False)
104-
105-
isValidExpClaim :: JSON.Value -> Either JwtError Bool
106-
isValidExpClaim (JSON.Number secs) = Right $ now <= (sciToInt secs + allowedSkewSeconds)
107-
isValidExpClaim _ = Left $ JwtClaimsError "The JWT 'exp' claim must be a number"
108-
109-
isValidNbfClaim :: JSON.Value -> Either JwtError Bool
110-
isValidNbfClaim (JSON.Number secs) = Right $ now >= (sciToInt secs - allowedSkewSeconds)
111-
isValidNbfClaim _ = Left $ JwtClaimsError "The JWT 'nbf' claim must be a number"
112-
113-
isValidIatClaim :: JSON.Value -> Either JwtError Bool
114-
isValidIatClaim (JSON.Number secs) = Right $ now >= (sciToInt secs - allowedSkewSeconds)
115-
isValidIatClaim _ = Left $ JwtClaimsError "The JWT 'iat' claim must be a number"
116-
117-
isValidAudClaim :: JSON.Value -> Either JwtError Bool
118-
isValidAudClaim JSON.Null = Right True -- {"aud": null} is valid for all audiences
119-
isValidAudClaim (JSON.String str) = Right $ maybe (const True) (==) configJwtAudience str
120-
isValidAudClaim (JSON.Array arr)
121-
| null arr = Right True -- {"aud": []} is valid for all audiences
122-
| allStrings arr = Right $ maybe True (\a -> JSON.String a `elem` arr) configJwtAudience
123-
isValidAudClaim _ = Left $ JwtClaimsError "The JWT 'aud' claim must be a string or an array of strings"
124-
125-
parseClaims :: Monad m =>
126-
AppConfig -> JSON.Value -> ExceptT Error m AuthResult
127-
parseClaims AppConfig{..} jclaims@(JSON.Object mclaims) = do
128-
-- role defaults to anon if not specified in jwt
129-
role <- liftEither . maybeToRight (JwtErr JwtTokenRequired) $
130-
unquoted <$> walkJSPath (Just jclaims) configJwtRoleClaimKey <|> configDbAnonRole
131-
return AuthResult
132-
{ authClaims = mclaims & KM.insert "role" (JSON.toJSON $ decodeUtf8 role)
133-
, authRole = role
134-
}
135-
where
136-
walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value
137-
walkJSPath x [] = x
138-
walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (KM.lookup (K.fromText key) o) rest
139-
walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest
140-
walkJSPath (Just (JSON.Array ar)) [JSPFilter (EqualsCond txt)] = findFirstMatch (==) txt ar
141-
walkJSPath (Just (JSON.Array ar)) [JSPFilter (NotEqualsCond txt)] = findFirstMatch (/=) txt ar
142-
walkJSPath (Just (JSON.Array ar)) [JSPFilter (StartsWithCond txt)] = findFirstMatch T.isPrefixOf txt ar
143-
walkJSPath (Just (JSON.Array ar)) [JSPFilter (EndsWithCond txt)] = findFirstMatch T.isSuffixOf txt ar
144-
walkJSPath (Just (JSON.Array ar)) [JSPFilter (ContainsCond txt)] = findFirstMatch T.isInfixOf txt ar
145-
walkJSPath _ _ = Nothing
146-
147-
findFirstMatch matchWith pattern = foldr checkMatch Nothing
148-
where
149-
checkMatch (JSON.String txt) acc
150-
| pattern `matchWith` txt = Just $ JSON.String txt
151-
| otherwise = acc
152-
checkMatch _ acc = acc
153-
154-
unquoted :: JSON.Value -> BS.ByteString
155-
unquoted (JSON.String t) = encodeUtf8 t
156-
unquoted v = LBS.toStrict $ JSON.encode v
157-
-- impossible case - just added to please -Wincomplete-patterns
158-
parseClaims _ _ = return AuthResult { authClaims = KM.empty, authRole = mempty }
159-
160-
-- | Validate authorization header.
42+
-- | Validate authorization header
16143
-- Parse and store JWT claims for future use in the request.
16244
middleware :: AppState -> Wai.Middleware
16345
middleware appState app req respond = do
164-
conf <- getConfig appState
46+
conf@AppConfig{..} <- getConfig appState
16547
time <- getTime appState
16648

16749
let token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req)
168-
parseJwt = runExceptT $ parseToken conf token time >>= parseClaims conf
50+
parseJwt = runExceptT $ lookupJwtCache jwtCacheState token >>= parseClaims conf time
16951
jwtCacheState = getJwtCacheState appState
170-
171-
-- If ServerTimingEnabled -> calculate JWT validation time
172-
-- If JwtCacheMaxLifetime -> cache JWT validation result
173-
req' <- case (configServerTimingEnabled conf, configJwtCacheMaxLifetime conf) of
174-
(True, 0) -> do
175-
(dur, authResult) <- timeItT parseJwt
176-
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }
177-
178-
(True, maxLifetime) -> do
179-
(dur, authResult) <- timeItT $ case token of
180-
Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time
181-
Nothing -> parseJwt
182-
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }
183-
184-
(False, 0) -> do
185-
authResult <- parseJwt
186-
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
187-
188-
(False, maxLifetime) -> do
189-
authResult <- case token of
190-
Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time
191-
Nothing -> parseJwt
192-
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
52+
53+
-- If ServerTimingEnabled -> calculate JWT validation time
54+
req' <- if configServerTimingEnabled then do
55+
(dur, authResult) <- timeItT parseJwt
56+
pure $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }
57+
else do
58+
authResult <- parseJwt
59+
pure $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
19360

19461
app req' respond
19562

0 commit comments

Comments
 (0)