Skip to content

Commit

Permalink
feat: allow full response control when raising exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
taimoorzaeem authored Sep 1, 2023
1 parent 7dc6e2b commit 07fef25
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
### Added

- #1614, Add `db-pool-automatic-recovery` configuration to disable connection retrying - @taimoorzaeem
- #2492, Allow full response control when raising exceptions - @taimoorzaeem, @laurenceisla

### Fixed

Expand Down
80 changes: 76 additions & 4 deletions src/PostgREST/Error.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ module PostgREST.Error

import qualified Data.Aeson as JSON
import qualified Data.ByteString.Char8 as BS
import qualified Data.CaseInsensitive as CI
import qualified Data.FuzzySet as Fuzzy
import qualified Data.HashMap.Strict as HM
import qualified Data.Map.Internal as M
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import qualified Hasql.Pool as SQL
import qualified Hasql.Session as SQL
import qualified Network.HTTP.Types.Status as HTTP

import Data.Aeson ((.=))
import Data.Aeson ((.:), (.:?), (.=))
import Network.Wai (Response, responseLBS)

import Network.HTTP.Types.Header (Header)
Expand Down Expand Up @@ -359,6 +361,13 @@ type Authenticated = Bool
instance PgrstError PgError where
status (PgError authed usageError) = pgErrorStatus authed usageError

headers (PgError _ (SQL.SessionUsageError (SQL.QueryError _ _ (SQL.ResultError (SQL.ServerError "PGRST" m d _ _p))))) =
case (parseMessage m, parseDetails d) of
(Just _, Just r) -> headers JSONParseError ++ map intoHeader (M.toList $ getHeaders r)
_ -> headers JSONParseError
where
intoHeader (k,v) = (CI.mk $ T.encodeUtf8 k, T.encodeUtf8 v)

headers err =
if status err == HTTP.status401
then [MediaType.toContentType MTApplicationJSON, ("WWW-Authenticate", "Bearer") :: Header]
Expand All @@ -384,6 +393,18 @@ instance JSON.ToJSON SQL.QueryError where
toJSON (SQL.QueryError _ _ e) = JSON.toJSON e

instance JSON.ToJSON SQL.CommandError where
-- Special error raised with code PGRST, to allow full response control
toJSON (SQL.ResultError (SQL.ServerError "PGRST" m d _ _p)) =
case (parseMessage m, parseDetails d) of
(Just r, Just _) -> JSON.object [
"code" .= getCode r,
"message" .= getMessage r,
"details" .= checkMaybe (getDetails r),
"hint" .= checkMaybe (getHint r)]
_ -> JSON.toJSON JSONParseError
where
checkMaybe = maybe JSON.Null JSON.String

toJSON (SQL.ResultError (SQL.ServerError c m d h _p)) = JSON.object [
"code" .= (T.decodeUtf8 c :: Text),
"message" .= (T.decodeUtf8 m :: Text),
Expand All @@ -402,14 +423,13 @@ instance JSON.ToJSON SQL.CommandError where
"details" .= (fmap T.decodeUtf8 d :: Maybe Text),
"hint" .= JSON.Null]


pgErrorStatus :: Bool -> SQL.UsageError -> HTTP.Status
pgErrorStatus _ (SQL.ConnectionUsageError _) = HTTP.status503
pgErrorStatus _ SQL.AcquisitionTimeoutUsageError = HTTP.status504
pgErrorStatus _ (SQL.SessionUsageError (SQL.QueryError _ _ (SQL.ClientError _))) = HTTP.status503
pgErrorStatus authed (SQL.SessionUsageError (SQL.QueryError _ _ (SQL.ResultError rError))) =
case rError of
(SQL.ServerError c m _ _ _) ->
(SQL.ServerError c m d _ _) ->
case BS.unpack c of
'0':'8':_ -> HTTP.status503 -- pg connection err
'0':'9':_ -> HTTP.status500 -- triggered action exception
Expand Down Expand Up @@ -442,12 +462,15 @@ pgErrorStatus authed (SQL.SessionUsageError (SQL.QueryError _ _ (SQL.ResultError
"42P01" -> HTTP.status404 -- undefined table
"42501" -> if authed then HTTP.status403 else HTTP.status401 -- insufficient privilege
'P':'T':n -> fromMaybe HTTP.status500 (HTTP.mkStatus <$> readMaybe n <*> pure m)
"PGRST" ->
case (parseMessage m, parseDetails d) of
(Just _, Just r) -> maybe (toEnum $ getStatus r) (HTTP.mkStatus (getStatus r) . T.encodeUtf8) (getStatusText r)
_ -> status JSONParseError
_ -> HTTP.status400

_ -> HTTP.status500



data Error
= ApiRequestError ApiRequestError
| GucHeadersError
Expand All @@ -460,6 +483,7 @@ data Error
| PgErr PgError
| PutMatchingPkError
| SingularityError Integer
| JSONParseError

instance PgrstError Error where
status (ApiRequestError err) = status err
Expand All @@ -473,6 +497,7 @@ instance PgrstError Error where
status (PgErr err) = status err
status PutMatchingPkError = HTTP.status400
status SingularityError{} = HTTP.status406
status JSONParseError = HTTP.status500

headers (ApiRequestError err) = headers err
headers (JwtTokenInvalid m) = [MediaType.toContentType MTApplicationJSON, invalidTokenHeader m]
Expand Down Expand Up @@ -533,6 +558,12 @@ instance JSON.ToJSON Error where
"details" .= T.unwords ["The result contains", show n, "rows"],
"hint" .= JSON.Null]

toJSON JSONParseError = JSON.object [
"code" .= ApiRequestErrorCode21,
"message" .= ("The message and detail field of RAISE 'PGRST' error expects JSON" :: Text),
"details" .= JSON.Null,
"hint" .= JSON.Null]

toJSON (PgErr err) = JSON.toJSON err
toJSON (ApiRequestError err) = JSON.toJSON err

Expand All @@ -546,6 +577,45 @@ requiredTokenHeader = ("WWW-Authenticate", "Bearer")
singularityError :: (Integral a) => a -> Error
singularityError = SingularityError . toInteger

-- For parsing byteString to JSON Object, used for allowing full response control
data PgRaiseErrMessage = PgRaiseErrMessage {
getCode :: Text,
getMessage :: Text,
getDetails :: Maybe Text,
getHint :: Maybe Text
}

data PgRaiseErrDetails = PgRaiseErrDetails {
getStatus :: Int,
getStatusText :: Maybe Text,
getHeaders :: Map Text Text
}

instance JSON.FromJSON PgRaiseErrMessage where
parseJSON (JSON.Object m) =
PgRaiseErrMessage
<$> m .: "code"
<*> m .: "message"
<*> m .:? "details"
<*> m .:? "hint"

parseJSON _ = mzero

instance JSON.FromJSON PgRaiseErrDetails where
parseJSON (JSON.Object d) =
PgRaiseErrDetails
<$> d .: "status"
<*> d .:? "status_text"
<*> d .: "headers"

parseJSON _ = mzero

parseMessage :: ByteString -> Maybe PgRaiseErrMessage
parseMessage = JSON.decodeStrict

parseDetails :: Maybe ByteString -> Maybe PgRaiseErrDetails
parseDetails d = JSON.decodeStrict =<< d

-- Error codes are grouped by common modules or characteristics
data ErrorCode
-- PostgreSQL connection errors
Expand Down Expand Up @@ -575,6 +645,7 @@ data ErrorCode
| ApiRequestErrorCode18
| ApiRequestErrorCode19
| ApiRequestErrorCode20
| ApiRequestErrorCode21
-- Schema Cache errors
| SchemaCacheErrorCode00
| SchemaCacheErrorCode01
Expand Down Expand Up @@ -621,6 +692,7 @@ buildErrorCode code = "PGRST" <> case code of
ApiRequestErrorCode18 -> "118"
ApiRequestErrorCode19 -> "119"
ApiRequestErrorCode20 -> "120"
ApiRequestErrorCode21 -> "121"

SchemaCacheErrorCode00 -> "200"
SchemaCacheErrorCode01 -> "201"
Expand Down
61 changes: 61 additions & 0 deletions test/spec/Feature/Query/RpcSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1490,3 +1490,64 @@ spec actualPgVersion =
`shouldRespondWith`
[json| {"code":"22026","details":null,"hint":null,"message":"bit string length 6 does not match type bit(5)"} |]
{ matchStatus = 400 }

context "get message and details from raise sqlstate" $ do
it "gets message and details from raise sqlstate PGRST" $ do
r <- request methodGet "/rpc/raise_sqlstate_test1"
[] ""

let resStatus = simpleStatus r
resHeaders = simpleHeaders r
resBody = simpleBody r

liftIO $ do
resStatus `shouldBe` Status { statusCode = 332, statusMessage = "My Custom Status" }
resHeaders `shouldSatisfy` elem ("X-Header", "str")
resBody `shouldBe` [json|{"code":"123","message":"ABC","details":"DEF","hint":"XYZ"}|]

get "/rpc/raise_sqlstate_test2" `shouldRespondWith`
[json|{"code":"123","message":"ABC","details":null,"hint":null}|]
{ matchStatus = 332
, matchHeaders = ["X-Header" <:> "str"] }

it "get message and details from PGRST raise and checks standard status message" $ do
r <- request methodGet "/rpc/raise_sqlstate_test3"
[] ""

let resStatus = simpleStatus r
resHeaders = simpleHeaders r
resBody = simpleBody r

liftIO $ do
resStatus `shouldBe` Status { statusCode = 404, statusMessage = "Not Found" }
resHeaders `shouldSatisfy` elem ("X-Header", "str")
resBody `shouldBe` [json|{"code":"123","message":"ABC","details":null,"hint":null}|]


it "get message and details from PGRST raise and checks custom status message" $ do
r <- request methodGet "/rpc/raise_sqlstate_test4"
[] ""

let resStatus = simpleStatus r
resHeaders = simpleHeaders r
resBody = simpleBody r

liftIO $ do
resStatus `shouldBe` Status { statusCode = 404, statusMessage = "My Not Found" }
resHeaders `shouldSatisfy` elem ("X-Header", "str")
resBody `shouldBe` [json|{"code":"123","message":"ABC","details":null,"hint":null}|]

it "returns JSONParseError for invalid JSON in RAISE Message field" $
get "/rpc/raise_sqlstate_invalid_json_message" `shouldRespondWith`
[json|{"code":"PGRST121","message":"The message and detail field of RAISE 'PGRST' error expects JSON","details":null,"hint":null}|]
{ matchStatus = 500 }

it "returns JSONParseError for invalid JSON in RAISE Details field" $
get "/rpc/raise_sqlstate_invalid_json_details" `shouldRespondWith`
[json|{"code":"PGRST121","message":"The message and detail field of RAISE 'PGRST' error expects JSON","details":null,"hint":null}|]
{ matchStatus = 500 }

it "returns JSONParseError for missing Details field in RAISE" $
get "/rpc/raise_sqlstate_missing_details" `shouldRespondWith`
[json|{"code":"PGRST121","message":"The message and detail field of RAISE 'PGRST' error expects JSON","details":null,"hint":null}|]
{ matchStatus = 500 }
69 changes: 69 additions & 0 deletions test/spec/fixtures/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3368,3 +3368,72 @@ $$ language sql;
create function returns_setof_record_params(id int, name text) returns setof record as $$
select * from projects p where p.id >= $1 and p.name like $2;
$$ language sql;

create function raise_sqlstate_test1() returns void
language plpgsql
as $$
begin
raise sqlstate 'PGRST' USING
message = '{"code":"123","message":"ABC","details":"DEF","hint":"XYZ"}',
detail = '{"status":332,"status_text":"My Custom Status","headers":{"X-Header":"str"}}';
end
$$;

create function raise_sqlstate_test2() returns void
language plpgsql
as $$
begin
raise sqlstate 'PGRST' USING
message = '{"code":"123","message":"ABC"}',
detail = '{"status":332,"headers":{"X-Header":"str"}}';
end
$$;

create function raise_sqlstate_test3() returns void
language plpgsql
as $$
begin
raise sqlstate 'PGRST' USING
message = '{"code":"123","message":"ABC"}',
detail = '{"status":404,"headers":{"X-Header":"str"}}';
end
$$;

create function raise_sqlstate_test4() returns void
language plpgsql
as $$
begin
raise sqlstate 'PGRST' USING
message = '{"code":"123","message":"ABC"}',
detail = '{"status":404,"status_text":"My Not Found","headers":{"X-Header":"str"}}';
end
$$;

create function raise_sqlstate_invalid_json_message() returns void
language plpgsql
as $$
begin
raise sqlstate 'PGRST' USING
message = 'INVALID',
detail = '{"status":332,"headers":{"X-Header":"str"}}';
end
$$;

create function raise_sqlstate_invalid_json_details() returns void
language plpgsql
as $$
begin
raise sqlstate 'PGRST' USING
message = '{"code":"123","message":"ABC","details":"DEF"}',
detail = 'INVALID';
end
$$;

create function raise_sqlstate_missing_details() returns void
language plpgsql
as $$
begin
raise sqlstate 'PGRST' USING
message = '{"code":"123","message":"ABC","details":"DEF"}';
end
$$;

0 comments on commit 07fef25

Please sign in to comment.