Skip to content

Commit

Permalink
feat: custom media types for Accept
Browse files Browse the repository at this point in the history
  • Loading branch information
steve-chavez committed Jun 20, 2023
1 parent 0ceb1ca commit 4243c4a
Show file tree
Hide file tree
Showing 12 changed files with 246 additions and 108 deletions.
2 changes: 1 addition & 1 deletion src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A
return $ Response.invokeResponse cPlan invMethod (Plan.crProc cPlan) apiReq resultSet

(ActionInspect headersOnly, TargetDefaultSpec tSchema) -> do
iPlan <- liftEither $ Plan.inspectPlan conf apiReq
iPlan <- liftEither $ Plan.inspectPlan apiReq
oaiResult <- runQuery roleIsoLvl (Plan.ipTxmode iPlan) $ Query.openApiQuery sCache pgVer conf tSchema
return $ Response.openApiResponse headersOnly oaiResult conf sCache iSchema iNegotiatedByProfile

Expand Down
43 changes: 23 additions & 20 deletions src/PostgREST/MediaType.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}

module PostgREST.MediaType
Expand All @@ -7,12 +8,10 @@ module PostgREST.MediaType
, toContentType
, toMime
, decodeMediaType
, getMediaType
) where

import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS (c2w)
import Data.Maybe (fromJust)

import Network.HTTP.Types.Header (Header, hContentType)

Expand All @@ -39,7 +38,8 @@ data MediaType
| MTOctetStream
| MTAny
| MTOther ByteString
| MTPlan (Maybe MediaType) (Maybe MTPlanFormat) [MTPlanOption]
| MTPlan MediaType MTPlanFormat [MTPlanOption]
deriving Generic
instance Eq MediaType where
MTApplicationJSON == MTApplicationJSON = True
MTSingularJSON == MTSingularJSON = True
Expand All @@ -54,12 +54,17 @@ instance Eq MediaType where
MTOther x == MTOther y = x == y
MTPlan{} == MTPlan{} = True
_ == _ = False
instance Hashable MediaType

data MTPlanOption
= PlanAnalyze | PlanVerbose | PlanSettings | PlanBuffers | PlanWAL
deriving (Eq, Generic)
instance Hashable MTPlanOption

data MTPlanFormat
= PlanJSON | PlanText
deriving (Eq, Generic)
instance Hashable MTPlanFormat

-- | Convert MediaType to a Content-Type HTTP Header
toContentType :: MediaType -> Header
Expand All @@ -84,8 +89,8 @@ toMime MTOctetStream = "application/octet-stream"
toMime MTAny = "*/*"
toMime (MTOther ct) = ct
toMime (MTPlan mt fmt opts) =
"application/vnd.pgrst.plan" <> maybe mempty (\x -> "+" <> toMimePlanFormat x) fmt <>
(if isNothing mt then mempty else "; for=\"" <> toMime (fromJust mt) <> "\"") <>
"application/vnd.pgrst.plan+" <> toMimePlanFormat fmt <>
("; for=\"" <> toMime mt <> "\"") <>
(if null opts then mempty else "; options=" <> BS.intercalate "|" (toMimePlanOption <$> opts))

toMimePlanOption :: MTPlanOption -> ByteString
Expand All @@ -105,13 +110,13 @@ toMimePlanFormat PlanText = "text"
-- MTApplicationJSON
--
-- >>> decodeMediaType "application/vnd.pgrst.plan;"
-- MTPlan Nothing Nothing []
-- MTPlan MTApplicationJSON PlanText []
--
-- >>> decodeMediaType "application/vnd.pgrst.plan;for=\"application/json\""
-- MTPlan (Just MTApplicationJSON) Nothing []
-- MTPlan MTApplicationJSON PlanText []
--
-- >>> decodeMediaType "application/vnd.pgrst.plan+text;for=\"text/csv\""
-- MTPlan (Just MTTextCSV) (Just PlanText) []
-- >>> decodeMediaType "application/vnd.pgrst.plan+json;for=\"text/csv\""
-- MTPlan MTTextCSV PlanJSON []
decodeMediaType :: BS.ByteString -> MediaType
decodeMediaType mt =
case BS.split (BS.c2w ';') mt of
Expand All @@ -125,9 +130,9 @@ decodeMediaType mt =
"application/vnd.pgrst.object":_ -> MTSingularJSON
"application/x-www-form-urlencoded":_ -> MTUrlEncoded
"application/octet-stream":_ -> MTOctetStream
"application/vnd.pgrst.plan":rest -> getPlan Nothing rest
"application/vnd.pgrst.plan+text":rest -> getPlan (Just PlanText) rest
"application/vnd.pgrst.plan+json":rest -> getPlan (Just PlanJSON) rest
"application/vnd.pgrst.plan":rest -> getPlan PlanText rest
"application/vnd.pgrst.plan+text":rest -> getPlan PlanText rest
"application/vnd.pgrst.plan+json":rest -> getPlan PlanJSON rest
"*/*":_ -> MTAny
other:_ -> MTOther other
_ -> MTAny
Expand All @@ -136,17 +141,15 @@ decodeMediaType mt =
let
opts = BS.split (BS.c2w '|') $ fromMaybe mempty (BS.stripPrefix "options=" =<< find (BS.isPrefixOf "options=") rest)
inOpts str = str `elem` opts
mtFor = decodeMediaType . dropAround (== BS.c2w '"') <$> (BS.stripPrefix "for=" =<< find (BS.isPrefixOf "for=") rest)
dropAround p = BS.dropWhile p . BS.dropWhileEnd p in
dropAround p = BS.dropWhile p . BS.dropWhileEnd p
mtFor = fromMaybe MTApplicationJSON $ do
foundFor <- find (BS.isPrefixOf "for=") rest
strippedFor <- BS.stripPrefix "for=" foundFor
pure . decodeMediaType $ dropAround (== BS.c2w '"') strippedFor
in
MTPlan mtFor fmt $
[PlanAnalyze | inOpts "analyze" ] ++
[PlanVerbose | inOpts "verbose" ] ++
[PlanSettings | inOpts "settings"] ++
[PlanBuffers | inOpts "buffers" ] ++
[PlanWAL | inOpts "wal" ]

getMediaType :: MediaType -> MediaType
getMediaType mt = case mt of
MTPlan (Just mType) _ _ -> mType
MTPlan Nothing _ _ -> MTApplicationJSON
other -> other
61 changes: 31 additions & 30 deletions src/PostgREST/Plan.hs
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,14 @@ import PostgREST.RangeQuery (NonnegRange, allRange,
import PostgREST.SchemaCache (SchemaCache (..))
import PostgREST.SchemaCache.Identifiers (FieldName,
QualifiedIdentifier (..),
Schema)
RelIdentifier (..), Schema)
import PostgREST.SchemaCache.Relationship (Cardinality (..),
Junction (..),
Relationship (..),
RelationshipsMap,
relIsToOne)
import PostgREST.SchemaCache.Routine (Routine (..), RoutineMap,
import PostgREST.SchemaCache.Routine (ResultAggregate (..),
Routine (..), RoutineMap,
RoutineParam (..),
funcReturnsCompositeAlias,
funcReturnsScalar,
Expand All @@ -84,13 +85,15 @@ data WrappedReadPlan = WrappedReadPlan {
wrReadPlan :: ReadPlanTree
, wrTxMode :: SQL.Mode
, wrMedia :: MediaType
, wrResAgg :: ResultAggregate
}

data MutateReadPlan = MutateReadPlan {
mrReadPlan :: ReadPlanTree
, mrMutatePlan :: MutatePlan
, mrTxMode :: SQL.Mode
, mrMedia :: MediaType
, mrResAgg :: ResultAggregate
}

data CallReadPlan = CallReadPlan {
Expand All @@ -99,6 +102,7 @@ data CallReadPlan = CallReadPlan {
, crTxMode :: SQL.Mode
, crProc :: Routine
, crMedia :: MediaType
, crResAgg :: ResultAggregate
}

data InspectPlan = InspectPlan {
Expand All @@ -109,15 +113,15 @@ data InspectPlan = InspectPlan {
wrappedReadPlan :: QualifiedIdentifier -> AppConfig -> SchemaCache -> ApiRequest -> Either Error WrappedReadPlan
wrappedReadPlan identifier conf sCache apiRequest = do
rPlan <- readPlan identifier conf sCache apiRequest
mediaType <- mapLeft ApiRequestError $ negotiateContent conf (iAction apiRequest) (iAcceptMediaType apiRequest)
return $ WrappedReadPlan rPlan SQL.Read mediaType
(rAgg, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf identifier (iAcceptMediaType apiRequest) (dbMediaAggs sCache)
return $ WrappedReadPlan rPlan SQL.Read mediaType rAgg

mutateReadPlan :: Mutation -> ApiRequest -> QualifiedIdentifier -> AppConfig -> SchemaCache -> Either Error MutateReadPlan
mutateReadPlan mutation apiRequest identifier conf sCache = do
rPlan <- readPlan identifier conf sCache apiRequest
mPlan <- mutatePlan mutation identifier apiRequest sCache rPlan
mediaType <- mapLeft ApiRequestError $ negotiateContent conf (iAction apiRequest) (iAcceptMediaType apiRequest)
return $ MutateReadPlan rPlan mPlan SQL.Write mediaType
(rAgg, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf identifier (iAcceptMediaType apiRequest) (dbMediaAggs sCache)
return $ MutateReadPlan rPlan mPlan SQL.Write mediaType rAgg

callReadPlan :: QualifiedIdentifier -> AppConfig -> SchemaCache -> ApiRequest -> InvokeMethod -> Either Error CallReadPlan
callReadPlan identifier conf sCache apiRequest invMethod = do
Expand All @@ -141,15 +145,19 @@ callReadPlan identifier conf sCache apiRequest invMethod = do
(InvPost, Routine.Immutable) -> SQL.Read
(InvPost, Routine.Volatile) -> SQL.Write
cPlan = callPlan proc apiRequest paramKeys args rPlan
mediaType <- mapLeft ApiRequestError $ negotiateContent conf (iAction apiRequest) (iAcceptMediaType apiRequest)
return $ CallReadPlan rPlan cPlan txMode proc mediaType
(rAgg, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf identifier (iAcceptMediaType apiRequest) (dbMediaAggs sCache)
return $ CallReadPlan rPlan cPlan txMode proc mediaType rAgg
where
Preferences{..} = iPreferences apiRequest
qsParams' = QueryParams.qsParams (iQueryParams apiRequest)

inspectPlan :: AppConfig -> ApiRequest -> Either Error InspectPlan
inspectPlan conf apiRequest = do
mediaType <- mapLeft ApiRequestError $ negotiateContent conf (iAction apiRequest) (iAcceptMediaType apiRequest)
inspectPlan :: ApiRequest -> Either Error InspectPlan
inspectPlan apiRequest = do
let producedMTs = [MTOpenAPI, MTApplicationJSON, MTAny]
accepts = iAcceptMediaType apiRequest
mediaType <- if not . null $ L.intersect accepts producedMTs
then Right MTOpenAPI
else Left . ApiRequestError . MediaTypeError $ MediaType.toMime <$> accepts
return $ InspectPlan mediaType SQL.Read

{-|
Expand Down Expand Up @@ -624,25 +632,18 @@ addFilterToLogicForest :: Filter -> [LogicTree] -> [LogicTree]
addFilterToLogicForest flt lf = Stmnt flt : lf

-- | Do content negotiation. i.e. choose a media type based on the intersection of accepted/produced media types.
negotiateContent :: AppConfig -> Action -> [MediaType] -> Either ApiRequestError MediaType
negotiateContent conf action accepts =
negotiateContent :: AppConfig -> QualifiedIdentifier -> [MediaType] -> HM.HashMap (RelIdentifier, MediaType) ResultAggregate -> Either ApiRequestError (ResultAggregate, MediaType)
negotiateContent conf identifier accepts produces =
case firstAcceptedPick of
Just MTAny -> Right MTApplicationJSON -- by default(for */*) we respond with json
Just mt -> Right mt
Nothing -> Left . MediaTypeError $ map MediaType.toMime accepts
Just (x, MTAny) -> Right (x, MTApplicationJSON) -- by default(for */*) we respond with json
Just (x, mt) -> Right (x, mt)
Nothing -> Left . MediaTypeError $ map MediaType.toMime accepts
where
-- if there are multiple accepted media types, pick the first
firstAcceptedPick = listToMaybe $ L.intersect accepts $ producedMediaTypes conf action

producedMediaTypes :: AppConfig -> Action -> [MediaType]
producedMediaTypes conf action =
case action of
ActionRead _ -> defaultMediaTypes
ActionInvoke _ -> defaultMediaTypes
ActionInspect _ -> [MTOpenAPI, MTApplicationJSON, MTAny]
ActionInfo -> defaultMediaTypes
ActionMutate _ -> defaultMediaTypes
where
defaultMediaTypes =
[MTApplicationJSON, MTSingularJSON, MTGeoJSON, MTTextCSV] ++
[MTPlan Nothing Nothing mempty | configDbPlanEnabled conf] ++ [MTAny]
firstAcceptedPick = listToMaybe $ mapMaybe searchMT accepts
lookupIdent mt = -- first search for an aggregate that applies to the particular relation, then for one that applies to anyelement
HM.lookup (RelId identifier, mt) produces <|> HM.lookup (RelAnyElement, mt) produces
searchMT mt = case mt of
MTPlan mType _ _ | configDbPlanEnabled conf -> (,) <$> lookupIdent mType <*> pure mt
| otherwise -> Nothing
x -> (,) <$> lookupIdent x <*> pure x
9 changes: 6 additions & 3 deletions src/PostgREST/Query.hs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ import Protolude hiding (Handler)
type DbHandler = ExceptT Error SQL.Transaction

readQuery :: WrappedReadPlan -> AppConfig -> ApiRequest -> DbHandler ResultSet
readQuery WrappedReadPlan{wrReadPlan, wrMedia} conf@AppConfig{..} apiReq@ApiRequest{iPreferences=Preferences{..}} = do
readQuery WrappedReadPlan{wrReadPlan, wrMedia, wrResAgg} conf@AppConfig{..} apiReq@ApiRequest{iPreferences=Preferences{..}} = do
let countQuery = QueryBuilder.readPlanToCountQuery wrReadPlan
resultSet <-
lift . SQL.statement mempty $
Expand All @@ -81,6 +81,7 @@ readQuery WrappedReadPlan{wrReadPlan, wrMedia} conf@AppConfig{..} apiReq@ApiRequ
)
(shouldCount preferCount)
wrMedia
wrResAgg
configDbPreparedStatements
failNotSingular wrMedia resultSet
optionalRollback conf apiReq
Expand Down Expand Up @@ -151,7 +152,7 @@ deleteQuery mrPlan@MutateReadPlan{mrMedia} apiReq@ApiRequest{..} conf = do
pure resultSet

invokeQuery :: Routine -> CallReadPlan -> ApiRequest -> AppConfig -> PgVersion -> DbHandler ResultSet
invokeQuery rout CallReadPlan{crReadPlan, crCallPlan, crMedia} apiReq@ApiRequest{iPreferences=Preferences{..}} conf@AppConfig{..} pgVer = do
invokeQuery rout CallReadPlan{crReadPlan, crCallPlan, crMedia, crResAgg} apiReq@ApiRequest{iPreferences=Preferences{..}} conf@AppConfig{..} pgVer = do
resultSet <-
lift . SQL.statement mempty $
Statements.prepareCall
Expand All @@ -161,6 +162,7 @@ invokeQuery rout CallReadPlan{crReadPlan, crCallPlan, crMedia} apiReq@ApiRequest
(QueryBuilder.readPlanToCountQuery crReadPlan)
(shouldCount preferCount)
crMedia
crResAgg
configDbPreparedStatements

optionalRollback conf apiReq
Expand All @@ -185,7 +187,7 @@ openApiQuery sCache pgVer AppConfig{..} tSchema =
pure Nothing

writeQuery :: MutateReadPlan -> ApiRequest -> AppConfig -> DbHandler ResultSet
writeQuery MutateReadPlan{mrReadPlan, mrMutatePlan, mrMedia} ApiRequest{iPreferences=Preferences{..}} conf =
writeQuery MutateReadPlan{mrReadPlan, mrMutatePlan, mrMedia, mrResAgg} ApiRequest{iPreferences=Preferences{..}} conf =
let
(isInsert, pkCols) = case mrMutatePlan of {Insert{insPkCols} -> (True, insPkCols); _ -> (False, mempty);}
in
Expand All @@ -195,6 +197,7 @@ writeQuery MutateReadPlan{mrReadPlan, mrMutatePlan, mrMedia} ApiRequest{iPrefere
(QueryBuilder.mutatePlanToQuery mrMutatePlan)
isInsert
mrMedia
mrResAgg
preferRepresentation
pkCols
(configDbPreparedStatements conf)
Expand Down
26 changes: 17 additions & 9 deletions src/PostgREST/Query/SqlFragment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ Any function that outputs a SqlFragment should be in this module.
module PostgREST.Query.SqlFragment
( noLocationF
, SqlFragment
, asCsvF
, asGeoJsonF
, asJsonF
, asJsonSingleF
, aggF
, countF
, fromQi
, limitOffsetF
Expand Down Expand Up @@ -78,7 +75,8 @@ import PostgREST.RangeQuery (NonnegRange, allRange,
rangeLimit, rangeOffset)
import PostgREST.SchemaCache.Identifiers (FieldName,
QualifiedIdentifier (..))
import PostgREST.SchemaCache.Routine (Routine (..),
import PostgREST.SchemaCache.Routine (ResultAggregate (..),
Routine (..),
funcReturnsScalar,
funcReturnsSetOfScalar,
funcReturnsSingleComposite)
Expand Down Expand Up @@ -171,6 +169,14 @@ trimNullChars = T.takeWhile (/= '\x0')
pgFmtIdentList :: [Text] -> SqlFragment
pgFmtIdentList schemas = BS.intercalate ", " $ pgFmtIdent <$> schemas

aggF :: Maybe Routine -> ResultAggregate -> SqlFragment
aggF rout = \case
BuiltinAggJson -> asJsonF rout
BuiltinAggSingleJson -> asJsonSingleF rout
BuiltinAggGeoJson -> asGeoJsonF
BuiltinAggCsv -> asCsvF
CustomAgg qi -> customAggF qi

asCsvF :: SqlFragment
asCsvF = asCsvHeaderF <> " || '\n' || " <> asCsvBodyF
where
Expand Down Expand Up @@ -206,6 +212,9 @@ asJsonF rout
asGeoJsonF :: SqlFragment
asGeoJsonF = "json_build_object('type', 'FeatureCollection', 'features', coalesce(json_agg(ST_AsGeoJSON(_postgrest_t)::json), '[]'))"

customAggF :: QualifiedIdentifier -> SqlFragment
customAggF qi = "coalesce(" <> fromQi qi <> "(_postgrest_t), '')"

locationF :: [Text] -> SqlFragment
locationF pKeys = [qc|(
WITH data AS (SELECT row_to_json(_) AS row FROM {sourceCTEName} AS _ LIMIT 1)
Expand Down Expand Up @@ -423,7 +432,7 @@ intercalateSnippet :: ByteString -> [SQL.Snippet] -> SQL.Snippet
intercalateSnippet _ [] = mempty
intercalateSnippet frag snippets = foldr1 (\a b -> a <> SQL.sql frag <> b) snippets

explainF :: Maybe MTPlanFormat -> [MTPlanOption] -> SQL.Snippet -> SQL.Snippet
explainF :: MTPlanFormat -> [MTPlanOption] -> SQL.Snippet -> SQL.Snippet
explainF fmt opts snip =
"EXPLAIN (" <>
SQL.sql (BS.intercalate ", " (fmtPlanFmt fmt : (fmtPlanOpt <$> opts))) <>
Expand All @@ -436,9 +445,8 @@ explainF fmt opts snip =
fmtPlanOpt PlanBuffers = "BUFFERS"
fmtPlanOpt PlanWAL = "WAL"

fmtPlanFmt Nothing = "FORMAT TEXT"
fmtPlanFmt (Just PlanJSON) = "FORMAT JSON"
fmtPlanFmt (Just PlanText) = "FORMAT TEXT"
fmtPlanFmt PlanJSON = "FORMAT JSON"
fmtPlanFmt PlanText = "FORMAT TEXT"

-- | Do a pg set_config(setting, value, true) call. This is equivalent to a SET LOCAL.
setConfigLocal :: ByteString -> (ByteString, ByteString) -> SQL.Snippet
Expand Down
Loading

0 comments on commit 4243c4a

Please sign in to comment.