Skip to content

Commit

Permalink
fix: Avoid casting to table type when select= and media type handler …
Browse files Browse the repository at this point in the history
…are used

Previously using a generic mimetype handler failed when any kind of select= was given, because
we tried to cast the select-result to the original table type. With this change, this cast is
only applied when select=* is given implicitly or explicitly. This is the only case where this
makes sense, because this guarantees that correct columns are selected in the correct order for
this cast to succeed.

Resolves PostgREST#3160
  • Loading branch information
wolfgangwalther committed Feb 12, 2024
1 parent b22bb74 commit 21bc0e7
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
- #3149, Misleading "Starting PostgREST.." logs on schema cache reloading - @steve-chavez
- #2815, Build static executable with GSSAPI support - @wolfgangwalther
- #3205, Fix wrong subquery error returning a status of 400 Bad Request - @steve-chavez
- #3160, Fix using select= query parameter for custom media type handlers - @wolfgangwalther

### Deprecated

Expand Down
28 changes: 16 additions & 12 deletions src/PostgREST/Plan.hs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ data WrappedReadPlan = WrappedReadPlan {
, wrTxMode :: SQL.Mode
, wrHandler :: MediaHandler
, wrMedia :: MediaType
, wrIdent :: QualifiedIdentifier
, wrIdent :: Maybe QualifiedIdentifier
}

data MutateReadPlan = MutateReadPlan {
Expand All @@ -106,7 +106,7 @@ data MutateReadPlan = MutateReadPlan {
, mrTxMode :: SQL.Mode
, mrHandler :: MediaHandler
, mrMedia :: MediaType
, mrIdent :: QualifiedIdentifier
, mrIdent :: Maybe QualifiedIdentifier
}

data CallReadPlan = CallReadPlan {
Expand All @@ -116,7 +116,7 @@ data CallReadPlan = CallReadPlan {
, crProc :: Routine
, crHandler :: MediaHandler
, crMedia :: MediaType
, crIdent :: QualifiedIdentifier
, crIdent :: Maybe QualifiedIdentifier
}

data InspectPlan = InspectPlan {
Expand All @@ -126,18 +126,18 @@ data InspectPlan = InspectPlan {

wrappedReadPlan :: QualifiedIdentifier -> AppConfig -> SchemaCache -> ApiRequest -> Either Error WrappedReadPlan
wrappedReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Preferences{..},..} = do
rPlan <- readPlan identifier conf sCache apiRequest
(hdler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaHandlers sCache)
rPlan@(Node ReadPlan{select} forest) <- readPlan identifier conf sCache apiRequest
(handler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaHandlers sCache)
if not (null invalidPrefs) && preferHandling == Just Strict then Left $ ApiRequestError $ InvalidPreferences invalidPrefs else Right ()
return $ WrappedReadPlan rPlan SQL.Read hdler mediaType identifier
return $ WrappedReadPlan rPlan SQL.Read handler mediaType (if isStarOnly select && null forest then Just identifier else Nothing)

mutateReadPlan :: Mutation -> ApiRequest -> QualifiedIdentifier -> AppConfig -> SchemaCache -> Either Error MutateReadPlan
mutateReadPlan mutation apiRequest@ApiRequest{iPreferences=Preferences{..},..} identifier conf sCache = do
rPlan <- readPlan identifier conf sCache apiRequest
rPlan@(Node ReadPlan{select} forest) <- readPlan identifier conf sCache apiRequest
mPlan <- mutatePlan mutation identifier apiRequest sCache rPlan
if not (null invalidPrefs) && preferHandling == Just Strict then Left $ ApiRequestError $ InvalidPreferences invalidPrefs else Right ()
(hdler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaHandlers sCache)
return $ MutateReadPlan rPlan mPlan SQL.Write hdler mediaType identifier
(handler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaHandlers sCache)
return $ MutateReadPlan rPlan mPlan SQL.Write handler mediaType (if isStarOnly select && null forest then Just identifier else Nothing)

callReadPlan :: QualifiedIdentifier -> AppConfig -> SchemaCache -> ApiRequest -> InvokeMethod -> Either Error CallReadPlan
callReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Preferences{..},..} invMethod = do
Expand All @@ -148,7 +148,7 @@ callReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Preferenc
proc@Function{..} <- mapLeft ApiRequestError $
findProc identifier paramKeys (preferParameters == Just SingleObject) (dbRoutines sCache) iContentMediaType (invMethod == InvPost)
let relIdentifier = QualifiedIdentifier pdSchema (fromMaybe pdName $ Routine.funcTableName proc) -- done so a set returning function can embed other relations
rPlan <- readPlan relIdentifier conf sCache apiRequest
rPlan@(Node ReadPlan{select} forest) <- readPlan relIdentifier conf sCache apiRequest
let args = case (invMethod, iContentMediaType) of
(InvGet, _) -> jsonRpcParams proc qsParams'
(InvHead, _) -> jsonRpcParams proc qsParams'
Expand All @@ -161,12 +161,16 @@ callReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Preferenc
(InvPost, Routine.Immutable) -> SQL.Read
(InvPost, Routine.Volatile) -> SQL.Write
cPlan = callPlan proc apiRequest paramKeys args rPlan
(hdler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest relIdentifier iAcceptMediaType (dbMediaHandlers sCache)
(handler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest relIdentifier iAcceptMediaType (dbMediaHandlers sCache)
if not (null invalidPrefs) && preferHandling == Just Strict then Left $ ApiRequestError $ InvalidPreferences invalidPrefs else Right ()
return $ CallReadPlan rPlan cPlan txMode proc hdler mediaType relIdentifier
return $ CallReadPlan rPlan cPlan txMode proc handler mediaType (if isStarOnly select && null forest then Just relIdentifier else Nothing)
where
qsParams' = QueryParams.qsParams iQueryParams

isStarOnly :: [CoercibleSelectField] -> Bool
isStarOnly [CoercibleSelectField{csField=CoercibleField{cfName}}] = cfName == "*"
isStarOnly _ = False

inspectPlan :: ApiRequest -> Either Error InspectPlan
inspectPlan apiRequest = do
let producedMTs = [MTOpenAPI, MTApplicationJSON, MTAny]
Expand Down
9 changes: 5 additions & 4 deletions src/PostgREST/Query/SqlFragment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,11 @@ asJsonF rout strip
asGeoJsonF :: SQL.Snippet
asGeoJsonF = "json_build_object('type', 'FeatureCollection', 'features', coalesce(json_agg(ST_AsGeoJSON(_postgrest_t)::json), '[]'))"

customFuncF :: Maybe Routine -> QualifiedIdentifier -> QualifiedIdentifier -> SQL.Snippet
customFuncF rout funcQi target
customFuncF :: Maybe Routine -> QualifiedIdentifier -> Maybe QualifiedIdentifier -> SQL.Snippet
customFuncF rout funcQi _
| (funcReturnsScalar <$> rout) == Just True = fromQi funcQi <> "(_postgrest_t.pgrst_scalar)"
| otherwise = fromQi funcQi <> "(_postgrest_t::" <> fromQi target <> ")"
customFuncF _ funcQi (Just target) = fromQi funcQi <> "(_postgrest_t::" <> fromQi target <> ")"
customFuncF _ funcQi Nothing = fromQi funcQi <> "(_postgrest_t)"

locationF :: [Text] -> SQL.Snippet
locationF pKeys = [qc|(
Expand Down Expand Up @@ -559,7 +560,7 @@ setConfigWithConstantNameJSON prefix keyVals = [setConfigWithConstantName (prefi
arrayByteStringToText :: [(ByteString, ByteString)] -> [(Text,Text)]
arrayByteStringToText keyVal = (T.decodeUtf8 *** T.decodeUtf8) <$> keyVal

handlerF :: Maybe Routine -> QualifiedIdentifier -> MediaHandler -> SQL.Snippet
handlerF :: Maybe Routine -> Maybe QualifiedIdentifier -> MediaHandler -> SQL.Snippet
handlerF rout target = \case
BuiltinAggArrayJsonStrip -> asJsonF rout True
BuiltinAggSingleJson strip -> asJsonSingleF rout strip
Expand Down
6 changes: 3 additions & 3 deletions src/PostgREST/Query/Statements.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ data ResultSet
| RSPlan BS.ByteString -- ^ the plan of the query


prepareWrite :: QualifiedIdentifier -> SQL.Snippet -> SQL.Snippet -> Bool -> Bool -> MediaType -> MediaHandler ->
prepareWrite :: Maybe QualifiedIdentifier -> SQL.Snippet -> SQL.Snippet -> Bool -> Bool -> MediaType -> MediaHandler ->
Maybe PreferRepresentation -> Maybe PreferResolution -> [Text] -> Bool -> SQL.Statement () ResultSet
prepareWrite qi selectQuery mutateQuery isInsert isPut mt handler rep resolution pKeys =
SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt
Expand Down Expand Up @@ -94,7 +94,7 @@ prepareWrite qi selectQuery mutateQuery isInsert isPut mt handler rep resolution
MTVndPlan{} -> planRow
_ -> fromMaybe (RSStandard Nothing 0 mempty mempty Nothing Nothing Nothing) <$> HD.rowMaybe (standardRow False)

prepareRead :: QualifiedIdentifier -> SQL.Snippet -> SQL.Snippet -> Bool -> MediaType -> MediaHandler -> Bool -> SQL.Statement () ResultSet
prepareRead :: Maybe QualifiedIdentifier -> SQL.Snippet -> SQL.Snippet -> Bool -> MediaType -> MediaHandler -> Bool -> SQL.Statement () ResultSet
prepareRead qi selectQuery countQuery countTotal mt handler =
SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt
where
Expand All @@ -117,7 +117,7 @@ prepareRead qi selectQuery countQuery countTotal mt handler =
MTVndPlan{} -> planRow
_ -> HD.singleRow $ standardRow True

prepareCall :: QualifiedIdentifier -> Routine -> SQL.Snippet -> SQL.Snippet -> SQL.Snippet -> Bool ->
prepareCall :: Maybe QualifiedIdentifier -> Routine -> SQL.Snippet -> SQL.Snippet -> SQL.Snippet -> Bool ->
MediaType -> MediaHandler -> Bool ->
SQL.Statement () ResultSet
prepareCall qi rout callProcQuery selectQuery countQuery countTotal mt handler =
Expand Down
53 changes: 53 additions & 0 deletions test/spec/Feature/Query/CustomMediaSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,59 @@ spec = describe "custom media types" $ do
simpleHeaders r `shouldContain` [("Content-Type", "text/csv; charset=utf-8")]
simpleHeaders r `shouldContain` [("Content-Disposition", "attachment; filename=\"lines.csv\"")]

-- https://github.com/PostgREST/postgrest/issues/3160
context "using select query parameter" $ do
it "without select" $ do
request methodGet "/projects?id=in.(1,2)" (acceptHdrs "pg/outfunc") ""
`shouldRespondWith`
[str|(1,"Windows 7",1)
|(2,"Windows 10",1)
|]
{ matchStatus = 200
, matchHeaders = ["Content-Type" <:> "pg/outfunc"]
}

it "with fewer columns selected" $ do
request methodGet "/projects?id=in.(1,2)&select=id,name" (acceptHdrs "pg/outfunc") ""
`shouldRespondWith`
[str|(1,"Windows 7")
|(2,"Windows 10")
|]
{ matchStatus = 200
, matchHeaders = ["Content-Type" <:> "pg/outfunc"]
}

it "with columns in different order" $ do
request methodGet "/projects?id=in.(1,2)&select=name,id,client_id" (acceptHdrs "pg/outfunc") ""
`shouldRespondWith`
[str|("Windows 7",1,1)
|("Windows 10",2,1)
|]
{ matchStatus = 200
, matchHeaders = ["Content-Type" <:> "pg/outfunc"]
}

it "with computed columns" $ do
request methodGet "/items?id=in.(1,2)&select=id,always_true" (acceptHdrs "pg/outfunc") ""
`shouldRespondWith`
[str|(1,t)
|(2,t)
|]
{ matchStatus = 200
, matchHeaders = ["Content-Type" <:> "pg/outfunc"]
}

-- TODO: Embeddings should not return JSON. Arrays of record would be much better.
it "with embedding" $ do
request methodGet "/projects?id=in.(1,2)&select=*,clients(id)" (acceptHdrs "pg/outfunc") ""
`shouldRespondWith`
[str|(1,"Windows 7",1,"{""id"": 1}")
|(2,"Windows 10",1,"{""id"": 1}")
|]
{ matchStatus = 200
, matchHeaders = ["Content-Type" <:> "pg/outfunc"]
}

context "any media type" $ do
context "on functions" $ do
it "returns application/json for */* if not explicitly set" $ do
Expand Down
13 changes: 13 additions & 0 deletions test/spec/fixtures/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3755,3 +3755,16 @@ create aggregate test.some_agg (some_numbers) (

create view bad_subquery as
select * from projects where id = (select id from projects);

-- custom generic mimetype
create domain "pg/outfunc" as text;
create function test.outfunc_trans (state text, next anyelement)
returns "pg/outfunc" as $$
select (state || next::text || E'\n')::"pg/outfunc";
$$ language sql;

create aggregate test.outfunc_agg (anyelement) (
initcond = ''
, stype = "pg/outfunc"
, sfunc = outfunc_trans
);

0 comments on commit 21bc0e7

Please sign in to comment.