From 7ffca041a3c9891760fbd38674482202ab0ad67c Mon Sep 17 00:00:00 2001 From: "M. Taimoor Zaeem" Date: Sat, 23 Nov 2024 00:41:55 +0500 Subject: [PATCH] fix: jwt cache is not purged Fixed the JWT cache not purging expired tokens in cache. This is tested by adding a new metric pgrst_jwt_cache_size. --- docs/references/observability.rst | 14 ++++++++ src/PostgREST/Auth.hs | 24 +++++++++---- src/PostgREST/Metrics.hs | 11 +++--- src/PostgREST/Observation.hs | 7 +++- test/io/test_io.py | 57 +++++++++++++++++++++++++++++++ 5 files changed, 102 insertions(+), 11 deletions(-) diff --git a/docs/references/observability.rst b/docs/references/observability.rst index 04a261ba48..bcb7158780 100644 --- a/docs/references/observability.rst +++ b/docs/references/observability.rst @@ -169,6 +169,20 @@ pgrst_db_pool_max Max pool connections. +JWT Cache Metric +---------------- + +Related to the :ref:`jwt_caching`. + +pgrst_jwt_cache_size +~~~~~~~~~~~~~~~~~~~~ + +======== ======= +**Type** Gauge +======== ======= + +JWT Cache Size. + Traces ====== diff --git a/src/PostgREST/Auth.hs b/src/PostgREST/Auth.hs index 46afa801c2..89d1a116a4 100644 --- a/src/PostgREST/Auth.hs +++ b/src/PostgREST/Auth.hs @@ -44,10 +44,11 @@ import System.Clock (TimeSpec (..)) import System.IO.Unsafe (unsafePerformIO) import System.TimeIt (timeItT) -import PostgREST.AppState (AppState, AuthResult (..), getConfig, - getJwtCache, getTime) -import PostgREST.Config (AppConfig (..), JSPath, JSPathExp (..)) -import PostgREST.Error (Error (..)) +import PostgREST.AppState (AppState, AuthResult (..), getConfig, + getJwtCache, getObserver, getTime) +import PostgREST.Config (AppConfig (..), JSPath, JSPathExp (..)) +import PostgREST.Error (Error (..)) +import PostgREST.Observation (Observation (..)) import Protolude @@ -163,14 +164,25 @@ middleware appState app req respond = do -- | Used to retrieve and insert JWT to JWT Cache getJWTFromCache :: AppState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult) getJWTFromCache appState token maxLifetime parseJwt utc = do - checkCache <- C.lookup (getJwtCache appState) token + + -- Purge expired tokens in a separate thread + _ <- forkIO $ C.purgeExpired jwtCache + + checkCache <- C.lookup jwtCache token authResult <- maybe parseJwt (pure . Right) checkCache + case (authResult,checkCache) of - (Right res, Nothing) -> C.insert' (getJwtCache appState) (getTimeSpec res maxLifetime utc) token res + (Right res, Nothing) -> C.insert' jwtCache (getTimeSpec res maxLifetime utc) token res _ -> pure () + jwtCacheSize <- C.size jwtCache + observer $ JWTCache jwtCacheSize + return authResult + where + observer = getObserver appState + jwtCache = getJwtCache appState -- Used to extract JWT exp claim and add to JWT Cache getTimeSpec :: AuthResult -> Int -> UTCTime -> Maybe TimeSpec diff --git a/src/PostgREST/Metrics.hs b/src/PostgREST/Metrics.hs index 3999e43d83..84974e970c 100644 --- a/src/PostgREST/Metrics.hs +++ b/src/PostgREST/Metrics.hs @@ -1,5 +1,5 @@ {-| -Module : PostgREST.Logger +Module : PostgREST.Metrics Description : Metrics based on the Observation module. See Observation.hs. -} module PostgREST.Metrics @@ -19,7 +19,7 @@ import PostgREST.Observation import Protolude data MetricsState = - MetricsState Counter Gauge Gauge Gauge (Vector Label1 Counter) Gauge + MetricsState Counter Gauge Gauge Gauge (Vector Label1 Counter) Gauge Gauge init :: Int -> IO MetricsState init configDbPoolSize = do @@ -29,12 +29,13 @@ init configDbPoolSize = do poolMaxSize <- register $ gauge (Info "pgrst_db_pool_max" "Max pool connections") schemaCacheLoads <- register $ vector "status" $ counter (Info "pgrst_schema_cache_loads_total" "The total number of times the schema cache was loaded") schemaCacheQueryTime <- register $ gauge (Info "pgrst_schema_cache_query_time_seconds" "The query time in seconds of the last schema cache load") + jwtCacheSize <- register $ gauge (Info "pgrst_jwt_cache_size" "The number of cached JWTs") setGauge poolMaxSize (fromIntegral configDbPoolSize) - pure $ MetricsState poolTimeouts poolAvailable poolWaiting poolMaxSize schemaCacheLoads schemaCacheQueryTime + pure $ MetricsState poolTimeouts poolAvailable poolWaiting poolMaxSize schemaCacheLoads schemaCacheQueryTime jwtCacheSize -- Only some observations are used as metrics observationMetrics :: MetricsState -> ObservationHandler -observationMetrics (MetricsState poolTimeouts poolAvailable poolWaiting _ schemaCacheLoads schemaCacheQueryTime) obs = case obs of +observationMetrics (MetricsState poolTimeouts poolAvailable poolWaiting _ schemaCacheLoads schemaCacheQueryTime jwtCacheSize) obs = case obs of (PoolAcqTimeoutObs _) -> do incCounter poolTimeouts (HasqlPoolObs (SQL.ConnectionObservation _ status)) -> case status of @@ -54,6 +55,8 @@ observationMetrics (MetricsState poolTimeouts poolAvailable poolWaiting _ schema setGauge schemaCacheQueryTime resTime SchemaCacheErrorObs _ -> do withLabel schemaCacheLoads "FAIL" incCounter + JWTCache cacheSize -> do + setGauge jwtCacheSize $ fromIntegral cacheSize _ -> pure () diff --git a/src/PostgREST/Observation.hs b/src/PostgREST/Observation.hs index 18fbf558d7..5a3d6da42a 100644 --- a/src/PostgREST/Observation.hs +++ b/src/PostgREST/Observation.hs @@ -57,8 +57,13 @@ data Observation | HasqlPoolObs SQL.Observation | PoolRequest | PoolRequestFullfilled + | JWTCache Int -data ObsFatalError = ServerAuthError | ServerPgrstBug | ServerError42P05 | ServerError08P01 +data ObsFatalError + = ServerAuthError + | ServerPgrstBug + | ServerError42P05 + | ServerError08P01 type ObservationHandler = Observation -> IO () diff --git a/test/io/test_io.py b/test/io/test_io.py index 75c2ad8617..eeac008487 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -1632,6 +1632,8 @@ def test_admin_metrics(defaultenv): assert "pgrst_db_pool_available" in response.text assert "pgrst_db_pool_timeouts_total" in response.text + assert "pgrst_jwt_cache_size" in response.text + def test_schema_cache_startup_load_with_in_db_config(defaultenv, metapostgrest): "verify that the Schema Cache loads correctly at startup, using the in-db `pgrst.db_schemas` config" @@ -1648,3 +1650,58 @@ def test_schema_cache_startup_load_with_in_db_config(defaultenv, metapostgrest): response = metapostgrest.session.post("/rpc/reset_db_schemas_config") assert response.text == "" assert response.status_code == 204 + + +def test_jwt_cache_size_decreases_after_expiry(defaultenv): + "verify that JWT purges expired JWTs" + + relativeSeconds = lambda sec: int( + (datetime.now(timezone.utc) + timedelta(seconds=sec)).timestamp() + ) + + headers = lambda sec: jwtauthheader( + {"role": "postgrest_test_author", "exp": relativeSeconds(sec)}, + SECRET, + ) + + env = { + **defaultenv, + "PGRST_JWT_CACHE_MAX_LIFETIME": "86400", + "PGRST_JWT_SECRET": SECRET, + "PGRST_DB_CONFIG": "false", + } + + with run(env=env, port=freeport()) as postgrest: + + # Generate three unique JWT tokens + # The 1 second sleep is needed for it generate a unique token + hdrs1 = headers(5) + postgrest.session.get("/authors_only", headers=hdrs1) + + time.sleep(1) + + hdrs2 = headers(5) + postgrest.session.get("/authors_only", headers=hdrs2) + + time.sleep(1) + + hdrs3 = headers(5) + postgrest.session.get("/authors_only", headers=hdrs3) + + # the cache should now have three tokens + response = postgrest.admin.get("/metrics") + assert response.status_code == 200 + assert "pgrst_jwt_cache_size 3.0" in response.text + + # Wait 5 seconds for the tokens to expire + time.sleep(5) + + hdrs4 = headers(5) + + # Make another request to force call the purgeExpired method + # This should remove the 3 expired tokens and adds 1 to cache + postgrest.session.get("/authors_only", headers=hdrs4) + + response = postgrest.admin.get("/metrics") + assert response.status_code == 200 + assert "pgrst_jwt_cache_size 1.0" in response.text