Skip to content

Commit

Permalink
fix: Avoid casting to table type when select= and media type handler …
Browse files Browse the repository at this point in the history
…are used

Previously using a generic mimetype handler failed when any kind of select= was given, because
we tried to cast the select-result to the original table type. With this change, this cast is
only applied when select=* is given implicitly or explicitly. This is the only case where this
makes sense, because this guarantees that correct columns are selected in the correct order for
this cast to succeed.

Resolves PostgREST#3160
  • Loading branch information
wolfgangwalther committed May 9, 2024
1 parent 8d4468e commit 27cb28b
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 43 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
- #3149, Misleading "Starting PostgREST.." logs on schema cache reloading - @steve-chavez
- #3205, Fix wrong subquery error returning a status of 400 Bad Request - @steve-chavez
- #3224, Return status code 406 for non-accepted media type instead of code 415 - @wolfgangwalther
- #3160, Fix using select= query parameter for custom media type handlers - @wolfgangwalther

## [12.0.2] - 2023-12-20

Expand Down
32 changes: 18 additions & 14 deletions src/PostgREST/Plan.hs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ data WrappedReadPlan = WrappedReadPlan {
, wrTxMode :: SQL.Mode
, wrHandler :: MediaHandler
, wrMedia :: MediaType
, wrIdent :: QualifiedIdentifier
}

data MutateReadPlan = MutateReadPlan {
Expand All @@ -106,7 +105,6 @@ data MutateReadPlan = MutateReadPlan {
, mrTxMode :: SQL.Mode
, mrHandler :: MediaHandler
, mrMedia :: MediaType
, mrIdent :: QualifiedIdentifier
}

data CallReadPlan = CallReadPlan {
Expand All @@ -116,7 +114,6 @@ data CallReadPlan = CallReadPlan {
, crProc :: Routine
, crHandler :: MediaHandler
, crMedia :: MediaType
, crIdent :: QualifiedIdentifier
}

data InspectPlan = InspectPlan {
Expand All @@ -127,17 +124,17 @@ data InspectPlan = InspectPlan {
wrappedReadPlan :: QualifiedIdentifier -> AppConfig -> SchemaCache -> ApiRequest -> Either Error WrappedReadPlan
wrappedReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Preferences{..},..} = do
rPlan <- readPlan identifier conf sCache apiRequest
(hdler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaHandlers sCache)
(handler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaHandlers sCache) (hasDefaultSelect rPlan)
if not (null invalidPrefs) && preferHandling == Just Strict then Left $ ApiRequestError $ InvalidPreferences invalidPrefs else Right ()
return $ WrappedReadPlan rPlan SQL.Read hdler mediaType identifier
return $ WrappedReadPlan rPlan SQL.Read handler mediaType

mutateReadPlan :: Mutation -> ApiRequest -> QualifiedIdentifier -> AppConfig -> SchemaCache -> Either Error MutateReadPlan
mutateReadPlan mutation apiRequest@ApiRequest{iPreferences=Preferences{..},..} identifier conf sCache = do
rPlan <- readPlan identifier conf sCache apiRequest
mPlan <- mutatePlan mutation identifier apiRequest sCache rPlan
if not (null invalidPrefs) && preferHandling == Just Strict then Left $ ApiRequestError $ InvalidPreferences invalidPrefs else Right ()
(hdler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaHandlers sCache)
return $ MutateReadPlan rPlan mPlan SQL.Write hdler mediaType identifier
(handler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaHandlers sCache) (hasDefaultSelect rPlan)
return $ MutateReadPlan rPlan mPlan SQL.Write handler mediaType

callReadPlan :: QualifiedIdentifier -> AppConfig -> SchemaCache -> ApiRequest -> InvokeMethod -> Either Error CallReadPlan
callReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Preferences{..},..} invMethod = do
Expand All @@ -161,12 +158,16 @@ callReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Preferenc
(InvPost, Routine.Immutable) -> SQL.Read
(InvPost, Routine.Volatile) -> SQL.Write
cPlan = callPlan proc apiRequest paramKeys args rPlan
(hdler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest relIdentifier iAcceptMediaType (dbMediaHandlers sCache)
(handler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest relIdentifier iAcceptMediaType (dbMediaHandlers sCache) (hasDefaultSelect rPlan)
if not (null invalidPrefs) && preferHandling == Just Strict then Left $ ApiRequestError $ InvalidPreferences invalidPrefs else Right ()
return $ CallReadPlan rPlan cPlan txMode proc hdler mediaType relIdentifier
return $ CallReadPlan rPlan cPlan txMode proc handler mediaType
where
qsParams' = QueryParams.qsParams iQueryParams

hasDefaultSelect :: ReadPlanTree -> Bool
hasDefaultSelect (Node ReadPlan{select=[CoercibleSelectField{csField=CoercibleField{cfName}}]} []) = cfName == "*"
hasDefaultSelect _ = False

inspectPlan :: ApiRequest -> Either Error InspectPlan
inspectPlan apiRequest = do
let producedMTs = [MTOpenAPI, MTApplicationJSON, MTAny]
Expand Down Expand Up @@ -993,8 +994,8 @@ addFilterToLogicForest :: CoercibleFilter -> [CoercibleLogicTree] -> [CoercibleL
addFilterToLogicForest flt lf = CoercibleStmnt flt : lf

-- | Do content negotiation. i.e. choose a media type based on the intersection of accepted/produced media types.
negotiateContent :: AppConfig -> ApiRequest -> QualifiedIdentifier -> [MediaType] -> MediaHandlerMap -> Either ApiRequestError ResolvedHandler
negotiateContent conf ApiRequest{iAction=act, iPreferences=Preferences{preferRepresentation=rep}} identifier accepts produces =
negotiateContent :: AppConfig -> ApiRequest -> QualifiedIdentifier -> [MediaType] -> MediaHandlerMap -> Bool -> Either ApiRequestError ResolvedHandler
negotiateContent conf ApiRequest{iAction=act, iPreferences=Preferences{preferRepresentation=rep}} identifier accepts produces defaultSelect =
case (act, firstAcceptedPick) of
(_, Nothing) -> Left . MediaTypeError $ map MediaType.toMime accepts
(ActionMutate _, Just (x, mt)) -> Right (if rep == Just Full then x else NoAgg, mt)
Expand All @@ -1017,6 +1018,9 @@ negotiateContent conf ApiRequest{iAction=act, iPreferences=Preferences{preferRep
x -> lookupHandler x
mtPlanToNothing x = if configDbPlanEnabled conf then x else Nothing -- don't find anything if the plan media type is not allowed
lookupHandler mt =
HM.lookup (RelId identifier, MTAny) produces <|> -- lookup for identifier and `*/*`
HM.lookup (RelId identifier, mt) produces <|> -- lookup for identifier and a particular media type
HM.lookup (RelAnyElement, mt) produces -- lookup for anyelement and a particular media type
when' defaultSelect (HM.lookup (RelId identifier, MTAny) produces) <|> -- lookup for identifier and `*/*`
when' defaultSelect (HM.lookup (RelId identifier, mt) produces) <|> -- lookup for identifier and a particular media type
HM.lookup (RelAnyElement, mt) produces -- lookup for anyelement and a particular media type
when' :: Bool -> Maybe a -> Maybe a
when' True (Just a) = Just a
when' _ _ = Nothing
3 changes: 0 additions & 3 deletions src/PostgREST/Query.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ readQuery WrappedReadPlan{..} conf@AppConfig{..} apiReq@ApiRequest{iPreferences=
resultSet <-
lift . SQL.statement mempty $
Statements.prepareRead
wrIdent
(QueryBuilder.readPlanToQuery wrReadPlan)
(if preferCount == Just EstimatedCount then
-- LIMIT maxRows + 1 so we can determine below that maxRows was surpassed
Expand Down Expand Up @@ -153,7 +152,6 @@ invokeQuery rout CallReadPlan{..} apiReq@ApiRequest{iPreferences=Preferences{..}
resultSet <-
lift . SQL.statement mempty $
Statements.prepareCall
crIdent
rout
(QueryBuilder.callPlanToQuery crCallPlan pgVer)
(QueryBuilder.readPlanToQuery crReadPlan)
Expand Down Expand Up @@ -191,7 +189,6 @@ writeQuery MutateReadPlan{..} ApiRequest{iPreferences=Preferences{..}} conf =
in
lift . SQL.statement mempty $
Statements.prepareWrite
mrIdent
(QueryBuilder.readPlanToQuery mrReadPlan)
(QueryBuilder.mutatePlanToQuery mrMutatePlan)
isInsert
Expand Down
16 changes: 9 additions & 7 deletions src/PostgREST/Query/SqlFragment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ import PostgREST.Plan.Types (CoercibleField (..),
import PostgREST.RangeQuery (NonnegRange, allRange,
rangeLimit, rangeOffset)
import PostgREST.SchemaCache.Identifiers (FieldName,
QualifiedIdentifier (..))
QualifiedIdentifier (..),
RelIdentifier (..))
import PostgREST.SchemaCache.Routine (MediaHandler (..),
Routine (..),
funcReturnsScalar,
Expand Down Expand Up @@ -221,10 +222,11 @@ asJsonF rout strip
asGeoJsonF :: SQL.Snippet
asGeoJsonF = "json_build_object('type', 'FeatureCollection', 'features', coalesce(json_agg(ST_AsGeoJSON(_postgrest_t)::json), '[]'))"

customFuncF :: Maybe Routine -> QualifiedIdentifier -> QualifiedIdentifier -> SQL.Snippet
customFuncF rout funcQi target
customFuncF :: Maybe Routine -> QualifiedIdentifier -> RelIdentifier -> SQL.Snippet
customFuncF rout funcQi _
| (funcReturnsScalar <$> rout) == Just True = fromQi funcQi <> "(_postgrest_t.pgrst_scalar)"
| otherwise = fromQi funcQi <> "(_postgrest_t::" <> fromQi target <> ")"
customFuncF _ funcQi RelAnyElement = fromQi funcQi <> "(_postgrest_t)"
customFuncF _ funcQi (RelId target) = fromQi funcQi <> "(_postgrest_t::" <> fromQi target <> ")"

locationF :: [Text] -> SQL.Snippet
locationF pKeys = [qc|(
Expand Down Expand Up @@ -559,12 +561,12 @@ setConfigWithConstantNameJSON prefix keyVals = [setConfigWithConstantName (prefi
arrayByteStringToText :: [(ByteString, ByteString)] -> [(Text,Text)]
arrayByteStringToText keyVal = (T.decodeUtf8 *** T.decodeUtf8) <$> keyVal

handlerF :: Maybe Routine -> QualifiedIdentifier -> MediaHandler -> SQL.Snippet
handlerF rout target = \case
handlerF :: Maybe Routine -> MediaHandler -> SQL.Snippet
handlerF rout = \case
BuiltinAggArrayJsonStrip -> asJsonF rout True
BuiltinAggSingleJson strip -> asJsonSingleF rout strip
BuiltinOvAggJson -> asJsonF rout False
BuiltinOvAggGeoJson -> asGeoJsonF
BuiltinOvAggCsv -> asCsvF
CustomFunc funcQi -> customFuncF rout funcQi target
CustomFunc funcQi target -> customFuncF rout funcQi target
NoAgg -> "''::text"
27 changes: 13 additions & 14 deletions src/PostgREST/Query/Statements.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ import qualified Hasql.Statement as SQL
import Control.Lens ((^?))

import PostgREST.ApiRequest.Preferences
import PostgREST.MediaType (MTVndPlanFormat (..),
MediaType (..))
import PostgREST.MediaType (MTVndPlanFormat (..),
MediaType (..))
import PostgREST.Query.SqlFragment
import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier)
import PostgREST.SchemaCache.Routine (MediaHandler (..), Routine,
funcReturnsSingle)
import PostgREST.SchemaCache.Routine (MediaHandler (..), Routine,
funcReturnsSingle)

import Protolude

Expand All @@ -56,9 +55,9 @@ data ResultSet
| RSPlan BS.ByteString -- ^ the plan of the query


prepareWrite :: QualifiedIdentifier -> SQL.Snippet -> SQL.Snippet -> Bool -> Bool -> MediaType -> MediaHandler ->
prepareWrite :: SQL.Snippet -> SQL.Snippet -> Bool -> Bool -> MediaType -> MediaHandler ->
Maybe PreferRepresentation -> Maybe PreferResolution -> [Text] -> Bool -> SQL.Statement () ResultSet
prepareWrite qi selectQuery mutateQuery isInsert isPut mt handler rep resolution pKeys =
prepareWrite selectQuery mutateQuery isInsert isPut mt handler rep resolution pKeys =
SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt
where
checkUpsert snip = if isInsert && (isPut || resolution == Just MergeDuplicates) then snip else "''"
Expand All @@ -69,7 +68,7 @@ prepareWrite qi selectQuery mutateQuery isInsert isPut mt handler rep resolution
"'' AS total_result_set, " <>
"pg_catalog.count(_postgrest_t) AS page_total, " <>
locF <> " AS header, " <>
handlerF Nothing qi handler <> " AS body, " <>
handlerF Nothing handler <> " AS body, " <>
responseHeadersF <> " AS response_headers, " <>
responseStatusF <> " AS response_status, " <>
pgrstInsertedF <> " AS response_inserted " <>
Expand All @@ -94,8 +93,8 @@ prepareWrite qi selectQuery mutateQuery isInsert isPut mt handler rep resolution
MTVndPlan{} -> planRow
_ -> fromMaybe (RSStandard Nothing 0 mempty mempty Nothing Nothing Nothing) <$> HD.rowMaybe (standardRow False)

prepareRead :: QualifiedIdentifier -> SQL.Snippet -> SQL.Snippet -> Bool -> MediaType -> MediaHandler -> Bool -> SQL.Statement () ResultSet
prepareRead qi selectQuery countQuery countTotal mt handler =
prepareRead :: SQL.Snippet -> SQL.Snippet -> Bool -> MediaType -> MediaHandler -> Bool -> SQL.Statement () ResultSet
prepareRead selectQuery countQuery countTotal mt handler =
SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt
where
snippet =
Expand All @@ -104,7 +103,7 @@ prepareRead qi selectQuery countQuery countTotal mt handler =
"SELECT " <>
countResultF <> " AS total_result_set, " <>
"pg_catalog.count(_postgrest_t) AS page_total, " <>
handlerF Nothing qi handler <> " AS body, " <>
handlerF Nothing handler <> " AS body, " <>
responseHeadersF <> " AS response_headers, " <>
responseStatusF <> " AS response_status, " <>
"''" <> " AS response_inserted " <>
Expand All @@ -117,10 +116,10 @@ prepareRead qi selectQuery countQuery countTotal mt handler =
MTVndPlan{} -> planRow
_ -> HD.singleRow $ standardRow True

prepareCall :: QualifiedIdentifier -> Routine -> SQL.Snippet -> SQL.Snippet -> SQL.Snippet -> Bool ->
prepareCall :: Routine -> SQL.Snippet -> SQL.Snippet -> SQL.Snippet -> Bool ->
MediaType -> MediaHandler -> Bool ->
SQL.Statement () ResultSet
prepareCall qi rout callProcQuery selectQuery countQuery countTotal mt handler =
prepareCall rout callProcQuery selectQuery countQuery countTotal mt handler =
SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt
where
snippet =
Expand All @@ -131,7 +130,7 @@ prepareCall qi rout callProcQuery selectQuery countQuery countTotal mt handler =
(if funcReturnsSingle rout
then "1"
else "pg_catalog.count(_postgrest_t)") <> " AS page_total, " <>
handlerF (Just rout) qi handler <> " AS body, " <>
handlerF (Just rout) handler <> " AS body, " <>
responseHeadersF <> " AS response_headers, " <>
responseStatusF <> " AS response_status, " <>
"''" <> " AS response_inserted " <>
Expand Down
4 changes: 3 additions & 1 deletion src/PostgREST/SchemaCache.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,9 @@ mediaHandlers pgVer =

decodeMediaHandlers :: HD.Result MediaHandlerMap
decodeMediaHandlers =
HM.fromList . fmap (\(x, y, z, w) -> ((if isAnyElement y then RelAnyElement else RelId y, z), (CustomFunc x, w)) ) <$> HD.rowList caggRow
HM.fromList . fmap (\(x, y, z, w) ->
let rel = if isAnyElement y then RelAnyElement else RelId y
in ((rel, z), (CustomFunc x rel, w)) ) <$> HD.rowList caggRow
where
caggRow = (,,,)
<$> (QualifiedIdentifier <$> column HD.text <*> column HD.text)
Expand Down
2 changes: 1 addition & 1 deletion src/PostgREST/SchemaCache/Identifiers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import qualified Data.Text as T
import Protolude

data RelIdentifier = RelId QualifiedIdentifier | RelAnyElement
deriving (Eq, Ord, Generic, JSON.ToJSON, JSON.ToJSONKey)
deriving (Eq, Ord, Generic, JSON.ToJSON, JSON.ToJSONKey, Show)
instance Hashable RelIdentifier

-- | Represents a pg identifier with a prepended schema name "schema.table".
Expand Down
2 changes: 1 addition & 1 deletion src/PostgREST/SchemaCache/Routine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ data MediaHandler
| BuiltinOvAggGeoJson
| BuiltinOvAggCsv
-- custom
| CustomFunc QualifiedIdentifier
| CustomFunc QualifiedIdentifier RelIdentifier
| NoAgg
deriving (Eq, Show)

Expand Down
70 changes: 70 additions & 0 deletions test/spec/Feature/Query/CustomMediaSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,76 @@ spec = describe "custom media types" $ do
simpleHeaders r `shouldContain` [("Content-Type", "text/csv; charset=utf-8")]
simpleHeaders r `shouldContain` [("Content-Disposition", "attachment; filename=\"lines.csv\"")]

-- https://github.com/PostgREST/postgrest/issues/3160
context "using select query parameter" $ do
it "without select" $ do
request methodGet "/projects?id=in.(1,2)" (acceptHdrs "pg/outfunc") ""
`shouldRespondWith`
[str|(1,"Windows 7",1)
|(2,"Windows 10",1)
|]
{ matchStatus = 200
, matchHeaders = ["Content-Type" <:> "pg/outfunc"]
}

it "with fewer columns selected" $ do
request methodGet "/projects?id=in.(1,2)&select=id,name" (acceptHdrs "pg/outfunc") ""
`shouldRespondWith`
[str|(1,"Windows 7")
|(2,"Windows 10")
|]
{ matchStatus = 200
, matchHeaders = ["Content-Type" <:> "pg/outfunc"]
}

it "with columns in different order" $ do
request methodGet "/projects?id=in.(1,2)&select=name,id,client_id" (acceptHdrs "pg/outfunc") ""
`shouldRespondWith`
[str|("Windows 7",1,1)
|("Windows 10",2,1)
|]
{ matchStatus = 200
, matchHeaders = ["Content-Type" <:> "pg/outfunc"]
}

it "with computed columns" $ do
request methodGet "/items?id=in.(1,2)&select=id,always_true" (acceptHdrs "pg/outfunc") ""
`shouldRespondWith`
[str|(1,t)
|(2,t)
|]
{ matchStatus = 200
, matchHeaders = ["Content-Type" <:> "pg/outfunc"]
}

-- TODO: Embeddings should not return JSON. Arrays of record would be much better.
it "with embedding" $ do
request methodGet "/projects?id=in.(1,2)&select=*,clients(id)" (acceptHdrs "pg/outfunc") ""
`shouldRespondWith`
[str|(1,"Windows 7",1,"{""id"": 1}")
|(2,"Windows 10",1,"{""id"": 1}")
|]
{ matchStatus = 200
, matchHeaders = ["Content-Type" <:> "pg/outfunc"]
}

it "will fail for specific aggregate with fewer columns" $ do
request methodGet "/lines?select=id" (acceptHdrs "application/vnd.twkb") ""
`shouldRespondWith` 406

it "will fail for specific aggregate with more columns" $ do
request methodGet "/lines?select=id,name,geom,id" (acceptHdrs "application/vnd.twkb") ""
`shouldRespondWith` 406

it "will fail for specific aggregate with columns in different order" $ do
request methodGet "/lines?select=name,id,geom" (acceptHdrs "application/vnd.twkb") ""
`shouldRespondWith` 406

-- This is just because it would be hard to detect this case, so we better error in this case, too.
it "will fail for specific aggregate with columns in same order" $ do
request methodGet "/lines?select=id,name,geom" (acceptHdrs "application/vnd.twkb") ""
`shouldRespondWith` 406

context "any media type" $ do
context "on functions" $ do
it "returns application/json for */* if not explicitly set" $ do
Expand Down
17 changes: 15 additions & 2 deletions test/spec/fixtures/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3550,8 +3550,8 @@ returns "application/vnd.geo2+json" as $$
select (jsonb_build_object('type', 'FeatureCollection', 'hello', 'world'))::"application/vnd.geo2+json";
$$ language sql;

drop aggregate if exists test.geo2json_agg(anyelement);
create aggregate test.geo2json_agg(anyelement) (
drop aggregate if exists test.geo2json_agg_any(anyelement);
create aggregate test.geo2json_agg_any(anyelement) (
initcond = '[]'
, stype = "application/vnd.geo2+json"
, sfunc = geo2json_trans
Expand Down Expand Up @@ -3755,3 +3755,16 @@ create aggregate test.some_agg (some_numbers) (

create view bad_subquery as
select * from projects where id = (select id from projects);

-- custom generic mimetype
create domain "pg/outfunc" as text;
create function test.outfunc_trans (state text, next anyelement)
returns "pg/outfunc" as $$
select (state || next::text || E'\n')::"pg/outfunc";
$$ language sql;

create aggregate test.outfunc_agg (anyelement) (
initcond = ''
, stype = "pg/outfunc"
, sfunc = outfunc_trans
);

0 comments on commit 27cb28b

Please sign in to comment.