Skip to content

Commit

Permalink
test: overriding anyelement with particular agg
Browse files Browse the repository at this point in the history
  • Loading branch information
steve-chavez committed Oct 24, 2023
1 parent 1114f0f commit 6462f64
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 34 deletions.
9 changes: 6 additions & 3 deletions src/PostgREST/Plan.hs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ data WrappedReadPlan = WrappedReadPlan {
, wrTxMode :: SQL.Mode
, wrResAgg :: ResultAggregate
, wrMedia :: MediaType
, wrIdent :: QualifiedIdentifier
}

data MutateReadPlan = MutateReadPlan {
Expand All @@ -103,6 +104,7 @@ data MutateReadPlan = MutateReadPlan {
, mrTxMode :: SQL.Mode
, mrResAgg :: ResultAggregate
, mrMedia :: MediaType
, mrIdent :: QualifiedIdentifier
}

data CallReadPlan = CallReadPlan {
Expand All @@ -112,6 +114,7 @@ data CallReadPlan = CallReadPlan {
, crProc :: Routine
, crResAgg :: ResultAggregate
, crMedia :: MediaType
, crIdent :: QualifiedIdentifier
}

data InspectPlan = InspectPlan {
Expand All @@ -124,15 +127,15 @@ wrappedReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Prefe
rPlan <- readPlan identifier conf sCache apiRequest
(rAgg, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaAggs sCache)
if not (null invalidPrefs) && preferHandling == Just Strict then Left $ ApiRequestError $ InvalidPreferences invalidPrefs else Right ()
return $ WrappedReadPlan rPlan SQL.Read rAgg mediaType
return $ WrappedReadPlan rPlan SQL.Read rAgg mediaType identifier

mutateReadPlan :: Mutation -> ApiRequest -> QualifiedIdentifier -> AppConfig -> SchemaCache -> Either Error MutateReadPlan
mutateReadPlan mutation apiRequest@ApiRequest{iPreferences=Preferences{..},..} identifier conf sCache = do
rPlan <- readPlan identifier conf sCache apiRequest
mPlan <- mutatePlan mutation identifier apiRequest sCache rPlan
if not (null invalidPrefs) && preferHandling == Just Strict then Left $ ApiRequestError $ InvalidPreferences invalidPrefs else Right ()
(rAgg, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaAggs sCache)
return $ MutateReadPlan rPlan mPlan SQL.Write rAgg mediaType
return $ MutateReadPlan rPlan mPlan SQL.Write rAgg mediaType identifier

callReadPlan :: QualifiedIdentifier -> AppConfig -> SchemaCache -> ApiRequest -> InvokeMethod -> Either Error CallReadPlan
callReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Preferences{..},..} invMethod = do
Expand All @@ -158,7 +161,7 @@ callReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Preferenc
cPlan = callPlan proc apiRequest paramKeys args rPlan
(rAgg, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest relIdentifier iAcceptMediaType (dbMediaAggs sCache)
if not (null invalidPrefs) && preferHandling == Just Strict then Left $ ApiRequestError $ InvalidPreferences invalidPrefs else Right ()
return $ CallReadPlan rPlan cPlan txMode proc rAgg mediaType
return $ CallReadPlan rPlan cPlan txMode proc rAgg mediaType relIdentifier
where
qsParams' = QueryParams.qsParams iQueryParams

Expand Down
5 changes: 4 additions & 1 deletion src/PostgREST/Query.hs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ readQuery WrappedReadPlan{..} conf@AppConfig{..} apiReq@ApiRequest{iPreferences=
resultSet <-
lift . SQL.statement mempty $
Statements.prepareRead
wrIdent
(QueryBuilder.readPlanToQuery wrReadPlan)
(if preferCount == Just EstimatedCount then
-- LIMIT maxRows + 1 so we can determine below that maxRows was surpassed
Expand Down Expand Up @@ -155,6 +156,7 @@ invokeQuery rout CallReadPlan{..} apiReq@ApiRequest{iPreferences=Preferences{..}
resultSet <-
lift . SQL.statement mempty $
Statements.prepareCall
crIdent
rout
(QueryBuilder.callPlanToQuery crCallPlan pgVer)
(QueryBuilder.readPlanToQuery crReadPlan)
Expand Down Expand Up @@ -186,12 +188,13 @@ openApiQuery sCache pgVer AppConfig{..} tSchema =
pure Nothing

writeQuery :: MutateReadPlan -> ApiRequest -> AppConfig -> DbHandler ResultSet
writeQuery MutateReadPlan{mrReadPlan, mrMutatePlan, mrResAgg, mrMedia} ApiRequest{iPreferences=Preferences{..}} conf =
writeQuery MutateReadPlan{mrReadPlan, mrMutatePlan, mrResAgg, mrMedia, mrIdent} ApiRequest{iPreferences=Preferences{..}} conf =
let
(isInsert, pkCols) = case mrMutatePlan of {Insert{insPkCols} -> (True, insPkCols); _ -> (False, mempty);}
in
lift . SQL.statement mempty $
Statements.prepareWrite
mrIdent
(QueryBuilder.readPlanToQuery mrReadPlan)
(QueryBuilder.mutatePlanToQuery mrMutatePlan)
isInsert
Expand Down
20 changes: 7 additions & 13 deletions src/PostgREST/Query/SqlFragment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,10 @@ asJsonF rout strip
asGeoJsonF :: SQL.Snippet
asGeoJsonF = "json_build_object('type', 'FeatureCollection', 'features', coalesce(json_agg(ST_AsGeoJSON(_postgrest_t)::json), '[]'))"

customAggF :: Maybe Routine -> QualifiedIdentifier -> SQL.Snippet
customAggF rout qi
| returnsSingleComposite = fromQi qi <> "(_postgrest_t)"
| returnsScalar = fromQi qi <> "(_postgrest_t.pgrst_scalar)"
| returnsSetOfScalar = fromQi qi <> "(_postgrest_t.pgrst_scalar)"
| otherwise = fromQi qi <> "(_postgrest_t)"
where
(returnsSingleComposite, returnsScalar, returnsSetOfScalar) = case rout of
Just r -> (funcReturnsSingleComposite r, funcReturnsScalar r, funcReturnsSetOfScalar r)
Nothing -> (False, False, False)
customAggF :: Maybe Routine -> QualifiedIdentifier -> QualifiedIdentifier -> SQL.Snippet
customAggF rout funcQi target
| (funcReturnsScalar <$> rout) == Just True = fromQi funcQi <> "(_postgrest_t.pgrst_scalar)"
| otherwise = fromQi funcQi <> "(_postgrest_t::" <> fromQi target <> ")"

locationF :: [Text] -> SQL.Snippet
locationF pKeys = [qc|(
Expand Down Expand Up @@ -497,12 +491,12 @@ setConfigLocalJson prefix keyVals = [setConfigLocal mempty (prefix, gucJsonVal k
arrayByteStringToText :: [(ByteString, ByteString)] -> [(Text,Text)]
arrayByteStringToText keyVal = (T.decodeUtf8 *** T.decodeUtf8) <$> keyVal

aggF :: Maybe Routine -> ResultAggregate -> SQL.Snippet
aggF rout = \case
aggF :: Maybe Routine -> QualifiedIdentifier -> ResultAggregate -> SQL.Snippet
aggF rout target = \case
BuiltinAggJson -> asJsonF rout False
BuiltinAggArrayJsonStrip -> asJsonF rout True
BuiltinAggSingleJson strip -> asJsonSingleF rout strip
BuiltinAggGeoJson -> asGeoJsonF
BuiltinAggCsv -> asCsvF
CustomAgg qi -> customAggF rout qi
CustomAgg funcQi -> customAggF rout funcQi target
NoAgg -> "''::text"
19 changes: 10 additions & 9 deletions src/PostgREST/Query/Statements.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import PostgREST.MediaType (MTPlanFormat (..),
import PostgREST.Query.SqlFragment
import PostgREST.SchemaCache.Routine (ResultAggregate (..),
Routine, funcReturnsSingle)
import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier)

import Protolude

Expand All @@ -53,9 +54,9 @@ data ResultSet
| RSPlan BS.ByteString -- ^ the plan of the query


prepareWrite :: SQL.Snippet -> SQL.Snippet -> Bool -> MediaType -> ResultAggregate ->
prepareWrite :: QualifiedIdentifier -> SQL.Snippet -> SQL.Snippet -> Bool -> MediaType -> ResultAggregate ->
Maybe PreferRepresentation -> [Text] -> Bool -> SQL.Statement () ResultSet
prepareWrite selectQuery mutateQuery isInsert mt rAgg rep pKeys =
prepareWrite qi selectQuery mutateQuery isInsert mt rAgg rep pKeys =
SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt
where
snippet =
Expand All @@ -64,7 +65,7 @@ prepareWrite selectQuery mutateQuery isInsert mt rAgg rep pKeys =
"'' AS total_result_set, " <>
"pg_catalog.count(_postgrest_t) AS page_total, " <>
locF <> " AS header, " <>
aggF Nothing rAgg <> " AS body, " <>
aggF Nothing qi rAgg <> " AS body, " <>
responseHeadersF <> " AS response_headers, " <>
responseStatusF <> " AS response_status " <>
"FROM (" <> selectF <> ") _postgrest_t"
Expand All @@ -88,8 +89,8 @@ prepareWrite selectQuery mutateQuery isInsert mt rAgg rep pKeys =
MTPlan{} -> planRow
_ -> fromMaybe (RSStandard Nothing 0 mempty mempty Nothing Nothing) <$> HD.rowMaybe (standardRow False)

prepareRead :: SQL.Snippet -> SQL.Snippet -> Bool -> MediaType -> ResultAggregate -> Bool -> SQL.Statement () ResultSet
prepareRead selectQuery countQuery countTotal mt rAgg =
prepareRead :: QualifiedIdentifier -> SQL.Snippet -> SQL.Snippet -> Bool -> MediaType -> ResultAggregate -> Bool -> SQL.Statement () ResultSet
prepareRead qi selectQuery countQuery countTotal mt rAgg =
SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt
where
snippet =
Expand All @@ -98,7 +99,7 @@ prepareRead selectQuery countQuery countTotal mt rAgg =
"SELECT " <>
countResultF <> " AS total_result_set, " <>
"pg_catalog.count(_postgrest_t) AS page_total, " <>
aggF Nothing rAgg <> " AS body, " <>
aggF Nothing qi rAgg <> " AS body, " <>
responseHeadersF <> " AS response_headers, " <>
responseStatusF <> " AS response_status " <>
"FROM ( SELECT * FROM " <> sourceCTE <> " ) _postgrest_t"
Expand All @@ -110,10 +111,10 @@ prepareRead selectQuery countQuery countTotal mt rAgg =
MTPlan{} -> planRow
_ -> HD.singleRow $ standardRow True

prepareCall :: Routine -> SQL.Snippet -> SQL.Snippet -> SQL.Snippet -> Bool ->
prepareCall :: QualifiedIdentifier -> Routine -> SQL.Snippet -> SQL.Snippet -> SQL.Snippet -> Bool ->
MediaType -> ResultAggregate -> Bool ->
SQL.Statement () ResultSet
prepareCall rout callProcQuery selectQuery countQuery countTotal mt rAgg =
prepareCall qi rout callProcQuery selectQuery countQuery countTotal mt rAgg =
SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt
where
snippet =
Expand All @@ -124,7 +125,7 @@ prepareCall rout callProcQuery selectQuery countQuery countTotal mt rAgg =
(if funcReturnsSingle rout
then "1"
else "pg_catalog.count(_postgrest_t)") <> " AS page_total, " <>
aggF (Just rout) rAgg <> " AS body, " <>
aggF (Just rout) qi rAgg <> " AS body, " <>
responseHeadersF <> " AS response_headers, " <>
responseStatusF <> " AS response_status " <>
"FROM (" <> selectQuery <> ") _postgrest_t"
Expand Down
11 changes: 9 additions & 2 deletions test/spec/Feature/Query/CustomMediaSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,17 @@ spec = describe "custom media types" $ do
, matchHeaders = ["Content-Type" <:> "application/vnd.geo2+json"]
}

it "will use the application/vnd.geo2+json media type for any table" $
it "will use the more specific application/vnd.geo2 handler for this table" $ do
request methodGet "/shop_bles" (acceptHdrs "application/vnd.geo2+json") ""
`shouldRespondWith`
"\SOH{\"type\": \"FeatureCollection\", \"hello\": \"world\"}"
"\SOH\"anyelement overridden\""
{ matchStatus = 200
, matchHeaders = ["Content-Type" <:> "application/vnd.geo2+json"]
}

request methodGet "/rpc/get_shop_bles" (acceptHdrs "application/vnd.geo2+json") ""
`shouldRespondWith`
"\SOH\"anyelement overridden\""
{ matchStatus = 200
, matchHeaders = ["Content-Type" <:> "application/vnd.geo2+json"]
}
Expand Down
29 changes: 23 additions & 6 deletions test/spec/fixtures/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3534,6 +3534,11 @@ returns setof test.lines as $$
select * from lines;
$$ language sql;

create or replace function test.get_shop_bles ()
returns setof test.shop_bles as $$
select * from shop_bles;
$$ language sql;

-- it can work without a final function too if the stype is already the media type
DO $do$
BEGIN
Expand Down Expand Up @@ -3568,22 +3573,34 @@ BEGIN
END
$do$;

create or replace function test.geoson_trans (state "application/vnd.geo2+json", next anyelement)
create or replace function test.geo2json_trans (state "application/vnd.geo2+json", next anyelement)
returns "application/vnd.geo2+json" as $$
select (state || extensions.ST_AsGeoJSON(next)::jsonb)::"application/vnd.geo2+json";
$$ language sql;

create or replace function test.geojson_final (data "application/vnd.geo2+json")
create or replace function test.geo2json_final (data "application/vnd.geo2+json")
returns "application/vnd.geo2+json" as $$
select (jsonb_build_object('type', 'FeatureCollection', 'hello', 'world'))::"application/vnd.geo2+json";
$$ language sql;

drop aggregate if exists test.geojson_agg(anyelement);
create aggregate test.geojson_agg(anyelement) (
drop aggregate if exists test.geo2json_agg(anyelement);
create aggregate test.geo2json_agg(anyelement) (
initcond = '[]'
, stype = "application/vnd.geo2+json"
, sfunc = geo2json_trans
, finalfunc = geo2json_final
);

create or replace function test.geo2json_trans (state "application/vnd.geo2+json", next test.shop_bles)
returns "application/vnd.geo2+json" as $$
select '"anyelement overridden"'::"application/vnd.geo2+json";
$$ language sql;

drop aggregate if exists test.geo2json_agg(test.shop_bles);
create aggregate test.geo2json_agg(test.shop_bles) (
initcond = '[]'
, stype = "application/vnd.geo2+json"
, sfunc = geoson_trans
, finalfunc = geojson_final
, sfunc = geo2json_trans
);

create table ov_json ();
Expand Down

0 comments on commit 6462f64

Please sign in to comment.