Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ConnectionOptions for createConnectionPool parameters #337

Merged
merged 2 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions orville-postgresql/src/Orville/PostgreSQL.hs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,19 @@ module Orville.PostgreSQL
, Orville.runOrvilleWithState

-- * Creating a connection pool
, Connection.ConnectionOptions
( ConnectionOptions
, connectionString
, connectionNoticeReporting
, connectionPoolStripes
, connectionPoolLingerTime
, connectionPoolMaxConnectionsPerStripe
)
, Connection.createConnectionPool
, Connection.Connection
, Connection.Pool
, Connection.NoticeReporting (EnableNoticeReporting, DisableNoticeReporting)
, Connection.StripeOption (OneStripePerCapability, StripeCount)
, Connection.Connection
, Connection.ConnectionPool

-- * Opening transactions and savepoints
, Transaction.withTransaction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,11 @@ module Orville.PostgreSQL.Internal.OrvilleState
where

import qualified Data.Map.Strict as Map
import Data.Pool (Pool)

import Orville.PostgreSQL.ErrorDetailLevel (ErrorDetailLevel)
import Orville.PostgreSQL.Execution.QueryType (QueryType)
import qualified Orville.PostgreSQL.Expr as Expr
import Orville.PostgreSQL.Raw.Connection (Connection)
import Orville.PostgreSQL.Raw.Connection (Connection, ConnectionPool)
import qualified Orville.PostgreSQL.Raw.RawSql as RawSql
import qualified Orville.PostgreSQL.Raw.SqlCommenter as SqlCommenter

Expand All @@ -57,7 +56,7 @@ import qualified Orville.PostgreSQL.Raw.SqlCommenter as SqlCommenter
@since 1.0.0.0
-}
data OrvilleState = OrvilleState
{ _orvilleConnectionPool :: Pool Connection
{ _orvilleConnectionPool :: ConnectionPool
, _orvilleConnectionState :: ConnectionState
, _orvilleErrorDetailLevel :: ErrorDetailLevel
, _orvilleTransactionCallback :: TransactionEvent -> IO ()
Expand All @@ -71,7 +70,7 @@ data OrvilleState = OrvilleState

@since 1.0.0.0
-}
orvilleConnectionPool :: OrvilleState -> Pool Connection
orvilleConnectionPool :: OrvilleState -> ConnectionPool
orvilleConnectionPool =
_orvilleConnectionPool

Expand Down Expand Up @@ -166,7 +165,7 @@ addTransactionCallback newCallback state =

@since 1.0.0.0
-}
newOrvilleState :: ErrorDetailLevel -> Pool Connection -> OrvilleState
newOrvilleState :: ErrorDetailLevel -> ConnectionPool -> OrvilleState
newOrvilleState errorDetailLevel pool =
OrvilleState
{ _orvilleConnectionPool = pool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ where
import Control.Exception (Exception)
import Control.Monad.IO.Class (MonadIO)
import Control.Monad.Trans.Reader (ReaderT (ReaderT), mapReaderT, runReaderT)
import Data.Pool (withResource)

import Orville.PostgreSQL.Internal.OrvilleState
( ConnectedState (ConnectedState, connectedConnection, connectedTransaction)
Expand All @@ -31,7 +30,7 @@ import Orville.PostgreSQL.Internal.OrvilleState
, orvilleConnectionState
)
import Orville.PostgreSQL.Monad.HasOrvilleState (HasOrvilleState (askOrvilleState, localOrvilleState))
import Orville.PostgreSQL.Raw.Connection (Connection)
import Orville.PostgreSQL.Raw.Connection (Connection, withPoolConnection)

{- |
'MonadOrville' is the typeclass that most Orville operations require to
Expand Down Expand Up @@ -188,7 +187,7 @@ withConnectedState connectedAction = do
let
pool = orvilleConnectionPool state
in
liftWithConnection (withResource pool) $ \conn ->
liftWithConnection (withPoolConnection pool) $ \conn ->
let
connectedState =
ConnectedState
Expand Down
5 changes: 2 additions & 3 deletions orville-postgresql/src/Orville/PostgreSQL/Monad/Orville.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ where
import qualified Control.Exception.Safe as ExSafe
import Control.Monad.IO.Class (MonadIO)
import Control.Monad.Trans.Reader (ReaderT, runReaderT)
import Data.Pool (Pool)

import qualified Orville.PostgreSQL.ErrorDetailLevel as ErrorDetailLevel
import qualified Orville.PostgreSQL.Monad.HasOrvilleState as HasOrvilleState
import qualified Orville.PostgreSQL.Monad.MonadOrville as MonadOrville
import qualified Orville.PostgreSQL.OrvilleState as OrvilleState
import Orville.PostgreSQL.Raw.Connection (Connection)
import Orville.PostgreSQL.Raw.Connection (ConnectionPool)

{- |
The 'Orville' Monad provides a easy starter implementation of
Expand Down Expand Up @@ -61,7 +60,7 @@ newtype Orville a = Orville

@since 1.0.0.0
-}
runOrville :: Pool Connection -> Orville a -> IO a
runOrville :: ConnectionPool -> Orville a -> IO a
runOrville =
runOrvilleWithState
. OrvilleState.newOrvilleState ErrorDetailLevel.defaultErrorDetailLevel
Expand Down
143 changes: 110 additions & 33 deletions orville-postgresql/src/Orville/PostgreSQL/Raw/Connection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,30 @@ Stability : Stable
@since 1.0.0.0
-}
module Orville.PostgreSQL.Raw.Connection
( -- * Orville definitions
Connection
, createConnectionPool
( ConnectionOptions
( ConnectionOptions
, connectionString
, connectionNoticeReporting
, connectionPoolStripes
, connectionPoolLingerTime
, connectionPoolMaxConnectionsPerStripe
)
, NoticeReporting (EnableNoticeReporting, DisableNoticeReporting)
, StripeOption (OneStripePerCapability, StripeCount)
, ConnectionPool
, createConnectionPool
, Connection
, withPoolConnection
, executeRaw
, quoteStringLiteral
, quoteIdentifier
, ConnectionUsedAfterCloseError
, ConnectionError
, SqlExecutionError (..)

-- * Re-exports from "Data.Pool" for convenience
, Pool
)
where

import Control.Concurrent (threadWaitRead, threadWaitWrite)
import Control.Concurrent (getNumCapabilities, threadWaitRead, threadWaitWrite)
import Control.Concurrent.MVar (MVar, newMVar, tryReadMVar, tryTakeMVar)
import Control.Exception (Exception, mask, throwIO)
import Control.Monad (void)
Expand All @@ -34,9 +41,9 @@ import qualified Data.ByteString.Builder as BSB
import qualified Data.ByteString.Char8 as B8
import Data.Maybe (fromMaybe)
#if MIN_VERSION_resource_pool(0,4,0)
import Data.Pool (Pool, newPool, defaultPoolConfig, setNumStripes)
import Data.Pool (Pool, newPool, defaultPoolConfig, setNumStripes, withResource)
#else
import Data.Pool (Pool, createPool)
import Data.Pool (Pool, createPool, withResource)
#endif
import qualified Data.Text as T
import qualified Data.Text.Encoding as Enc
Expand All @@ -55,43 +62,113 @@ data NoticeReporting
= EnableNoticeReporting
| DisableNoticeReporting

{- |
Orville always uses a connection pool to manage the number of open connections
to the database. See 'ConnectionConfig' and 'createConnectionPool' to find how
to create a 'ConnectionPool'.

@since 1.0.0.0
-}
newtype ConnectionPool
= ConnectionPool (Pool Connection)

{- |
'createConnectionPool' allocates a pool of connections to a PostgreSQL server.

@since 1.0.0.0
-}
createConnectionPool ::
-- | Whether or not notice reporting from LibPQ should be enabled
NoticeReporting ->
-- | Number of stripes in the connection pool
Int ->
-- | Linger time before closing an idle connection
NominalDiffTime ->
-- | Max number of connections to allocate per stripe
Int ->
-- | A PostgreSQL connection string
BS.ByteString ->
IO (Pool Connection)
createConnectionPool :: ConnectionOptions -> IO ConnectionPool
createConnectionPool options = do
let
open =
connect
(connectionNoticeReporting options)
(B8.pack $ connectionString options)

connPerStripe =
connectionPoolMaxConnectionsPerStripe options

linger =
connectionPoolLingerTime options

stripes <- determineStripeCount (connectionPoolStripes options)

#if MIN_VERSION_resource_pool(0,4,0)
createConnectionPool noticeReporting stripes linger maxRes connectionString =
newPool . setNumStripes (Just stripes) $
fmap ConnectionPool . newPool . setNumStripes (Just stripes) $
defaultPoolConfig
(connect noticeReporting connectionString)
open
close
(realToFrac linger)
(stripes * maxRes)
(stripes * connPerStripe)
#else
createConnectionPool noticeReporting stripes linger maxRes connectionString =
createPool (connect noticeReporting connectionString) close stripes linger maxRes
ConnectionPool <$>
createPool
open
close
stripes
linger
connPerStripe
#endif

{- |
'executeRaw' runs a given SQL statement returning the raw underlying result.
Values for the 'connectionPoolStripes' field of 'ConnectionOptions'

@since 1.0.0.0
-}
data StripeOption
= -- | 'OneStripePerCapability' will cause the connection pool to be set up
-- with one stripe for each capability (processor thread) available to the
-- runtime. This is the best option for multi-threaded connectin pool
-- performance.
OneStripePerCapability
| -- | 'StripeCount' will cause the connection pool to be set up with
-- the specified number of stripes, regardless of how many capabilities
-- the runtime has
StripeCount Int

{- |
Configuration options to pass to 'createConnectionPool' to specify the
parameters for the pool and the connections that it creates.

@since 1.0.0.0
-}
data ConnectionOptions = ConnectionOptions
{ connectionString :: String
-- ^ A PostgreSQL connection string
, connectionNoticeReporting :: NoticeReporting
-- ^ Whether or not notice reporting from LibPQ should be enabled
, connectionPoolStripes :: StripeOption
-- ^ Number of stripes in the connection pool
, connectionPoolLingerTime :: NominalDiffTime
-- ^ Linger time before closing an idle connection
, connectionPoolMaxConnectionsPerStripe :: Int
-- ^ Max number of connections to allocate per stripe
}

{- |
INTERNAL: Resolves the 'StripeOption' to the actual number of stripes to use.
-}
determineStripeCount :: StripeOption -> IO Int
determineStripeCount stripeOption =
case stripeOption of
OneStripePerCapability -> getNumCapabilities
StripeCount n -> pure n

{- |
Allocates a connection from the pool and performs an action with it. This
function will block if the maximum number of connections is reached.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing the since annotation here.

-}
withPoolConnection :: ConnectionPool -> (Connection -> IO a) -> IO a
withPoolConnection (ConnectionPool pool) =
withResource pool

{- |
'executeRaw' runs a given SQL statement returning the raw underlying result.

All handling of stepping through the result set is left to the caller. This
potentially leaves connections open much longer than one would expect if all
of the results are not iterated through immediately *and* the data copied.
Use with caution.
of the results are not iterated through immediately *and* the data copied.
Use with caution.

@since 1.0.0.0
-}
Expand All @@ -115,7 +192,7 @@ executeRaw connection bs params =
newtype Connection = Connection (MVar LibPQ.Connection)

{- |
'connect' is the internal, primitive connection function.
'connect' is the internal, primitive connection function.

This should not be exposed to end users, but instead wrapped in something to create a pool.

Expand All @@ -125,7 +202,7 @@ newtype Connection = Connection (MVar LibPQ.Connection)
@since 1.0.0.0
-}
connect :: NoticeReporting -> BS.ByteString -> IO Connection
connect noticeReporting connectionString =
connect noticeReporting connString =
let
checkSocketAndThreadWait conn threadWaitFn = do
fd <- LibPQ.socket conn
Expand All @@ -150,7 +227,7 @@ connect noticeReporting connectionString =
pure (Connection connectionHandle)
in
do
connection <- LibPQ.connectStart connectionString
connection <- LibPQ.connectStart connString
case noticeReporting of
DisableNoticeReporting -> LibPQ.disableNoticeReporting connection
EnableNoticeReporting -> LibPQ.enableNoticeReporting connection
Expand Down
17 changes: 12 additions & 5 deletions orville-postgresql/test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ module Main
where

import qualified Control.Monad as Monad
import qualified Data.ByteString.Char8 as B8
import qualified Hedgehog as HH
import qualified System.Environment as Env
import qualified System.Exit as SE
Expand Down Expand Up @@ -86,22 +85,30 @@ main = do

Monad.unless (Property.allPassed summary) SE.exitFailure

createTestConnectionPool :: IO (Orville.Pool Orville.Connection)
createTestConnectionPool :: IO Orville.ConnectionPool
createTestConnectionPool = do
connStr <- lookupConnStr
-- Some tests use more than one connection, so the pool size must be greater
Orville.createConnectionPool Orville.DisableNoticeReporting 1 10 2 connStr
-- than 1
Orville.createConnectionPool $
Orville.ConnectionOptions
{ Orville.connectionString = connStr
, Orville.connectionNoticeReporting = Orville.DisableNoticeReporting
, Orville.connectionPoolStripes = Orville.OneStripePerCapability
, Orville.connectionPoolLingerTime = 10
, Orville.connectionPoolMaxConnectionsPerStripe = 2
}

recheckDBProperty :: HH.Size -> HH.Seed -> Property.NamedDBProperty -> IO ()
recheckDBProperty size seed namedProperty = do
pool <- createTestConnectionPool
HH.recheck size seed (snd $ namedProperty pool)

lookupConnStr :: IO B8.ByteString
lookupConnStr :: IO String
lookupConnStr = do
mbConnHostStr <- Env.lookupEnv "TEST_CONN_HOST"
let
connStrUserPass = " user=orville_test password=orville"
case mbConnHostStr of
Nothing -> fail "TEST_CONN_HOST not set, so we don't know what database to connect to!"
Just connHost -> pure . B8.pack $ connHost <> connStrUserPass
Just connHost -> pure $ connHost <> connStrUserPass
Loading