Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: jwt cache is not purged #3801

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/references/observability.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,20 @@ pgrst_db_pool_max

Max pool connections.

JWT Cache Metric
----------------

Related to the :ref:`jwt_caching`.

pgrst_jwt_cache_size
~~~~~~~~~~~~~~~~~~~~

======== =======
**Type** Gauge
======== =======

JWT Cache Size.

Traces
======

Expand Down
24 changes: 18 additions & 6 deletions src/PostgREST/Auth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ import System.Clock (TimeSpec (..))
import System.IO.Unsafe (unsafePerformIO)
import System.TimeIt (timeItT)

import PostgREST.AppState (AppState, AuthResult (..), getConfig,
getJwtCache, getTime)
import PostgREST.Config (AppConfig (..), JSPath, JSPathExp (..))
import PostgREST.Error (Error (..))
import PostgREST.AppState (AppState, AuthResult (..), getConfig,
getJwtCache, getObserver, getTime)
import PostgREST.Config (AppConfig (..), JSPath, JSPathExp (..))
import PostgREST.Error (Error (..))
import PostgREST.Observation (Observation (..))

import Protolude

Expand Down Expand Up @@ -163,14 +164,25 @@ middleware appState app req respond = do
-- | Used to retrieve and insert JWT to JWT Cache
getJWTFromCache :: AppState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult)
getJWTFromCache appState token maxLifetime parseJwt utc = do
checkCache <- C.lookup (getJwtCache appState) token

-- Purge expired tokens in a separate thread
_ <- forkIO $ C.purgeExpired jwtCache
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this is going to start a new thread on every request? That doesn't look efficient.

I think we should have a background thread that is notified for doing the purge.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest - for starters I would implement a simple solution that executes purgeExpired periodically ( frequency subject to configuration possibly - but not necessarily ).

Triggering purge upon every request might mean the purging thread is running constantly wasting cycles as there is nothing to do most of the times.

The ultimate solution would be to have some kind of scheduler that triggers purging upon next nearest expiry - but that would require more complex bookkeeping.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it some more - it seems to me the best moment to trigger purge is upon a cache miss - just before/after inserting a new entry.

The reasoning is that:

  1. We expect it to be rare (otherwise there is no point of the cache)
  2. It makes sure the cache is not growing (as inserting new entries does garbage collection)
  3. Since this is time expiration based cache there is no real risk of starvation - sooner or later we are going to have a cache miss.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a much better idea. This way, we wouldn't need to do maintain any new thread state, hence avoiding extra overhead. @steve-chavez WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, looks like great idea 👍


checkCache <- C.lookup jwtCache token
authResult <- maybe parseJwt (pure . Right) checkCache


case (authResult,checkCache) of
(Right res, Nothing) -> C.insert' (getJwtCache appState) (getTimeSpec res maxLifetime utc) token res
(Right res, Nothing) -> C.insert' jwtCache (getTimeSpec res maxLifetime utc) token res
_ -> pure ()

jwtCacheSize <- C.size jwtCache
observer $ JWTCache jwtCacheSize

return authResult
where
observer = getObserver appState
jwtCache = getJwtCache appState

-- Used to extract JWT exp claim and add to JWT Cache
getTimeSpec :: AuthResult -> Int -> UTCTime -> Maybe TimeSpec
Expand Down
11 changes: 7 additions & 4 deletions src/PostgREST/Metrics.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{-|
Module : PostgREST.Logger
Module : PostgREST.Metrics
Description : Metrics based on the Observation module. See Observation.hs.
-}
module PostgREST.Metrics
Expand All @@ -19,7 +19,7 @@ import PostgREST.Observation
import Protolude

data MetricsState =
MetricsState Counter Gauge Gauge Gauge (Vector Label1 Counter) Gauge
MetricsState Counter Gauge Gauge Gauge (Vector Label1 Counter) Gauge Gauge

init :: Int -> IO MetricsState
init configDbPoolSize = do
Expand All @@ -29,12 +29,13 @@ init configDbPoolSize = do
poolMaxSize <- register $ gauge (Info "pgrst_db_pool_max" "Max pool connections")
schemaCacheLoads <- register $ vector "status" $ counter (Info "pgrst_schema_cache_loads_total" "The total number of times the schema cache was loaded")
schemaCacheQueryTime <- register $ gauge (Info "pgrst_schema_cache_query_time_seconds" "The query time in seconds of the last schema cache load")
jwtCacheSize <- register $ gauge (Info "pgrst_jwt_cache_size" "The number of cached JWTs")
setGauge poolMaxSize (fromIntegral configDbPoolSize)
pure $ MetricsState poolTimeouts poolAvailable poolWaiting poolMaxSize schemaCacheLoads schemaCacheQueryTime
pure $ MetricsState poolTimeouts poolAvailable poolWaiting poolMaxSize schemaCacheLoads schemaCacheQueryTime jwtCacheSize

-- Only some observations are used as metrics
observationMetrics :: MetricsState -> ObservationHandler
observationMetrics (MetricsState poolTimeouts poolAvailable poolWaiting _ schemaCacheLoads schemaCacheQueryTime) obs = case obs of
observationMetrics (MetricsState poolTimeouts poolAvailable poolWaiting _ schemaCacheLoads schemaCacheQueryTime jwtCacheSize) obs = case obs of
(PoolAcqTimeoutObs _) -> do
incCounter poolTimeouts
(HasqlPoolObs (SQL.ConnectionObservation _ status)) -> case status of
Expand All @@ -54,6 +55,8 @@ observationMetrics (MetricsState poolTimeouts poolAvailable poolWaiting _ schema
setGauge schemaCacheQueryTime resTime
SchemaCacheErrorObs _ -> do
withLabel schemaCacheLoads "FAIL" incCounter
JWTCache cacheSize -> do
setGauge jwtCacheSize $ fromIntegral cacheSize
_ ->
pure ()

Expand Down
7 changes: 6 additions & 1 deletion src/PostgREST/Observation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,13 @@ data Observation
| HasqlPoolObs SQL.Observation
| PoolRequest
| PoolRequestFullfilled
| JWTCache Int

data ObsFatalError = ServerAuthError | ServerPgrstBug | ServerError42P05 | ServerError08P01
data ObsFatalError
= ServerAuthError
| ServerPgrstBug
| ServerError42P05
| ServerError08P01

type ObservationHandler = Observation -> IO ()

Expand Down
57 changes: 57 additions & 0 deletions test/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,6 +1632,8 @@ def test_admin_metrics(defaultenv):
assert "pgrst_db_pool_available" in response.text
assert "pgrst_db_pool_timeouts_total" in response.text

assert "pgrst_jwt_cache_size" in response.text


def test_schema_cache_startup_load_with_in_db_config(defaultenv, metapostgrest):
"verify that the Schema Cache loads correctly at startup, using the in-db `pgrst.db_schemas` config"
Expand All @@ -1648,3 +1650,58 @@ def test_schema_cache_startup_load_with_in_db_config(defaultenv, metapostgrest):
response = metapostgrest.session.post("/rpc/reset_db_schemas_config")
assert response.text == ""
assert response.status_code == 204


def test_jwt_cache_size_decreases_after_expiry(defaultenv):
"verify that JWT purges expired JWTs"

relativeSeconds = lambda sec: int(
(datetime.now(timezone.utc) + timedelta(seconds=sec)).timestamp()
)

headers = lambda sec: jwtauthheader(
{"role": "postgrest_test_author", "exp": relativeSeconds(sec)},
SECRET,
)

env = {
**defaultenv,
"PGRST_JWT_CACHE_MAX_LIFETIME": "86400",
"PGRST_JWT_SECRET": SECRET,
"PGRST_DB_CONFIG": "false",
}

with run(env=env, port=freeport()) as postgrest:

# Generate three unique JWT tokens
# The 1 second sleep is needed for it generate a unique token
hdrs1 = headers(5)
postgrest.session.get("/authors_only", headers=hdrs1)

time.sleep(1)

hdrs2 = headers(5)
postgrest.session.get("/authors_only", headers=hdrs2)

time.sleep(1)

hdrs3 = headers(5)
postgrest.session.get("/authors_only", headers=hdrs3)

# the cache should now have three tokens
response = postgrest.admin.get("/metrics")
assert response.status_code == 200
assert "pgrst_jwt_cache_size 3.0" in response.text
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@taimoorzaeem Is it possible to have the size measured on bytes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe. There is no built-in function for this in Data.Cache. I have tried using https://hackage.haskell.org/package/ghc-datasize to calculate size in bytes at runtime, but I have run into problems with that too.

We can try manually writing some logic to calculate the size in bytes, but I guess that would require some effort.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There doesn't seem to be much code in that library: https://github.com/def-/ghc-datasize/blob/master/src/GHC/DataSize.hs. So maybe we can just vendor the code and use unsafePerformIO if it's cumbersome to integrate it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, this metric could be a new PR so it's easier to review.


# Wait 5 seconds for the tokens to expire
time.sleep(5)

hdrs4 = headers(5)

# Make another request to force call the purgeExpired method
# This should remove the 3 expired tokens and adds 1 to cache
postgrest.session.get("/authors_only", headers=hdrs4)

response = postgrest.admin.get("/metrics")
assert response.status_code == 200
assert "pgrst_jwt_cache_size 1.0" in response.text
Loading