Skip to content

Commit

Permalink
feat: implement JWT Caching
Browse files Browse the repository at this point in the history
  • Loading branch information
taimoorzaeem committed Sep 19, 2023
1 parent fd7e03e commit c83befe
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 16 deletions.
3 changes: 3 additions & 0 deletions postgrest.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ library
, auto-update >= 0.1.4 && < 0.2
, base64-bytestring >= 1 && < 1.3
, bytestring >= 0.10.8 && < 0.12
, cache >= 0.1.3 && < 0.2.0
, case-insensitive >= 1.2 && < 1.3
, cassava >= 0.4.5 && < 0.6
, clock >= 0.8.3 && < 0.9.0
, configurator-pg >= 0.2 && < 0.3
, containers >= 0.5.7 && < 0.7
, contravariant-extras >= 0.3.3 && < 0.4
Expand Down Expand Up @@ -184,6 +186,7 @@ test-suite spec
Feature.Auth.AudienceJwtSecretSpec
Feature.Auth.AuthSpec
Feature.Auth.BinaryJwtSecretSpec
Feature.Auth.JwtCachingSpec
Feature.Auth.NoAnonSpec
Feature.Auth.NoJwtSpec
Feature.ConcurrentSpec
Expand Down
16 changes: 16 additions & 0 deletions src/PostgREST/AppState.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

module PostgREST.AppState
( AppState
, AuthResult(..)
, destroy
, getConfig
, getSchemaCache
Expand All @@ -12,6 +13,7 @@ module PostgREST.AppState
, getPgVersion
, getRetryNextIn
, getTime
, getJwtCache
, init
, initWithPool
, logWithZTime
Expand All @@ -24,8 +26,11 @@ module PostgREST.AppState
, runListener
) where

import qualified Data.Aeson as JSON
import qualified Data.Aeson.KeyMap as KM
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Cache as C
import Data.Either.Combinators (whenLeft)
import qualified Data.Text.Encoding as T
import Hasql.Connection (acquire)
Expand Down Expand Up @@ -62,6 +67,11 @@ import PostgREST.SchemaCache.Identifiers (dumpQi)
import Protolude


data AuthResult = AuthResult
{ authClaims :: KM.KeyMap JSON.Value
, authRole :: BS.ByteString
}

data AppState = AppState
-- | Database connection pool
{ statePool :: SQL.Pool
Expand All @@ -87,6 +97,8 @@ data AppState = AppState
, stateRetryNextIn :: IORef Int
-- | Logs a pool error with a debounce
, debounceLogAcquisitionTimeout :: IO ()
-- | JWT Cache
, jwtCache :: C.Cache ByteString AuthResult
}

init :: AppConfig -> IO AppState
Expand All @@ -108,6 +120,7 @@ initWithPool pool conf = do
<*> myThreadId
<*> newIORef 0
<*> pure (pure ())
<*> C.newCache Nothing


debLogTimeout <-
Expand Down Expand Up @@ -188,6 +201,9 @@ putConfig = atomicWriteIORef . stateConf
getTime :: AppState -> IO UTCTime
getTime = stateGetTime

getJwtCache :: AppState -> C.Cache ByteString AuthResult
getJwtCache = jwtCache

-- | Log to stderr with local time
logWithZTime :: AppState -> Text -> IO ()
logWithZTime appState txt = do
Expand Down
64 changes: 48 additions & 16 deletions src/PostgREST/Auth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import qualified Data.Aeson.KeyMap as KM
import qualified Data.Aeson.Types as JSON
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy.Char8 as LBS
import qualified Data.Cache as C
import qualified Data.Scientific as Sci
import qualified Data.Vault.Lazy as Vault
import qualified Data.Vector as V
import qualified Network.HTTP.Types.Header as HTTP
Expand All @@ -37,21 +39,18 @@ import Control.Monad.Except (liftEither)
import Data.Either.Combinators (mapLeft)
import Data.List (lookup)
import Data.Time.Clock (UTCTime)
import System.Clock (TimeSpec (..))
import System.IO.Unsafe (unsafePerformIO)
import System.TimeIt (timeItT)

import PostgREST.AppState (AppState, getConfig, getTime)
import PostgREST.AppState (AppState, AuthResult (..), getConfig,
getJwtCache, getTime)
import PostgREST.Config (AppConfig (..), JSPath, JSPathExp (..))
import PostgREST.Error (Error (..))

import Protolude


data AuthResult = AuthResult
{ authClaims :: KM.KeyMap JSON.Value
, authRole :: BS.ByteString
}

-- | Receives the JWT secret and audience (from config) and a JWT and returns a
-- JSON object of JWT claims.
parseToken :: Monad m =>
Expand Down Expand Up @@ -107,16 +106,49 @@ middleware appState app req respond = do
let token = fromMaybe "" $ Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req)
parseJwt = runExceptT $ parseToken conf (LBS.fromStrict token) time >>= parseClaims conf

if configDbPlanEnabled conf
then do
(dur,authResult) <- timeItT parseJwt
let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }
app req' respond
else do
authResult <- parseJwt
let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
app req' respond

-- If DbPlanEnabled -> calculate JWT validation time
-- If JwtCaching -> cache JWT validation result
case (configDbPlanEnabled conf, configJwtCaching conf) of
(True, True) -> do
(dur, authResult) <- timeItT $ getJWTFromCache appState token parseJwt
let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }
app req' respond

(True, False) -> do
(dur, authResult) <- timeItT parseJwt
let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }
app req' respond

(False, True) -> do
authResult <- getJWTFromCache appState token parseJwt
let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
app req' respond

(False, False) -> do
authResult <- parseJwt
let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
app req' respond

-- Used to extract JWT exp claim and add to JWT Cache
getTimeSpec :: AuthResult -> Maybe TimeSpec
getTimeSpec res = do
let sciToInt = fromMaybe 0 . Sci.toBoundedInteger
expireJSON <- KM.lookup "exp" (authClaims res)
case expireJSON of
JSON.Number seconds -> Just $ TimeSpec (sciToInt seconds) 0
_ -> Just $ TimeSpec 0 0 -- set timeSpec to 0 so it expires immediately, hence not cached

-- | Used to retrieve and insert JWT to JWT Cache
getJWTFromCache :: AppState -> ByteString -> IO (Either Error AuthResult) -> IO (Either Error AuthResult)
getJWTFromCache appState token parseJwt = do
checkCache <- C.lookup (getJwtCache appState) token
authResult <- maybe parseJwt (pure . Right) checkCache

case (authResult,checkCache) of
(Right res, Nothing) -> C.insert' (getJwtCache appState) (getTimeSpec res) token res
_ -> pure ()

return authResult

authResultKey :: Vault.Key (Either Error AuthResult)
authResultKey = unsafePerformIO Vault.newKey
Expand Down
39 changes: 39 additions & 0 deletions test/spec/Feature/Auth/JwtCachingSpec.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
module Feature.Auth.JwtCachingSpec where

import Network.Wai (Application)

import Network.HTTP.Types
import Test.Hspec
import Test.Hspec.Wai
import Test.Hspec.Wai.JSON

import Protolude hiding (get)
import SpecHelper

spec :: SpecWith ((), Application)
spec = describe "JWT Caching" $ do
let auth = authHeaderJWT "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoicG9zdGdyZXN0X3Rlc3RfYXV0aG9yIn0.Xod-F15qsGL0WhdOCr2j3DdKuTw9QJERVgoFD3vGaWA"

it "jwt Cache warmup" $
request methodPost "/rpc/privileged_hello" [auth] [json|{"name": "jdoe"}|]
`shouldRespondWith` [json|"Privileged hello to jdoe"|]
{ matchStatus = 200
, matchHeaders = ["Server-Timing" <:> "jwt;dur="
, matchContentTypeJson]
}

it "JWT Cache, token should be cached at this request" $
request methodPost "/rpc/privileged_hello" [auth] [json|{"name": "jdoe"}|]
`shouldRespondWith` [json|"Privileged hello to jdoe"|]
{ matchStatus = 200
, matchHeaders = ["Server-Timing" <:> "jwt;dur="
, matchContentTypeJson]
}

it "JWT Claims should be retrieved from the cache" $
request methodPost "/rpc/privileged_hello" [auth] [json|{"name": "jdoe"}|]
`shouldRespondWith` [json|"Privileged hello to jdoe"|]
{ matchStatus = 200
, matchHeaders = ["Server-Timing" <:> "jwt;dur="
, matchContentTypeJson]
}
6 changes: 6 additions & 0 deletions test/spec/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import qualified Feature.Auth.AsymmetricJwtSpec
import qualified Feature.Auth.AudienceJwtSecretSpec
import qualified Feature.Auth.AuthSpec
import qualified Feature.Auth.BinaryJwtSecretSpec
import qualified Feature.Auth.JwtCachingSpec
import qualified Feature.Auth.NoAnonSpec
import qualified Feature.Auth.NoJwtSpec
import qualified Feature.ConcurrentSpec
Expand Down Expand Up @@ -111,6 +112,7 @@ main = do
pgSafeUpdateApp = app testPgSafeUpdateEnabledCfg
obsApp = app testObservabilityCfg
serverTiming = app testCfgServerTiming
jwtCaching = app testCfgJwtCaching

extraSearchPathApp = appDbs testCfgExtraSearchPath
unicodeApp = appDbs testUnicodeCfg
Expand Down Expand Up @@ -210,6 +212,10 @@ main = do
parallel $ before asymJwkSetApp $
describe "Feature.Auth.AsymmetricJwtSpec" Feature.Auth.AsymmetricJwtSpec.spec

-- this test runs with jwtCaching
parallel $ before jwtCaching $
describe "Feature.Auth.JwtCachingSpec" Feature.Auth.JwtCachingSpec.spec

-- this test runs with a nonexistent db-schema
parallel $ before nonexistentSchemaApp $
describe "Feature.Query.ErrorSpec" Feature.Query.ErrorSpec.spec
Expand Down

0 comments on commit c83befe

Please sign in to comment.