From 07fef25591268fb3afcf235688233f58b223d0e8 Mon Sep 17 00:00:00 2001 From: Taimoor Zaeem Date: Sat, 2 Sep 2023 00:02:03 +0500 Subject: [PATCH] feat: allow full response control when raising exceptions --- CHANGELOG.md | 1 + src/PostgREST/Error.hs | 80 ++++++++++++++++++++++++++++-- test/spec/Feature/Query/RpcSpec.hs | 61 +++++++++++++++++++++++ test/spec/fixtures/schema.sql | 69 ++++++++++++++++++++++++++ 4 files changed, 207 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index df1618f1a2..a058f1df3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). ### Added - #1614, Add `db-pool-automatic-recovery` configuration to disable connection retrying - @taimoorzaeem + - #2492, Allow full response control when raising exceptions - @taimoorzaeem, @laurenceisla ### Fixed diff --git a/src/PostgREST/Error.hs b/src/PostgREST/Error.hs index 1283725dee..7ce061a00b 100644 --- a/src/PostgREST/Error.hs +++ b/src/PostgREST/Error.hs @@ -16,8 +16,10 @@ module PostgREST.Error import qualified Data.Aeson as JSON import qualified Data.ByteString.Char8 as BS +import qualified Data.CaseInsensitive as CI import qualified Data.FuzzySet as Fuzzy import qualified Data.HashMap.Strict as HM +import qualified Data.Map.Internal as M import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Text.Encoding.Error as T @@ -25,7 +27,7 @@ import qualified Hasql.Pool as SQL import qualified Hasql.Session as SQL import qualified Network.HTTP.Types.Status as HTTP -import Data.Aeson ((.=)) +import Data.Aeson ((.:), (.:?), (.=)) import Network.Wai (Response, responseLBS) import Network.HTTP.Types.Header (Header) @@ -359,6 +361,13 @@ type Authenticated = Bool instance PgrstError PgError where status (PgError authed usageError) = pgErrorStatus authed usageError + headers (PgError _ (SQL.SessionUsageError (SQL.QueryError _ _ (SQL.ResultError (SQL.ServerError "PGRST" m d _ _p))))) = + case (parseMessage m, parseDetails d) of + (Just _, Just r) -> headers JSONParseError ++ map intoHeader (M.toList $ getHeaders r) + _ -> headers JSONParseError + where + intoHeader (k,v) = (CI.mk $ T.encodeUtf8 k, T.encodeUtf8 v) + headers err = if status err == HTTP.status401 then [MediaType.toContentType MTApplicationJSON, ("WWW-Authenticate", "Bearer") :: Header] @@ -384,6 +393,18 @@ instance JSON.ToJSON SQL.QueryError where toJSON (SQL.QueryError _ _ e) = JSON.toJSON e instance JSON.ToJSON SQL.CommandError where + -- Special error raised with code PGRST, to allow full response control + toJSON (SQL.ResultError (SQL.ServerError "PGRST" m d _ _p)) = + case (parseMessage m, parseDetails d) of + (Just r, Just _) -> JSON.object [ + "code" .= getCode r, + "message" .= getMessage r, + "details" .= checkMaybe (getDetails r), + "hint" .= checkMaybe (getHint r)] + _ -> JSON.toJSON JSONParseError + where + checkMaybe = maybe JSON.Null JSON.String + toJSON (SQL.ResultError (SQL.ServerError c m d h _p)) = JSON.object [ "code" .= (T.decodeUtf8 c :: Text), "message" .= (T.decodeUtf8 m :: Text), @@ -402,14 +423,13 @@ instance JSON.ToJSON SQL.CommandError where "details" .= (fmap T.decodeUtf8 d :: Maybe Text), "hint" .= JSON.Null] - pgErrorStatus :: Bool -> SQL.UsageError -> HTTP.Status pgErrorStatus _ (SQL.ConnectionUsageError _) = HTTP.status503 pgErrorStatus _ SQL.AcquisitionTimeoutUsageError = HTTP.status504 pgErrorStatus _ (SQL.SessionUsageError (SQL.QueryError _ _ (SQL.ClientError _))) = HTTP.status503 pgErrorStatus authed (SQL.SessionUsageError (SQL.QueryError _ _ (SQL.ResultError rError))) = case rError of - (SQL.ServerError c m _ _ _) -> + (SQL.ServerError c m d _ _) -> case BS.unpack c of '0':'8':_ -> HTTP.status503 -- pg connection err '0':'9':_ -> HTTP.status500 -- triggered action exception @@ -442,12 +462,15 @@ pgErrorStatus authed (SQL.SessionUsageError (SQL.QueryError _ _ (SQL.ResultError "42P01" -> HTTP.status404 -- undefined table "42501" -> if authed then HTTP.status403 else HTTP.status401 -- insufficient privilege 'P':'T':n -> fromMaybe HTTP.status500 (HTTP.mkStatus <$> readMaybe n <*> pure m) + "PGRST" -> + case (parseMessage m, parseDetails d) of + (Just _, Just r) -> maybe (toEnum $ getStatus r) (HTTP.mkStatus (getStatus r) . T.encodeUtf8) (getStatusText r) + _ -> status JSONParseError _ -> HTTP.status400 _ -> HTTP.status500 - data Error = ApiRequestError ApiRequestError | GucHeadersError @@ -460,6 +483,7 @@ data Error | PgErr PgError | PutMatchingPkError | SingularityError Integer + | JSONParseError instance PgrstError Error where status (ApiRequestError err) = status err @@ -473,6 +497,7 @@ instance PgrstError Error where status (PgErr err) = status err status PutMatchingPkError = HTTP.status400 status SingularityError{} = HTTP.status406 + status JSONParseError = HTTP.status500 headers (ApiRequestError err) = headers err headers (JwtTokenInvalid m) = [MediaType.toContentType MTApplicationJSON, invalidTokenHeader m] @@ -533,6 +558,12 @@ instance JSON.ToJSON Error where "details" .= T.unwords ["The result contains", show n, "rows"], "hint" .= JSON.Null] + toJSON JSONParseError = JSON.object [ + "code" .= ApiRequestErrorCode21, + "message" .= ("The message and detail field of RAISE 'PGRST' error expects JSON" :: Text), + "details" .= JSON.Null, + "hint" .= JSON.Null] + toJSON (PgErr err) = JSON.toJSON err toJSON (ApiRequestError err) = JSON.toJSON err @@ -546,6 +577,45 @@ requiredTokenHeader = ("WWW-Authenticate", "Bearer") singularityError :: (Integral a) => a -> Error singularityError = SingularityError . toInteger +-- For parsing byteString to JSON Object, used for allowing full response control +data PgRaiseErrMessage = PgRaiseErrMessage { + getCode :: Text, + getMessage :: Text, + getDetails :: Maybe Text, + getHint :: Maybe Text +} + +data PgRaiseErrDetails = PgRaiseErrDetails { + getStatus :: Int, + getStatusText :: Maybe Text, + getHeaders :: Map Text Text +} + +instance JSON.FromJSON PgRaiseErrMessage where + parseJSON (JSON.Object m) = + PgRaiseErrMessage + <$> m .: "code" + <*> m .: "message" + <*> m .:? "details" + <*> m .:? "hint" + + parseJSON _ = mzero + +instance JSON.FromJSON PgRaiseErrDetails where + parseJSON (JSON.Object d) = + PgRaiseErrDetails + <$> d .: "status" + <*> d .:? "status_text" + <*> d .: "headers" + + parseJSON _ = mzero + +parseMessage :: ByteString -> Maybe PgRaiseErrMessage +parseMessage = JSON.decodeStrict + +parseDetails :: Maybe ByteString -> Maybe PgRaiseErrDetails +parseDetails d = JSON.decodeStrict =<< d + -- Error codes are grouped by common modules or characteristics data ErrorCode -- PostgreSQL connection errors @@ -575,6 +645,7 @@ data ErrorCode | ApiRequestErrorCode18 | ApiRequestErrorCode19 | ApiRequestErrorCode20 + | ApiRequestErrorCode21 -- Schema Cache errors | SchemaCacheErrorCode00 | SchemaCacheErrorCode01 @@ -621,6 +692,7 @@ buildErrorCode code = "PGRST" <> case code of ApiRequestErrorCode18 -> "118" ApiRequestErrorCode19 -> "119" ApiRequestErrorCode20 -> "120" + ApiRequestErrorCode21 -> "121" SchemaCacheErrorCode00 -> "200" SchemaCacheErrorCode01 -> "201" diff --git a/test/spec/Feature/Query/RpcSpec.hs b/test/spec/Feature/Query/RpcSpec.hs index 98a73a002b..881f57cdf3 100644 --- a/test/spec/Feature/Query/RpcSpec.hs +++ b/test/spec/Feature/Query/RpcSpec.hs @@ -1490,3 +1490,64 @@ spec actualPgVersion = `shouldRespondWith` [json| {"code":"22026","details":null,"hint":null,"message":"bit string length 6 does not match type bit(5)"} |] { matchStatus = 400 } + + context "get message and details from raise sqlstate" $ do + it "gets message and details from raise sqlstate PGRST" $ do + r <- request methodGet "/rpc/raise_sqlstate_test1" + [] "" + + let resStatus = simpleStatus r + resHeaders = simpleHeaders r + resBody = simpleBody r + + liftIO $ do + resStatus `shouldBe` Status { statusCode = 332, statusMessage = "My Custom Status" } + resHeaders `shouldSatisfy` elem ("X-Header", "str") + resBody `shouldBe` [json|{"code":"123","message":"ABC","details":"DEF","hint":"XYZ"}|] + + get "/rpc/raise_sqlstate_test2" `shouldRespondWith` + [json|{"code":"123","message":"ABC","details":null,"hint":null}|] + { matchStatus = 332 + , matchHeaders = ["X-Header" <:> "str"] } + + it "get message and details from PGRST raise and checks standard status message" $ do + r <- request methodGet "/rpc/raise_sqlstate_test3" + [] "" + + let resStatus = simpleStatus r + resHeaders = simpleHeaders r + resBody = simpleBody r + + liftIO $ do + resStatus `shouldBe` Status { statusCode = 404, statusMessage = "Not Found" } + resHeaders `shouldSatisfy` elem ("X-Header", "str") + resBody `shouldBe` [json|{"code":"123","message":"ABC","details":null,"hint":null}|] + + + it "get message and details from PGRST raise and checks custom status message" $ do + r <- request methodGet "/rpc/raise_sqlstate_test4" + [] "" + + let resStatus = simpleStatus r + resHeaders = simpleHeaders r + resBody = simpleBody r + + liftIO $ do + resStatus `shouldBe` Status { statusCode = 404, statusMessage = "My Not Found" } + resHeaders `shouldSatisfy` elem ("X-Header", "str") + resBody `shouldBe` [json|{"code":"123","message":"ABC","details":null,"hint":null}|] + + it "returns JSONParseError for invalid JSON in RAISE Message field" $ + get "/rpc/raise_sqlstate_invalid_json_message" `shouldRespondWith` + [json|{"code":"PGRST121","message":"The message and detail field of RAISE 'PGRST' error expects JSON","details":null,"hint":null}|] + { matchStatus = 500 } + + it "returns JSONParseError for invalid JSON in RAISE Details field" $ + get "/rpc/raise_sqlstate_invalid_json_details" `shouldRespondWith` + [json|{"code":"PGRST121","message":"The message and detail field of RAISE 'PGRST' error expects JSON","details":null,"hint":null}|] + { matchStatus = 500 } + + it "returns JSONParseError for missing Details field in RAISE" $ + get "/rpc/raise_sqlstate_missing_details" `shouldRespondWith` + [json|{"code":"PGRST121","message":"The message and detail field of RAISE 'PGRST' error expects JSON","details":null,"hint":null}|] + { matchStatus = 500 } diff --git a/test/spec/fixtures/schema.sql b/test/spec/fixtures/schema.sql index 8d19d7676b..f01b5efb20 100644 --- a/test/spec/fixtures/schema.sql +++ b/test/spec/fixtures/schema.sql @@ -3368,3 +3368,72 @@ $$ language sql; create function returns_setof_record_params(id int, name text) returns setof record as $$ select * from projects p where p.id >= $1 and p.name like $2; $$ language sql; + +create function raise_sqlstate_test1() returns void + language plpgsql + as $$ +begin + raise sqlstate 'PGRST' USING + message = '{"code":"123","message":"ABC","details":"DEF","hint":"XYZ"}', + detail = '{"status":332,"status_text":"My Custom Status","headers":{"X-Header":"str"}}'; +end +$$; + +create function raise_sqlstate_test2() returns void + language plpgsql + as $$ +begin + raise sqlstate 'PGRST' USING + message = '{"code":"123","message":"ABC"}', + detail = '{"status":332,"headers":{"X-Header":"str"}}'; +end +$$; + +create function raise_sqlstate_test3() returns void + language plpgsql + as $$ +begin + raise sqlstate 'PGRST' USING + message = '{"code":"123","message":"ABC"}', + detail = '{"status":404,"headers":{"X-Header":"str"}}'; +end +$$; + +create function raise_sqlstate_test4() returns void + language plpgsql + as $$ +begin + raise sqlstate 'PGRST' USING + message = '{"code":"123","message":"ABC"}', + detail = '{"status":404,"status_text":"My Not Found","headers":{"X-Header":"str"}}'; +end +$$; + +create function raise_sqlstate_invalid_json_message() returns void + language plpgsql + as $$ +begin + raise sqlstate 'PGRST' USING + message = 'INVALID', + detail = '{"status":332,"headers":{"X-Header":"str"}}'; +end +$$; + +create function raise_sqlstate_invalid_json_details() returns void + language plpgsql + as $$ +begin + raise sqlstate 'PGRST' USING + message = '{"code":"123","message":"ABC","details":"DEF"}', + detail = 'INVALID'; +end +$$; + +create function raise_sqlstate_missing_details() returns void + language plpgsql + as $$ +begin + raise sqlstate 'PGRST' USING + message = '{"code":"123","message":"ABC","details":"DEF"}'; +end +$$;