Skip to content

Commit d3395b8

Browse files
committed
feat: change JWT cache to limited LRU based cache
BREAKING CHANGE
1 parent a1009d1 commit d3395b8

25 files changed

+131
-108
lines changed

postgrest.cabal

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,8 @@ library
9999
, auto-update >= 0.1.4 && < 0.2
100100
, base64-bytestring >= 1 && < 1.3
101101
, bytestring >= 0.10.8 && < 0.13
102-
, cache >= 0.1.3 && < 0.2.0
103102
, case-insensitive >= 1.2 && < 1.3
104103
, cassava >= 0.4.5 && < 0.6
105-
, clock >= 0.8.3 && < 0.9.0
106104
, configurator-pg >= 0.2 && < 0.3
107105
, containers >= 0.5.7 && < 0.7
108106
, cookie >= 0.4.2 && < 0.5
@@ -122,6 +120,7 @@ library
122120
, jose-jwt >= 0.9.6 && < 0.11
123121
, lens >= 4.14 && < 5.3
124122
, lens-aeson >= 1.0.1 && < 1.3
123+
, lrucache >= 1.2.0.1 && < 1.3
125124
, mtl >= 2.2.2 && < 2.4
126125
, neat-interpolation >= 0.5 && < 0.6
127126
, network >= 2.6 && < 3.2
@@ -158,7 +157,7 @@ library
158157
-- -optP-Wno-nonportable-include-path
159158
-- prevents build failures on case-insensitive filesystems (macos),
160159
-- see https://github.com/commercialhaskell/stack/issues/3918
161-
ghc-options: -Werror -Wall -fwarn-identities
160+
ghc-options: -Wall -fwarn-identities
162161
-fno-spec-constr -optP-Wno-nonportable-include-path
163162

164163
if flag(dev)

src/PostgREST/AppState.hs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ import Data.IORef (IORef, atomicWriteIORef, newIORef,
5757
readIORef)
5858
import Data.Time.Clock (UTCTime, getCurrentTime)
5959

60-
import PostgREST.Auth.JwtCache (JwtCacheState)
6160
import PostgREST.Config (AppConfig (..),
6261
addFallbackAppName,
6362
readAppConfig)
@@ -105,8 +104,8 @@ data AppState = AppState
105104
, stateSocketAdmin :: Maybe NS.Socket
106105
-- | Observation handler
107106
, stateObserver :: ObservationHandler
108-
-- | JWT Cache
109-
, stateJwtCache :: JwtCache.JwtCacheState
107+
-- | JWT Cache, disabled when config jwt-cache-max-entries is set to 0
108+
, stateJwtCache :: Maybe JwtCache.JwtCacheState
110109
, stateLogger :: Logger.LoggerState
111110
, stateMetrics :: Metrics.MetricsState
112111
}
@@ -120,20 +119,20 @@ data SchemaCacheStatus
120119
type AppSockets = (NS.Socket, Maybe NS.Socket)
121120

122121
init :: AppConfig -> IO AppState
123-
init conf@AppConfig{configLogLevel, configDbPoolSize} = do
122+
init conf = do
124123
loggerState <- Logger.init
125-
metricsState <- Metrics.init configDbPoolSize
126-
let observer = liftA2 (>>) (Logger.observationLogger loggerState configLogLevel) (Metrics.observationMetrics metricsState)
124+
metricsState <- Metrics.init (configDbPoolSize conf)
125+
let observer = liftA2 (>>) (Logger.observationLogger loggerState (configLogLevel conf)) (Metrics.observationMetrics metricsState)
127126

128127
observer $ AppStartObs prettyVersion
129128

130-
jwtCacheState <- JwtCache.init
129+
jwtCacheState <- JwtCache.init (configJwtCacheMaxEntries conf)
131130
pool <- initPool conf observer
132131
(sock, adminSock) <- initSockets conf
133132
state' <- initWithPool (sock, adminSock) pool conf jwtCacheState loggerState metricsState observer
134133
pure state' { stateSocketREST = sock, stateSocketAdmin = adminSock}
135134

136-
initWithPool :: AppSockets -> SQL.Pool -> AppConfig -> JwtCache.JwtCacheState -> Logger.LoggerState -> Metrics.MetricsState -> ObservationHandler -> IO AppState
135+
initWithPool :: AppSockets -> SQL.Pool -> AppConfig -> Maybe JwtCache.JwtCacheState -> Logger.LoggerState -> Metrics.MetricsState -> ObservationHandler -> IO AppState
137136
initWithPool (sock, adminSock) pool conf jwtCacheState loggerState metricsState observer = do
138137

139138
appState <- AppState pool
@@ -311,7 +310,7 @@ putConfig = atomicWriteIORef . stateConf
311310
getTime :: AppState -> IO UTCTime
312311
getTime = stateGetTime
313312

314-
getJwtCacheState :: AppState -> JwtCacheState
313+
getJwtCacheState :: AppState -> Maybe JwtCache.JwtCacheState
315314
getJwtCacheState = stateJwtCache
316315

317316
getSocketREST :: AppState -> NS.Socket

src/PostgREST/Auth.hs

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -163,27 +163,23 @@ middleware appState app req respond = do
163163
jwtCacheState = getJwtCacheState appState
164164

165165
-- If ServerTimingEnabled -> calculate JWT validation time
166-
-- If JwtCacheMaxLifetime -> cache JWT validation result
167-
req' <- case (configServerTimingEnabled conf, configJwtCacheMaxLifetime conf) of
168-
(True, 0) -> do
169-
(dur, authResult) <- timeItT parseJwt
170-
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }
171-
172-
(True, maxLifetime) -> do
173-
(dur, authResult) <- timeItT $ case token of
174-
Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time
175-
Nothing -> parseJwt
176-
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }
166+
req' <- if configServerTimingEnabled conf then do
167+
168+
(dur, authResult) <- timeItT $ case token of
169+
170+
Just tkn -> lookupJwtCache jwtCacheState tkn parseJwt time
171+
Nothing -> parseJwt
172+
173+
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }
177174

178-
(False, 0) -> do
179-
authResult <- parseJwt
180-
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
175+
else do
176+
177+
authResult <- case token of
178+
179+
Just tkn -> lookupJwtCache jwtCacheState tkn parseJwt time
180+
Nothing -> parseJwt
181181

182-
(False, maxLifetime) -> do
183-
authResult <- case token of
184-
Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time
185-
Nothing -> parseJwt
186-
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
182+
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
187183

188184
app req' respond
189185

src/PostgREST/Auth/JwtCache.hs

Lines changed: 77 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,67 +13,96 @@ module PostgREST.Auth.JwtCache
1313

1414
import qualified Data.Aeson as JSON
1515
import qualified Data.Aeson.KeyMap as KM
16-
import qualified Data.Cache as C
16+
import qualified Data.Cache.LRU as C
17+
import qualified Data.IORef as I
1718
import qualified Data.Scientific as Sci
1819

1920
import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds)
2021
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
21-
import System.Clock (TimeSpec (..))
22+
import GHC.Num (integerFromInt)
2223

2324
import PostgREST.Auth.Types (AuthResult (..))
2425
import PostgREST.Error (Error (..))
2526

2627
import Protolude
2728

2829
newtype JwtCacheState = JwtCacheState
29-
{ jwtCache :: C.Cache ByteString AuthResult
30+
{ jwtCacheIORef :: I.IORef (C.LRU ByteString AuthResult)
3031
}
3132

3233
-- | Initialize JwtCacheState
33-
init :: IO JwtCacheState
34-
init = do
35-
cache <- C.newCache Nothing -- no default expiration
36-
return $ JwtCacheState cache
34+
init :: Int -> IO (Maybe JwtCacheState)
35+
init 0 = return Nothing
36+
init maxEntries = do
37+
cache <- I.newIORef $ C.newLRU (Just $ integerFromInt maxEntries)
38+
return $ Just $ JwtCacheState cache
39+
3740

3841
-- | Used to retrieve and insert JWT to JWT Cache
39-
lookupJwtCache :: JwtCacheState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult)
40-
lookupJwtCache JwtCacheState{jwtCache} token maxLifetime parseJwt utc = do
41-
checkCache <- C.lookup jwtCache token
42-
authResult <- maybe parseJwt (pure . Right) checkCache
43-
44-
case (authResult,checkCache) of
45-
-- From comment:
46-
-- https://github.com/PostgREST/postgrest/pull/3801#discussion_r1857987914
47-
--
48-
-- We purge expired cache entries on a cache miss
49-
-- The reasoning is that:
50-
--
51-
-- 1. We expect it to be rare (otherwise there is no point of the cache)
52-
-- 2. It makes sure the cache is not growing (as inserting new entries
53-
-- does garbage collection)
54-
-- 3. Since this is time expiration based cache there is no real risk of
55-
-- starvation - sooner or later we are going to have a cache miss.
56-
57-
(Right res, Nothing) -> do -- cache miss
58-
59-
let timeSpec = getTimeSpec res maxLifetime utc
60-
61-
-- purge expired cache entries
62-
C.purgeExpired jwtCache
63-
64-
-- insert new cache entry
65-
C.insert' jwtCache (Just timeSpec) token res
66-
67-
_ -> pure ()
68-
69-
return authResult
70-
71-
-- Used to extract JWT exp claim and add to JWT Cache
72-
getTimeSpec :: AuthResult -> Int -> UTCTime -> TimeSpec
73-
getTimeSpec res maxLifetime utc = do
74-
let expireJSON = KM.lookup "exp" (authClaims res)
75-
utcToSecs = floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds
76-
sciToInt = fromMaybe 0 . Sci.toBoundedInteger
77-
case expireJSON of
78-
Just (JSON.Number seconds) -> TimeSpec (sciToInt seconds - utcToSecs utc) 0
79-
_ -> TimeSpec (fromIntegral maxLifetime :: Int64) 0
42+
lookupJwtCache :: Maybe JwtCacheState -> ByteString -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult)
43+
lookupJwtCache Nothing _ parseJwt _ = parseJwt
44+
lookupJwtCache (Just JwtCacheState{jwtCacheIORef}) token parseJwt utc = do
45+
-- get cache from IORef
46+
jwtCache <- I.readIORef jwtCacheIORef
47+
48+
-- lookup key = token
49+
let (jwtCache', maybeVal) = C.lookup token jwtCache
50+
51+
case maybeVal of
52+
53+
Nothing -> do -- CACHE MISS
54+
55+
-- When we get a cache miss, we get the parse result, insert it
56+
-- into the cache. After that, we write the cache IO ref with
57+
-- updated cache
58+
59+
authResult <- parseJwt
60+
61+
case authResult of
62+
63+
Right res -> do
64+
65+
let jwtCache'' = C.insert token res jwtCache'
66+
67+
-- update IORef
68+
I.writeIORef jwtCacheIORef jwtCache''
69+
70+
-- return the result
71+
return $ Right res
72+
73+
Left e -> return $ Left e
74+
75+
Just res -> -- CACHE HIT
76+
77+
-- For cache hit, we get the result from cache, we check the
78+
-- exp claim. If it expired, we delete it from cache and parse
79+
-- the jwt. Otherwise, the hit result is valid, so we return it
80+
81+
if isExpClaimExpired res utc then do
82+
83+
let (jwtCache'',_) = C.delete token jwtCache'
84+
85+
I.writeIORef jwtCacheIORef jwtCache''
86+
87+
parseJwt
88+
89+
else do
90+
91+
I.writeIORef jwtCacheIORef jwtCache'
92+
93+
return $ Right res
94+
95+
96+
type Expired = Bool
97+
98+
-- | Check if exp claim is expired when looked up from cache
99+
isExpClaimExpired :: AuthResult -> UTCTime -> Expired
100+
isExpClaimExpired res utc =
101+
case expireJSON of
102+
Nothing -> False -- if exp not present then it is valid
103+
Just (JSON.Number expiredAt) -> (sciToInt expiredAt - now) < 0
104+
Just _ -> False -- if exp is not a number then valid
105+
where
106+
expireJSON = KM.lookup "exp" (authClaims res)
107+
now = (floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds) utc :: Int
108+
sciToInt = fromMaybe 0 . Sci.toBoundedInteger

src/PostgREST/CLI.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ exampleConfigFile =
203203
|# jwt-secret = "secret_with_at_least_32_characters"
204204
|jwt-secret-is-base64 = false
205205
|
206-
|## Enables and set JWT Cache max lifetime, disables caching with 0
207-
|# jwt-cache-max-lifetime = 0
206+
|## Enables and set JWT Cache max entries, disables caching with 0
207+
|# jwt-cache-max-entries = 0
208208
|
209209
|## Logging level, the admitted values are: crit, error, warn, info and debug.
210210
|log-level = "error"

src/PostgREST/Config.hs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ data AppConfig = AppConfig
9797
, configJwtRoleClaimKey :: JSPath
9898
, configJwtSecret :: Maybe BS.ByteString
9999
, configJwtSecretIsBase64 :: Bool
100-
, configJwtCacheMaxLifetime :: Int
100+
, configJwtCacheMaxEntries :: Int
101101
, configLogLevel :: LogLevel
102102
, configLogQuery :: LogQuery
103103
, configOpenApiMode :: OpenAPIMode
@@ -177,7 +177,7 @@ toText conf =
177177
,("jwt-role-claim-key", q . T.intercalate mempty . fmap dumpJSPath . configJwtRoleClaimKey)
178178
,("jwt-secret", q . T.decodeUtf8 . showJwtSecret)
179179
,("jwt-secret-is-base64", T.toLower . show . configJwtSecretIsBase64)
180-
,("jwt-cache-max-lifetime", show . configJwtCacheMaxLifetime)
180+
,("jwt-cache-max-entries", show . configJwtCacheMaxEntries)
181181
,("log-level", q . dumpLogLevel . configLogLevel)
182182
,("log-query", q . dumpLogQuery . configLogQuery)
183183
,("openapi-mode", q . dumpOpenApiMode . configOpenApiMode)
@@ -287,7 +287,7 @@ parser optPath env dbSettings roleSettings roleIsolationLvl =
287287
<*> (fromMaybe False <$> optWithAlias
288288
(optBool "jwt-secret-is-base64")
289289
(optBool "secret-is-base64"))
290-
<*> (fromMaybe 0 <$> optInt "jwt-cache-max-lifetime")
290+
<*> (fromMaybe 1000 <$> optInt "jwt-cache-max-entries")
291291
<*> parseLogLevel "log-level"
292292
<*> parseLogQuery "log-query"
293293
<*> parseOpenAPIMode "openapi-mode"

src/PostgREST/Config/Database.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ dbSettingsNames =
6161
,"jwt_role_claim_key"
6262
,"jwt_secret"
6363
,"jwt_secret_is_base64"
64-
,"jwt_cache_max_lifetime"
64+
,"jwt-cache-max-entries"
6565
,"openapi_mode"
6666
,"openapi_security_active"
6767
,"openapi_server_proxy_uri"

test/io/configs/expected/aliases.config

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jwt-aud = ""
2323
jwt-role-claim-key = ".\"aliased\""
2424
jwt-secret = ""
2525
jwt-secret-is-base64 = true
26-
jwt-cache-max-lifetime = 0
26+
jwt-cache-max-entries = 1000
2727
log-level = "error"
2828
log-query = "disabled"
2929
openapi-mode = "follow-privileges"

test/io/configs/expected/boolean-numeric.config

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jwt-aud = ""
2323
jwt-role-claim-key = ".\"role\""
2424
jwt-secret = ""
2525
jwt-secret-is-base64 = true
26-
jwt-cache-max-lifetime = 0
26+
jwt-cache-max-entries = 1000
2727
log-level = "error"
2828
log-query = "disabled"
2929
openapi-mode = "follow-privileges"

test/io/configs/expected/boolean-string.config

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jwt-aud = ""
2323
jwt-role-claim-key = ".\"role\""
2424
jwt-secret = ""
2525
jwt-secret-is-base64 = true
26-
jwt-cache-max-lifetime = 0
26+
jwt-cache-max-entries = 1000
2727
log-level = "error"
2828
log-query = "disabled"
2929
openapi-mode = "follow-privileges"

0 commit comments

Comments
 (0)