From 125f10a60f9928981224557464365a723f7fc232 Mon Sep 17 00:00:00 2001 From: Taimoor Zaeem Date: Fri, 17 Nov 2023 22:36:51 +0500 Subject: [PATCH] feat: add statement_timeout set on functions (#3056) --- CHANGELOG.md | 1 + src/PostgREST/App.hs | 18 +++++++++--------- src/PostgREST/Query.hs | 7 ++++--- src/PostgREST/SchemaCache.hs | 7 +++++-- src/PostgREST/SchemaCache/Routine.hs | 8 +++++--- test/io/fixtures.sql | 8 ++++++++ test/io/test_io.py | 22 ++++++++++++++++++++++ 7 files changed, 54 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f7223bcb8..5c7a11d031 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). - #2825, SQL handlers for custom media types - @steve-chavez + Solves #1548, #2699, #2763, #2170, #1462, #1102, #1374, #2901 - #2799, Add timezone in Prefer header - @taimoorzaeem + - #3001, Add `statement_timeout` set on functions - @taimoorzaeem ### Fixed diff --git a/src/PostgREST/App.hs b/src/PostgREST/App.hs index c344201270..cd23c59b0c 100644 --- a/src/PostgREST/App.hs +++ b/src/PostgREST/App.hs @@ -179,49 +179,49 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A case (iAction, iTarget) of (ActionRead headersOnly, TargetIdent identifier) -> do (planTime', wrPlan) <- withTiming $ liftEither $ Plan.wrappedReadPlan identifier conf sCache apiReq - (rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.wrTxMode wrPlan) $ Query.readQuery wrPlan conf apiReq + (rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.wrTxMode wrPlan) $ Query.readQuery wrPlan conf apiReq (renderTime', pgrst) <- withTiming $ liftEither $ Response.readResponse wrPlan headersOnly identifier apiReq resultSet let metrics = Map.fromList [(SMPlan, planTime'), (SMQuery, rsTime'), (SMRender, renderTime'), jwtTime] return $ pgrstResponse metrics pgrst (ActionMutate MutationCreate, TargetIdent identifier) -> do (planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationCreate apiReq identifier conf sCache - (rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) $ Query.createQuery mrPlan apiReq conf + (rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.createQuery mrPlan apiReq conf (renderTime', pgrst) <- withTiming $ liftEither $ Response.createResponse identifier mrPlan apiReq resultSet let metrics = Map.fromList [(SMPlan, planTime'), (SMQuery, rsTime'), (SMRender, renderTime'), jwtTime] return $ pgrstResponse metrics pgrst (ActionMutate MutationUpdate, TargetIdent identifier) -> do (planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationUpdate apiReq identifier conf sCache - (rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) $ Query.updateQuery mrPlan apiReq conf + (rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.updateQuery mrPlan apiReq conf (renderTime', pgrst) <- withTiming $ liftEither $ Response.updateResponse mrPlan apiReq resultSet let metrics = Map.fromList [(SMPlan, planTime'), (SMQuery, rsTime'), (SMRender, renderTime'), jwtTime] return $ pgrstResponse metrics pgrst (ActionMutate MutationSingleUpsert, TargetIdent identifier) -> do (planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationSingleUpsert apiReq identifier conf sCache - (rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) $ Query.singleUpsertQuery mrPlan apiReq conf + (rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.singleUpsertQuery mrPlan apiReq conf (renderTime', pgrst) <- withTiming $ liftEither $ Response.singleUpsertResponse mrPlan apiReq resultSet let metrics = Map.fromList [(SMPlan, planTime'), (SMQuery, rsTime'), (SMRender, renderTime'), jwtTime] return $ pgrstResponse metrics pgrst (ActionMutate MutationDelete, TargetIdent identifier) -> do (planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationDelete apiReq identifier conf sCache - (rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) $ Query.deleteQuery mrPlan apiReq conf + (rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.deleteQuery mrPlan apiReq conf (renderTime', pgrst) <- withTiming $ liftEither $ Response.deleteResponse mrPlan apiReq resultSet let metrics = Map.fromList [(SMPlan, planTime'), (SMQuery, rsTime'), (SMRender, renderTime'), jwtTime] return $ pgrstResponse metrics pgrst (ActionInvoke invMethod, TargetProc identifier _) -> do (planTime', cPlan) <- withTiming $ liftEither $ Plan.callReadPlan identifier conf sCache apiReq invMethod - (rsTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (Plan.crTxMode cPlan) $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer + (rsTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (pdTimeout $ Plan.crProc cPlan) (Plan.crTxMode cPlan) $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer (renderTime', pgrst) <- withTiming $ liftEither $ Response.invokeResponse cPlan invMethod (Plan.crProc cPlan) apiReq resultSet let metrics = Map.fromList [(SMPlan, planTime'), (SMQuery, rsTime'), (SMRender, renderTime'), jwtTime] return $ pgrstResponse metrics pgrst (ActionInspect headersOnly, TargetDefaultSpec tSchema) -> do (planTime', iPlan) <- withTiming $ liftEither $ Plan.inspectPlan apiReq - (rsTime', oaiResult) <- withTiming $ runQuery roleIsoLvl (Plan.ipTxmode iPlan) $ Query.openApiQuery sCache pgVer conf tSchema + (rsTime', oaiResult) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.ipTxmode iPlan) $ Query.openApiQuery sCache pgVer conf tSchema (renderTime', pgrst) <- withTiming $ liftEither $ Response.openApiResponse (T.decodeUtf8 prettyVersion, docsVersion) headersOnly oaiResult conf sCache iSchema iNegotiatedByProfile let metrics = Map.fromList [(SMPlan, planTime'), (SMQuery, rsTime'), (SMRender, renderTime'), jwtTime] return $ pgrstResponse metrics pgrst @@ -249,9 +249,9 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A where roleSettings = fromMaybe mempty (HM.lookup authRole $ configRoleSettings conf) roleIsoLvl = HM.findWithDefault SQL.ReadCommitted authRole $ configRoleIsoLvl conf - runQuery isoLvl mode query = + runQuery isoLvl timeout mode query = runDbHandler appState isoLvl mode authenticated prepared $ do - Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) apiReq + Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) apiReq timeout Query.runPreReq conf query diff --git a/src/PostgREST/Query.hs b/src/PostgREST/Query.hs index 835559a52f..6733c7b31f 100644 --- a/src/PostgREST/Query.hs +++ b/src/PostgREST/Query.hs @@ -235,10 +235,10 @@ optionalRollback AppConfig{..} ApiRequest{iPreferences=Preferences{..}} = do -- | Runs local (transaction scoped) GUCs for every request. setPgLocals :: AppConfig -> KM.KeyMap JSON.Value -> BS.ByteString -> [(ByteString, ByteString)] -> - ApiRequest -> DbHandler () -setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} = lift $ + ApiRequest -> Maybe Text -> DbHandler () +setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $ SQL.statement mempty $ SQL.dynamicallyParameterized - ("select " <> intercalateSnippet ", " (searchPathSql : roleSql ++ roleSettingsSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ appSettingsSql)) + ("select " <> intercalateSnippet ", " (searchPathSql : roleSql ++ roleSettingsSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ timeoutSql ++ appSettingsSql)) HD.noResult configDbPreparedStatements where methodSql = setConfigWithConstantName ("request.method", iMethod) @@ -250,6 +250,7 @@ setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} = lift $ roleSettingsSql = setConfigWithDynamicName <$> roleSettings appSettingsSql = setConfigWithDynamicName <$> (join bimap toUtf8 <$> configAppSettings) timezoneSql = maybe mempty (\(PreferTimezone tz) -> [setConfigWithConstantName ("timezone", tz)]) $ preferTimezone iPreferences + timeoutSql = maybe mempty ((\t -> [setConfigWithConstantName ("statement_timeout", t)]) . encodeUtf8) tout searchPathSql = let schemas = escapeIdentList (iSchema : configDbExtraSearchPath) in setConfigWithConstantName ("search_path", schemas) diff --git a/src/PostgREST/SchemaCache.hs b/src/PostgREST/SchemaCache.hs index 9f80e19bd8..9228cc0a56 100644 --- a/src/PostgREST/SchemaCache.hs +++ b/src/PostgREST/SchemaCache.hs @@ -297,6 +297,7 @@ decodeFuncs = <*> (parseVolatility <$> column HD.char) <*> column HD.bool <*> nullableColumn (toIsolationLevel <$> HD.text) + <*> nullableColumn HD.text addKey :: Routine -> (QualifiedIdentifier, Routine) addKey pd = (QualifiedIdentifier (pdSchema pd) (pdName pd), pd) @@ -430,7 +431,8 @@ funcsSqlQuery pgVer = [q| bt.oid <> bt.base as rettype_is_composite_alias, p.provolatile, p.provariadic > 0 as hasvariadic, - lower((regexp_split_to_array((regexp_split_to_array(config, '='))[2], ','))[1]) AS transaction_isolation_level + lower((regexp_split_to_array((regexp_split_to_array(iso_config, '='))[2], ','))[1]) AS transaction_isolation_level, + lower((regexp_split_to_array((regexp_split_to_array(timeout_config, '='))[2], ','))[1]) AS statement_timeout FROM pg_proc p LEFT JOIN arguments a ON a.oid = p.oid JOIN pg_namespace pn ON pn.oid = p.pronamespace @@ -439,7 +441,8 @@ funcsSqlQuery pgVer = [q| JOIN pg_namespace tn ON tn.oid = t.typnamespace LEFT JOIN pg_class comp ON comp.oid = t.typrelid LEFT JOIN pg_description as d ON d.objoid = p.oid - LEFT JOIN LATERAL unnest(proconfig) config ON config like 'default_transaction_isolation%' + LEFT JOIN LATERAL unnest(proconfig) iso_config ON iso_config like 'default_transaction_isolation%' + LEFT JOIN LATERAL unnest(proconfig) timeout_config ON timeout_config like 'statement_timeout%' WHERE t.oid <> 'trigger'::regtype AND COALESCE(a.callable, true) |] <> (if pgVer >= pgVersion110 then "AND prokind = 'f'" else "AND NOT (proisagg OR proiswindow)") diff --git a/src/PostgREST/SchemaCache/Routine.hs b/src/PostgREST/SchemaCache/Routine.hs index 572ccc684c..8445ad2cf9 100644 --- a/src/PostgREST/SchemaCache/Routine.hs +++ b/src/PostgREST/SchemaCache/Routine.hs @@ -57,11 +57,12 @@ data Routine = Function , pdVolatility :: FuncVolatility , pdHasVariadic :: Bool , pdIsoLvl :: Maybe SQL.IsolationLevel + , pdTimeout :: Maybe Text } deriving (Eq, Show, Generic) -- need to define JSON manually bc SQL.IsolationLevel doesn't have a JSON instance(and we can't define one for that type without getting a compiler error) instance JSON.ToJSON Routine where - toJSON (Function sch nam desc params ret vol hasVar _) = JSON.object + toJSON (Function sch nam desc params ret vol hasVar _ tout) = JSON.object [ "pdSchema" .= sch , "pdName" .= nam @@ -70,6 +71,7 @@ instance JSON.ToJSON Routine where , "pdReturnType" .= JSON.toJSON ret , "pdVolatility" .= JSON.toJSON vol , "pdHasVariadic" .= JSON.toJSON hasVar + , "pdTimeout" .= tout ] data RoutineParam = RoutineParam @@ -83,10 +85,10 @@ data RoutineParam = RoutineParam -- Order by least number of params in the case of overloaded functions instance Ord Routine where - Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2 + Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 tout1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2 tout2 | schema1 == schema2 && name1 == name2 && length prms1 < length prms2 = LT | schema2 == schema2 && name1 == name2 && length prms1 > length prms2 = GT - | otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2) + | otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1, tout1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2, tout2) -- | A map of all procs, all of which can be overloaded(one entry will have more than one Routine). -- | It uses a HashMap for a faster lookup. diff --git a/test/io/fixtures.sql b/test/io/fixtures.sql index bc1dd96707..92ad23c5bb 100644 --- a/test/io/fixtures.sql +++ b/test/io/fixtures.sql @@ -178,3 +178,11 @@ $$; create function terminate_pgrst() returns setof record as $$ select pg_terminate_backend(pid) from pg_stat_activity where application_name iLIKE '%postgrest%'; $$ language sql security definer; + +create or replace function one_sec_timeout() returns void as $$ + select pg_sleep(3); +$$ language sql set statement_timeout = '1s'; + +create or replace function four_sec_timeout() returns void as $$ + select pg_sleep(3); +$$ language sql set statement_timeout = '4s'; diff --git a/test/io/test_io.py b/test/io/test_io.py index 91b5421ada..21202fdc5d 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -1294,3 +1294,25 @@ def test_no_preflight_request_with_CORS_config_should_not_return_header(defaulte with run(env=env) as postgrest: response = postgrest.session.get("/items", headers=headers) assert "Access-Control-Allow-Origin" not in response.headers + + +def test_fail_with_3_sec_statement_and_1_sec_statement_timeout(defaultenv): + "statement that takes three seconds to execute should fail with one second timeout" + + with run(env=defaultenv) as postgrest: + response = postgrest.session.post("/rpc/one_sec_timeout") + + assert response.status_code == 500 + assert ( + response.text + == '{"code":"57014","details":null,"hint":null,"message":"canceling statement due to statement timeout"}' + ) + + +def test_passes_with_3_sec_statement_and_4_sec_statement_timeout(defaultenv): + "statement that takes three seconds to execute should succeed with four second timeout" + + with run(env=defaultenv) as postgrest: + response = postgrest.session.post("/rpc/four_sec_timeout") + + assert response.status_code == 204