Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement JWT caching #2928

Merged
merged 2 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
- #1614, Add `db-pool-automatic-recovery` configuration to disable connection retrying - @taimoorzaeem
- #2492, Allow full response control when raising exceptions - @taimoorzaeem, @laurenceisla
- #2771, Add `Server-Timing` header with JWT duration - @taimoorzaeem
- #2698, Add config `jwt-cache-max-lifetime` and implement JWT caching - @taimoorzaeem

### Fixed

Expand Down
2 changes: 2 additions & 0 deletions postgrest.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ library
, auto-update >= 0.1.4 && < 0.2
, base64-bytestring >= 1 && < 1.3
, bytestring >= 0.10.8 && < 0.12
, cache >= 0.1.3 && < 0.2.0
, case-insensitive >= 1.2 && < 1.3
, cassava >= 0.4.5 && < 0.6
, clock >= 0.8.3 && < 0.9.0
, configurator-pg >= 0.2 && < 0.3
, containers >= 0.5.7 && < 0.7
, contravariant-extras >= 0.3.3 && < 0.4
Expand Down
16 changes: 16 additions & 0 deletions src/PostgREST/AppState.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

module PostgREST.AppState
( AppState
, AuthResult(..)
, destroy
, getConfig
, getSchemaCache
Expand All @@ -12,6 +13,7 @@ module PostgREST.AppState
, getPgVersion
, getRetryNextIn
, getTime
, getJwtCache
, init
, initWithPool
, logWithZTime
Expand All @@ -24,8 +26,11 @@ module PostgREST.AppState
, runListener
) where

import qualified Data.Aeson as JSON
import qualified Data.Aeson.KeyMap as KM
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Cache as C
import Data.Either.Combinators (whenLeft)
import qualified Data.Text.Encoding as T
import Hasql.Connection (acquire)
Expand Down Expand Up @@ -62,6 +67,11 @@ import PostgREST.SchemaCache.Identifiers (dumpQi)
import Protolude


data AuthResult = AuthResult
{ authClaims :: KM.KeyMap JSON.Value
, authRole :: BS.ByteString
}

data AppState = AppState
-- | Database connection pool
{ statePool :: SQL.Pool
Expand All @@ -87,6 +97,8 @@ data AppState = AppState
, stateRetryNextIn :: IORef Int
-- | Logs a pool error with a debounce
, debounceLogAcquisitionTimeout :: IO ()
-- | JWT Cache
, jwtCache :: C.Cache ByteString AuthResult
}

init :: AppConfig -> IO AppState
Expand All @@ -108,6 +120,7 @@ initWithPool pool conf = do
<*> myThreadId
<*> newIORef 0
<*> pure (pure ())
<*> C.newCache Nothing


debLogTimeout <-
Expand Down Expand Up @@ -188,6 +201,9 @@ putConfig = atomicWriteIORef . stateConf
getTime :: AppState -> IO UTCTime
getTime = stateGetTime

getJwtCache :: AppState -> C.Cache ByteString AuthResult
getJwtCache = jwtCache

-- | Log to stderr with local time
logWithZTime :: AppState -> Text -> IO ()
logWithZTime appState txt = do
Expand Down
66 changes: 49 additions & 17 deletions src/PostgREST/Auth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import qualified Data.Aeson.KeyMap as KM
import qualified Data.Aeson.Types as JSON
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy.Char8 as LBS
import qualified Data.Cache as C
import qualified Data.Scientific as Sci
import qualified Data.Vault.Lazy as Vault
import qualified Data.Vector as V
import qualified Network.HTTP.Types.Header as HTTP
Expand All @@ -36,22 +38,20 @@ import Control.Lens (set)
import Control.Monad.Except (liftEither)
import Data.Either.Combinators (mapLeft)
import Data.List (lookup)
import Data.Time.Clock (UTCTime)
import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
import System.Clock (TimeSpec (..))
import System.IO.Unsafe (unsafePerformIO)
import System.TimeIt (timeItT)

import PostgREST.AppState (AppState, getConfig, getTime)
import PostgREST.AppState (AppState, AuthResult (..), getConfig,
getJwtCache, getTime)
import PostgREST.Config (AppConfig (..), JSPath, JSPathExp (..))
import PostgREST.Error (Error (..))

import Protolude


data AuthResult = AuthResult
{ authClaims :: KM.KeyMap JSON.Value
, authRole :: BS.ByteString
}

-- | Receives the JWT secret and audience (from config) and a JWT and returns a
-- JSON object of JWT claims.
parseToken :: Monad m =>
Expand Down Expand Up @@ -107,16 +107,48 @@ middleware appState app req respond = do
let token = fromMaybe "" $ Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req)
parseJwt = runExceptT $ parseToken conf (LBS.fromStrict token) time >>= parseClaims conf

if configDbPlanEnabled conf
then do
(dur,authResult) <- timeItT parseJwt
let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }
app req' respond
else do
authResult <- parseJwt
let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
app req' respond

-- If DbPlanEnabled -> calculate JWT validation time
-- If JwtCacheMaxLifetime -> cache JWT validation result
req' <- case (configDbPlanEnabled conf, configJwtCacheMaxLifetime conf) of
(True, 0) -> do
(dur, authResult) <- timeItT parseJwt
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }

(True, maxLifetime) -> do
(dur, authResult) <- timeItT $ getJWTFromCache appState token maxLifetime parseJwt time
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }

(False, 0) -> do
authResult <- parseJwt
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }

(False, maxLifetime) -> do
authResult <- getJWTFromCache appState token maxLifetime parseJwt time
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }

app req' respond

-- | Used to retrieve and insert JWT to JWT Cache
getJWTFromCache :: AppState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult)
getJWTFromCache appState token maxLifetime parseJwt utc = do
checkCache <- C.lookup (getJwtCache appState) token
authResult <- maybe parseJwt (pure . Right) checkCache

case (authResult,checkCache) of
(Right res, Nothing) -> C.insert' (getJwtCache appState) (getTimeSpec res maxLifetime utc) token res
_ -> pure ()

return authResult

-- Used to extract JWT exp claim and add to JWT Cache
getTimeSpec :: AuthResult -> Int -> UTCTime -> Maybe TimeSpec
getTimeSpec res maxLifetime utc = do
let expireJSON = KM.lookup "exp" (authClaims res)
utcToSecs = floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds
sciToInt = fromMaybe 0 . Sci.toBoundedInteger
case expireJSON of
Just (JSON.Number seconds) -> Just $ TimeSpec (sciToInt seconds - utcToSecs utc) 0
_ -> Just $ TimeSpec (fromIntegral maxLifetime :: Int64) 0

authResultKey :: Vault.Key (Either Error AuthResult)
authResultKey = unsafePerformIO Vault.newKey
Expand Down
5 changes: 4 additions & 1 deletion src/PostgREST/CLI.hs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ exampleConfigFile =
|## Time in seconds after which to recycle unused pool connections
|# db-pool-max-idletime = 30
|
|## Allow autmatic database connection retrying
|## Allow automatic database connection retrying
|# db-pool-automatic-recovery = true
|
|## Stored proc to exec immediately after auth
Expand Down Expand Up @@ -205,6 +205,9 @@ exampleConfigFile =
|# jwt-secret = "secret_with_at_least_32_characters"
|jwt-secret-is-base64 = false
|
|## Enables and set JWT Cache max lifetime, disables caching with 0
|# jwt-cache-max-lifetime = 0
|
|## Logging level, the admitted values are: crit, error, warn and info.
|log-level = "error"
|
Expand Down
3 changes: 3 additions & 0 deletions src/PostgREST/Config.hs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ data AppConfig = AppConfig
, configJwtRoleClaimKey :: JSPath
, configJwtSecret :: Maybe BS.ByteString
, configJwtSecretIsBase64 :: Bool
, configJwtCacheMaxLifetime :: Int
, configLogLevel :: LogLevel
, configOpenApiMode :: OpenAPIMode
, configOpenApiSecurityActive :: Bool
Expand Down Expand Up @@ -162,6 +163,7 @@ toText conf =
,("jwt-role-claim-key", q . T.intercalate mempty . fmap dumpJSPath . configJwtRoleClaimKey)
,("jwt-secret", q . T.decodeUtf8 . showJwtSecret)
,("jwt-secret-is-base64", T.toLower . show . configJwtSecretIsBase64)
,("jwt-cache-max-lifetime", show . configJwtCacheMaxLifetime)
,("log-level", q . dumpLogLevel . configLogLevel)
,("openapi-mode", q . dumpOpenApiMode . configOpenApiMode)
,("openapi-security-active", T.toLower . show . configOpenApiSecurityActive)
Expand Down Expand Up @@ -265,6 +267,7 @@ parser optPath env dbSettings roleSettings roleIsolationLvl =
<*> (fromMaybe False <$> optWithAlias
(optBool "jwt-secret-is-base64")
(optBool "secret-is-base64"))
<*> (fromMaybe 0 <$> optInt "jwt-cache-max-lifetime")
<*> parseLogLevel "log-level"
<*> parseOpenAPIMode "openapi-mode"
<*> (fromMaybe False <$> optBool "openapi-security-active")
Expand Down
1 change: 1 addition & 0 deletions test/io/configs/expected/aliases.config
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jwt-aud = ""
jwt-role-claim-key = ".\"aliased\""
jwt-secret = ""
jwt-secret-is-base64 = true
jwt-cache-max-lifetime = 0
log-level = "error"
openapi-mode = "follow-privileges"
openapi-security-active = false
Expand Down
1 change: 1 addition & 0 deletions test/io/configs/expected/boolean-numeric.config
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jwt-aud = ""
jwt-role-claim-key = ".\"role\""
jwt-secret = ""
jwt-secret-is-base64 = true
jwt-cache-max-lifetime = 0
log-level = "error"
openapi-mode = "follow-privileges"
openapi-security-active = false
Expand Down
1 change: 1 addition & 0 deletions test/io/configs/expected/boolean-string.config
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jwt-aud = ""
jwt-role-claim-key = ".\"role\""
jwt-secret = ""
jwt-secret-is-base64 = true
jwt-cache-max-lifetime = 0
log-level = "error"
openapi-mode = "follow-privileges"
openapi-security-active = false
Expand Down
1 change: 1 addition & 0 deletions test/io/configs/expected/defaults.config
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jwt-aud = ""
jwt-role-claim-key = ".\"role\""
jwt-secret = ""
jwt-secret-is-base64 = false
jwt-cache-max-lifetime = 0
log-level = "error"
openapi-mode = "follow-privileges"
openapi-security-active = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jwt-aud = "https://otherexample.org"
jwt-role-claim-key = ".\"other\".\"pre_config_role\""
jwt-secret = "ODERREALLYREALLYREALLYREALLYVERYSAFE"
jwt-secret-is-base64 = true
jwt-cache-max-lifetime = 86400
log-level = "info"
openapi-mode = "disabled"
openapi-security-active = false
Expand Down
1 change: 1 addition & 0 deletions test/io/configs/expected/no-defaults-with-db.config
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jwt-aud = "https://example.org"
jwt-role-claim-key = ".\"a\".\"role\""
jwt-secret = "OVERRIDE=REALLY=REALLY=REALLY=REALLY=VERY=SAFE"
jwt-secret-is-base64 = false
jwt-cache-max-lifetime = 86400
log-level = "info"
openapi-mode = "ignore-privileges"
openapi-security-active = true
Expand Down
1 change: 1 addition & 0 deletions test/io/configs/expected/no-defaults.config
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jwt-aud = "https://postgrest.org"
jwt-role-claim-key = ".\"user\"[0].\"real-role\""
jwt-secret = "c2VjdXJpdHl0aHJvdWdob2JzY3VyaXR5"
jwt-secret-is-base64 = true
jwt-cache-max-lifetime = 86400
log-level = "info"
openapi-mode = "ignore-privileges"
openapi-security-active = true
Expand Down
1 change: 1 addition & 0 deletions test/io/configs/expected/types.config
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jwt-aud = ""
jwt-role-claim-key = ".\"role\""
jwt-secret = ""
jwt-secret-is-base64 = false
jwt-cache-max-lifetime = 0
log-level = "error"
openapi-mode = "follow-privileges"
openapi-security-active = false
Expand Down
1 change: 1 addition & 0 deletions test/io/configs/no-defaults-env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ PGRST_JWT_AUD: 'https://postgrest.org'
PGRST_JWT_ROLE_CLAIM_KEY: '.user[0]."real-role"'
PGRST_JWT_SECRET: c2VjdXJpdHl0aHJvdWdob2JzY3VyaXR5
PGRST_JWT_SECRET_IS_BASE64: true
PGRST_JWT_CACHE_MAX_LIFETIME: 86400
PGRST_LOG_LEVEL: info
PGRST_OPENAPI_MODE: 'ignore-privileges'
PGRST_OPENAPI_SECURITY_ACTIVE: true
Expand Down
1 change: 1 addition & 0 deletions test/io/configs/no-defaults.config
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jwt-aud = "https://postgrest.org"
jwt-role-claim-key = ".user[0].\"real-role\""
jwt-secret = "c2VjdXJpdHl0aHJvdWdob2JzY3VyaXR5"
jwt-secret-is-base64 = true
jwt-cache-max-lifetime = 86400
log-level = "info"
openapi-mode = "ignore-privileges"
openapi-security-active = true
Expand Down
Loading
Loading