Skip to content

Commit

Permalink
test: overriding anyelement with particular agg
Browse files Browse the repository at this point in the history
  • Loading branch information
steve-chavez committed Oct 24, 2023
1 parent 1114f0f commit 8ac8067
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 38 deletions.
9 changes: 6 additions & 3 deletions src/PostgREST/Plan.hs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ data WrappedReadPlan = WrappedReadPlan {
, wrTxMode :: SQL.Mode
, wrResAgg :: ResultAggregate
, wrMedia :: MediaType
, wrIdent :: QualifiedIdentifier
}

data MutateReadPlan = MutateReadPlan {
Expand All @@ -103,6 +104,7 @@ data MutateReadPlan = MutateReadPlan {
, mrTxMode :: SQL.Mode
, mrResAgg :: ResultAggregate
, mrMedia :: MediaType
, mrIdent :: QualifiedIdentifier
}

data CallReadPlan = CallReadPlan {
Expand All @@ -112,6 +114,7 @@ data CallReadPlan = CallReadPlan {
, crProc :: Routine
, crResAgg :: ResultAggregate
, crMedia :: MediaType
, crIdent :: QualifiedIdentifier
}

data InspectPlan = InspectPlan {
Expand All @@ -124,15 +127,15 @@ wrappedReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Prefe
rPlan <- readPlan identifier conf sCache apiRequest
(rAgg, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaAggs sCache)
if not (null invalidPrefs) && preferHandling == Just Strict then Left $ ApiRequestError $ InvalidPreferences invalidPrefs else Right ()
return $ WrappedReadPlan rPlan SQL.Read rAgg mediaType
return $ WrappedReadPlan rPlan SQL.Read rAgg mediaType identifier

mutateReadPlan :: Mutation -> ApiRequest -> QualifiedIdentifier -> AppConfig -> SchemaCache -> Either Error MutateReadPlan
mutateReadPlan mutation apiRequest@ApiRequest{iPreferences=Preferences{..},..} identifier conf sCache = do
rPlan <- readPlan identifier conf sCache apiRequest
mPlan <- mutatePlan mutation identifier apiRequest sCache rPlan
if not (null invalidPrefs) && preferHandling == Just Strict then Left $ ApiRequestError $ InvalidPreferences invalidPrefs else Right ()
(rAgg, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaAggs sCache)
return $ MutateReadPlan rPlan mPlan SQL.Write rAgg mediaType
return $ MutateReadPlan rPlan mPlan SQL.Write rAgg mediaType identifier

callReadPlan :: QualifiedIdentifier -> AppConfig -> SchemaCache -> ApiRequest -> InvokeMethod -> Either Error CallReadPlan
callReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Preferences{..},..} invMethod = do
Expand All @@ -158,7 +161,7 @@ callReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Preferenc
cPlan = callPlan proc apiRequest paramKeys args rPlan
(rAgg, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest relIdentifier iAcceptMediaType (dbMediaAggs sCache)
if not (null invalidPrefs) && preferHandling == Just Strict then Left $ ApiRequestError $ InvalidPreferences invalidPrefs else Right ()
return $ CallReadPlan rPlan cPlan txMode proc rAgg mediaType
return $ CallReadPlan rPlan cPlan txMode proc rAgg mediaType relIdentifier
where
qsParams' = QueryParams.qsParams iQueryParams

Expand Down
5 changes: 4 additions & 1 deletion src/PostgREST/Query.hs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ readQuery WrappedReadPlan{..} conf@AppConfig{..} apiReq@ApiRequest{iPreferences=
resultSet <-
lift . SQL.statement mempty $
Statements.prepareRead
wrIdent
(QueryBuilder.readPlanToQuery wrReadPlan)
(if preferCount == Just EstimatedCount then
-- LIMIT maxRows + 1 so we can determine below that maxRows was surpassed
Expand Down Expand Up @@ -155,6 +156,7 @@ invokeQuery rout CallReadPlan{..} apiReq@ApiRequest{iPreferences=Preferences{..}
resultSet <-
lift . SQL.statement mempty $
Statements.prepareCall
crIdent
rout
(QueryBuilder.callPlanToQuery crCallPlan pgVer)
(QueryBuilder.readPlanToQuery crReadPlan)
Expand Down Expand Up @@ -186,12 +188,13 @@ openApiQuery sCache pgVer AppConfig{..} tSchema =
pure Nothing

writeQuery :: MutateReadPlan -> ApiRequest -> AppConfig -> DbHandler ResultSet
writeQuery MutateReadPlan{mrReadPlan, mrMutatePlan, mrResAgg, mrMedia} ApiRequest{iPreferences=Preferences{..}} conf =
writeQuery MutateReadPlan{mrReadPlan, mrMutatePlan, mrResAgg, mrMedia, mrIdent} ApiRequest{iPreferences=Preferences{..}} conf =
let
(isInsert, pkCols) = case mrMutatePlan of {Insert{insPkCols} -> (True, insPkCols); _ -> (False, mempty);}
in
lift . SQL.statement mempty $
Statements.prepareWrite
mrIdent
(QueryBuilder.readPlanToQuery mrReadPlan)
(QueryBuilder.mutatePlanToQuery mrMutatePlan)
isInsert
Expand Down
20 changes: 7 additions & 13 deletions src/PostgREST/Query/SqlFragment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,10 @@ asJsonF rout strip
asGeoJsonF :: SQL.Snippet
asGeoJsonF = "json_build_object('type', 'FeatureCollection', 'features', coalesce(json_agg(ST_AsGeoJSON(_postgrest_t)::json), '[]'))"

customAggF :: Maybe Routine -> QualifiedIdentifier -> SQL.Snippet
customAggF rout qi
| returnsSingleComposite = fromQi qi <> "(_postgrest_t)"
| returnsScalar = fromQi qi <> "(_postgrest_t.pgrst_scalar)"
| returnsSetOfScalar = fromQi qi <> "(_postgrest_t.pgrst_scalar)"
| otherwise = fromQi qi <> "(_postgrest_t)"
where
(returnsSingleComposite, returnsScalar, returnsSetOfScalar) = case rout of
Just r -> (funcReturnsSingleComposite r, funcReturnsScalar r, funcReturnsSetOfScalar r)
Nothing -> (False, False, False)
customAggF :: Maybe Routine -> QualifiedIdentifier -> QualifiedIdentifier -> SQL.Snippet
customAggF rout funcQi target
| (funcReturnsScalar <$> rout) == Just True = fromQi funcQi <> "(_postgrest_t.pgrst_scalar)"
| otherwise = fromQi funcQi <> "(_postgrest_t::" <> fromQi target <> ")"

locationF :: [Text] -> SQL.Snippet
locationF pKeys = [qc|(
Expand Down Expand Up @@ -497,12 +491,12 @@ setConfigLocalJson prefix keyVals = [setConfigLocal mempty (prefix, gucJsonVal k
arrayByteStringToText :: [(ByteString, ByteString)] -> [(Text,Text)]
arrayByteStringToText keyVal = (T.decodeUtf8 *** T.decodeUtf8) <$> keyVal

aggF :: Maybe Routine -> ResultAggregate -> SQL.Snippet
aggF rout = \case
aggF :: Maybe Routine -> QualifiedIdentifier -> ResultAggregate -> SQL.Snippet
aggF rout target = \case
BuiltinAggJson -> asJsonF rout False
BuiltinAggArrayJsonStrip -> asJsonF rout True
BuiltinAggSingleJson strip -> asJsonSingleF rout strip
BuiltinAggGeoJson -> asGeoJsonF
BuiltinAggCsv -> asCsvF
CustomAgg qi -> customAggF rout qi
CustomAgg funcQi -> customAggF rout funcQi target
NoAgg -> "''::text"
27 changes: 14 additions & 13 deletions src/PostgREST/Query/Statements.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ import qualified Hasql.Statement as SQL
import Control.Lens ((^?))

import PostgREST.ApiRequest.Preferences
import PostgREST.MediaType (MTPlanFormat (..),
MediaType (..))
import PostgREST.MediaType (MTPlanFormat (..),
MediaType (..))
import PostgREST.Query.SqlFragment
import PostgREST.SchemaCache.Routine (ResultAggregate (..),
Routine, funcReturnsSingle)
import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier)
import PostgREST.SchemaCache.Routine (ResultAggregate (..),
Routine, funcReturnsSingle)

import Protolude

Expand All @@ -53,9 +54,9 @@ data ResultSet
| RSPlan BS.ByteString -- ^ the plan of the query


prepareWrite :: SQL.Snippet -> SQL.Snippet -> Bool -> MediaType -> ResultAggregate ->
prepareWrite :: QualifiedIdentifier -> SQL.Snippet -> SQL.Snippet -> Bool -> MediaType -> ResultAggregate ->
Maybe PreferRepresentation -> [Text] -> Bool -> SQL.Statement () ResultSet
prepareWrite selectQuery mutateQuery isInsert mt rAgg rep pKeys =
prepareWrite qi selectQuery mutateQuery isInsert mt rAgg rep pKeys =
SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt
where
snippet =
Expand All @@ -64,7 +65,7 @@ prepareWrite selectQuery mutateQuery isInsert mt rAgg rep pKeys =
"'' AS total_result_set, " <>
"pg_catalog.count(_postgrest_t) AS page_total, " <>
locF <> " AS header, " <>
aggF Nothing rAgg <> " AS body, " <>
aggF Nothing qi rAgg <> " AS body, " <>
responseHeadersF <> " AS response_headers, " <>
responseStatusF <> " AS response_status " <>
"FROM (" <> selectF <> ") _postgrest_t"
Expand All @@ -88,8 +89,8 @@ prepareWrite selectQuery mutateQuery isInsert mt rAgg rep pKeys =
MTPlan{} -> planRow
_ -> fromMaybe (RSStandard Nothing 0 mempty mempty Nothing Nothing) <$> HD.rowMaybe (standardRow False)

prepareRead :: SQL.Snippet -> SQL.Snippet -> Bool -> MediaType -> ResultAggregate -> Bool -> SQL.Statement () ResultSet
prepareRead selectQuery countQuery countTotal mt rAgg =
prepareRead :: QualifiedIdentifier -> SQL.Snippet -> SQL.Snippet -> Bool -> MediaType -> ResultAggregate -> Bool -> SQL.Statement () ResultSet
prepareRead qi selectQuery countQuery countTotal mt rAgg =
SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt
where
snippet =
Expand All @@ -98,7 +99,7 @@ prepareRead selectQuery countQuery countTotal mt rAgg =
"SELECT " <>
countResultF <> " AS total_result_set, " <>
"pg_catalog.count(_postgrest_t) AS page_total, " <>
aggF Nothing rAgg <> " AS body, " <>
aggF Nothing qi rAgg <> " AS body, " <>
responseHeadersF <> " AS response_headers, " <>
responseStatusF <> " AS response_status " <>
"FROM ( SELECT * FROM " <> sourceCTE <> " ) _postgrest_t"
Expand All @@ -110,10 +111,10 @@ prepareRead selectQuery countQuery countTotal mt rAgg =
MTPlan{} -> planRow
_ -> HD.singleRow $ standardRow True

prepareCall :: Routine -> SQL.Snippet -> SQL.Snippet -> SQL.Snippet -> Bool ->
prepareCall :: QualifiedIdentifier -> Routine -> SQL.Snippet -> SQL.Snippet -> SQL.Snippet -> Bool ->
MediaType -> ResultAggregate -> Bool ->
SQL.Statement () ResultSet
prepareCall rout callProcQuery selectQuery countQuery countTotal mt rAgg =
prepareCall qi rout callProcQuery selectQuery countQuery countTotal mt rAgg =
SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt
where
snippet =
Expand All @@ -124,7 +125,7 @@ prepareCall rout callProcQuery selectQuery countQuery countTotal mt rAgg =
(if funcReturnsSingle rout
then "1"
else "pg_catalog.count(_postgrest_t)") <> " AS page_total, " <>
aggF (Just rout) rAgg <> " AS body, " <>
aggF (Just rout) qi rAgg <> " AS body, " <>
responseHeadersF <> " AS response_headers, " <>
responseStatusF <> " AS response_status " <>
"FROM (" <> selectQuery <> ") _postgrest_t"
Expand Down
11 changes: 9 additions & 2 deletions test/spec/Feature/Query/CustomMediaSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,17 @@ spec = describe "custom media types" $ do
, matchHeaders = ["Content-Type" <:> "application/vnd.geo2+json"]
}

it "will use the application/vnd.geo2+json media type for any table" $
it "will use the more specific application/vnd.geo2 handler for this table" $ do
request methodGet "/shop_bles" (acceptHdrs "application/vnd.geo2+json") ""
`shouldRespondWith`
"\SOH{\"type\": \"FeatureCollection\", \"hello\": \"world\"}"
"\SOH\"anyelement overridden\""
{ matchStatus = 200
, matchHeaders = ["Content-Type" <:> "application/vnd.geo2+json"]
}

request methodGet "/rpc/get_shop_bles" (acceptHdrs "application/vnd.geo2+json") ""
`shouldRespondWith`
"\SOH\"anyelement overridden\""
{ matchStatus = 200
, matchHeaders = ["Content-Type" <:> "application/vnd.geo2+json"]
}
Expand Down
29 changes: 23 additions & 6 deletions test/spec/fixtures/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3534,6 +3534,11 @@ returns setof test.lines as $$
select * from lines;
$$ language sql;

create or replace function test.get_shop_bles ()
returns setof test.shop_bles as $$
select * from shop_bles;
$$ language sql;

-- it can work without a final function too if the stype is already the media type
DO $do$
BEGIN
Expand Down Expand Up @@ -3568,22 +3573,34 @@ BEGIN
END
$do$;

create or replace function test.geoson_trans (state "application/vnd.geo2+json", next anyelement)
create or replace function test.geo2json_trans (state "application/vnd.geo2+json", next anyelement)
returns "application/vnd.geo2+json" as $$
select (state || extensions.ST_AsGeoJSON(next)::jsonb)::"application/vnd.geo2+json";
$$ language sql;

create or replace function test.geojson_final (data "application/vnd.geo2+json")
create or replace function test.geo2json_final (data "application/vnd.geo2+json")
returns "application/vnd.geo2+json" as $$
select (jsonb_build_object('type', 'FeatureCollection', 'hello', 'world'))::"application/vnd.geo2+json";
$$ language sql;

drop aggregate if exists test.geojson_agg(anyelement);
create aggregate test.geojson_agg(anyelement) (
drop aggregate if exists test.geo2json_agg(anyelement);
create aggregate test.geo2json_agg(anyelement) (
initcond = '[]'
, stype = "application/vnd.geo2+json"
, sfunc = geo2json_trans
, finalfunc = geo2json_final
);

create or replace function test.geo2json_trans (state "application/vnd.geo2+json", next test.shop_bles)
returns "application/vnd.geo2+json" as $$
select '"anyelement overridden"'::"application/vnd.geo2+json";
$$ language sql;

drop aggregate if exists test.geo2json_agg(test.shop_bles);
create aggregate test.geo2json_agg(test.shop_bles) (
initcond = '[]'
, stype = "application/vnd.geo2+json"
, sfunc = geoson_trans
, finalfunc = geojson_final
, sfunc = geo2json_trans
);

create table ov_json ();
Expand Down

0 comments on commit 8ac8067

Please sign in to comment.