Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify session vault key handling #1930

Merged
merged 2 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions IHP/ApplicationContext.hs
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
module IHP.ApplicationContext where

import IHP.Prelude
import Network.Wai.Session (Session)
import qualified Data.Vault.Lazy as Vault
import IHP.AutoRefresh.Types (AutoRefreshServer)
import IHP.FrameworkConfig (FrameworkConfig)
import IHP.PGListener (PGListener)

data ApplicationContext = ApplicationContext
{ modelContext :: !ModelContext
, session :: !(Vault.Key (Session IO ByteString ByteString))
, autoRefreshServer :: !(IORef AutoRefreshServer)
, frameworkConfig :: !FrameworkConfig
, pgListener :: PGListener
Expand Down
3 changes: 0 additions & 3 deletions IHP/Controller/RequestContext.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import ClassyPrelude
import qualified Data.ByteString.Lazy as LBS
import Network.Wai (Request, Response, ResponseReceived)
import Network.Wai.Parse (File, Param)
import qualified Data.Vault.Lazy as Vault
import Network.Wai.Session (Session)
import IHP.FrameworkConfig
import qualified Data.Aeson as Aeson

Expand All @@ -24,6 +22,5 @@ data RequestContext = RequestContext
{ request :: Request
, respond :: Respond
, requestBody :: RequestBody
, vault :: (Vault.Key (Session IO ByteString ByteString))
, frameworkConfig :: FrameworkConfig
}
5 changes: 2 additions & 3 deletions IHP/Controller/Response.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ where
import ClassyPrelude
import Network.HTTP.Types.Header
import qualified IHP.Controller.Context as Context
import IHP.Controller.Context (ControllerContext(ControllerContext))
import qualified Network.Wai
import Network.Wai (Response)
import qualified Control.Exception as Exception

respondAndExit :: (?context::ControllerContext) => Response -> IO ()
respondAndExit :: (?context :: Context.ControllerContext) => Response -> IO ()
respondAndExit response = do
responseWithHeaders <- addResponseHeadersFromContext response
Exception.throwIO (ResponseException responseWithHeaders)
Expand All @@ -35,7 +34,7 @@ addResponseHeaders headers = Network.Wai.mapResponseHeaders (\hs -> headers <> h
-- > addResponseHeadersFromContext response
-- You probabaly want `setHeader`
--
addResponseHeadersFromContext :: (?context :: ControllerContext) => Response -> IO Response
addResponseHeadersFromContext :: (?context :: Context.ControllerContext) => Response -> IO Response
addResponseHeadersFromContext response = do
maybeHeaders <- Context.maybeFromContext @[Header]
let headers = fromMaybe [] maybeHeaders
Expand Down
11 changes: 9 additions & 2 deletions IHP/Controller/Session.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ module IHP.Controller.Session
, getSessionEither
, deleteSession
, getSessionAndClear
, sessionVaultKey
) where

import IHP.Prelude
Expand All @@ -36,6 +37,8 @@ import qualified Network.Wai as Wai
import qualified Data.Serialize as Serialize
import Data.Serialize (Serialize)
import Data.Serialize.Text ()
import qualified Network.Wai.Session
import System.IO.Unsafe (unsafePerformIO)

-- | Types of possible errors as a result of
-- requesting a value from the session storage
Expand Down Expand Up @@ -161,5 +164,9 @@ sessionVault = case vaultLookup of
Just session -> session
Nothing -> error "sessionInsert: The session vault is missing in the request"
where
RequestContext { request, vault } = ?context.requestContext
vaultLookup = Vault.lookup vault (Wai.vault request)
RequestContext { request } = ?context.requestContext
vaultLookup = Vault.lookup sessionVaultKey request.vault

sessionVaultKey :: Vault.Key (Network.Wai.Session.Session IO ByteString ByteString)
sessionVaultKey = unsafePerformIO Vault.newKey
{-# NOINLINE sessionVaultKey #-}
7 changes: 3 additions & 4 deletions IHP/ControllerSupport.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ module IHP.ControllerSupport

import ClassyPrelude
import IHP.HaskellSupport
import Network.Wai (Response, Request, ResponseReceived, responseLBS, requestHeaders)
import Network.Wai (Request, ResponseReceived, responseLBS, requestHeaders)
import qualified Network.HTTP.Types as HTTP
import qualified Network.Wai
import IHP.ModelSupport
Expand All @@ -39,7 +39,6 @@ import qualified Data.ByteString.Lazy
import qualified IHP.Controller.RequestContext as RequestContext
import IHP.Controller.RequestContext (RequestContext, Respond)
import qualified Data.CaseInsensitive
import qualified Control.Exception as Exception
import qualified IHP.ErrorController as ErrorController
import qualified Data.Typeable as Typeable
import IHP.FrameworkConfig (FrameworkConfig (..), ConfigProvider(..))
Expand Down Expand Up @@ -259,7 +258,7 @@ requestBodyJSON =

{-# INLINE createRequestContext #-}
createRequestContext :: ApplicationContext -> Request -> Respond -> IO RequestContext
createRequestContext ApplicationContext { session, frameworkConfig } request respond = do
createRequestContext ApplicationContext { frameworkConfig } request respond = do
let contentType = lookup hContentType (requestHeaders request)
requestBody <- case contentType of
"application/json" -> do
Expand All @@ -270,7 +269,7 @@ createRequestContext ApplicationContext { session, frameworkConfig } request res
(params, files) <- WaiParse.parseRequestBodyEx frameworkConfig.parseRequestBodyOptions WaiParse.lbsBackEnd request
pure RequestContext.FormBody { .. }

pure RequestContext.RequestContext { request, respond, requestBody, vault = session, frameworkConfig }
pure RequestContext.RequestContext { request, respond, requestBody, frameworkConfig }


-- | Returns a custom config parameter
Expand Down
1 change: 0 additions & 1 deletion IHP/EnvVar.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import IHP.Prelude
import Data.String.Interpolate.IsString (i)
import qualified System.Posix.Env.ByteString as Posix
import Network.Socket (PortNumber)
import Data.Word (Word16)
import IHP.Mail.Types
import IHP.Environment

Expand Down
1 change: 0 additions & 1 deletion IHP/IDE/FileWatcher.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import System.Directory (listDirectory, doesDirectoryExist)
import qualified Data.Map as Map
import qualified System.FSNotify as FS
import IHP.IDE.Types
import qualified Data.Time.Clock as Clock
import qualified Data.List as List
import IHP.IDE.LiveReloadNotificationServer (notifyAssetChange)
import qualified Control.Debounce as Debounce
Expand Down
8 changes: 3 additions & 5 deletions IHP/IDE/ToolServer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@ import qualified Network.Wai.Handler.Warp as Warp
import IHP.IDE.Types
import IHP.IDE.PortConfig
import qualified IHP.ControllerSupport as ControllerSupport
import qualified IHP.ErrorController as ErrorController
import IHP.ApplicationContext
import IHP.ModelSupport
import IHP.RouterSupport hiding (get)
import Network.Wai.Session.ClientSession (clientsessionStore)
import qualified Web.ClientSession as ClientSession
import qualified Data.Vault.Lazy as Vault
import Network.Wai.Middleware.MethodOverridePost (methodOverridePost)
import Network.Wai.Session (withSession)
import qualified Network.WebSockets as Websocket
Expand Down Expand Up @@ -49,6 +47,7 @@ import qualified IHP.PGListener as PGListener
import qualified Network.Wai.Application.Static as Static
import qualified WaiAppStatic.Types as Static
import IHP.Controller.NotFound (handleNotFound)
import IHP.Controller.Session (sessionVaultKey)

withToolServer :: (?context :: Context) => IO () -> IO ()
withToolServer inner = withAsyncBound async (\_ -> inner)
Expand All @@ -72,15 +71,14 @@ startToolServer' port isDebugMode = do
Just baseUrl -> Config.option $ Config.BaseUrl baseUrl
Nothing -> pure ()

session <- Vault.newKey
store <- fmap clientsessionStore (ClientSession.getKey "Config/client_session_key.aes")
let sessionMiddleware :: Wai.Middleware = withSession store "SESSION" (frameworkConfig.sessionCookie) session
let sessionMiddleware :: Wai.Middleware = withSession store "SESSION" (frameworkConfig.sessionCookie) sessionVaultKey
let modelContext = notConnectedModelContext undefined
pgListener <- PGListener.init modelContext
autoRefreshServer <- newIORef (AutoRefresh.newAutoRefreshServer pgListener)
staticApp <- initStaticApp

let applicationContext = ApplicationContext { modelContext, session, autoRefreshServer, frameworkConfig, pgListener }
let applicationContext = ApplicationContext { modelContext, autoRefreshServer, frameworkConfig, pgListener }
let toolServerApplication = ToolServerApplication { devServerContext = ?context }
let application :: Wai.Application = \request respond -> do
let ?applicationContext = applicationContext
Expand Down
2 changes: 1 addition & 1 deletion IHP/RouterSupport.hs
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ withPrefix prefix routes = string prefix >> choice (map (\r -> r <* endOfInput)

frontControllerToWAIApp :: forall app (autoRefreshApp :: Type). (?applicationContext :: ApplicationContext, FrontController app, WSApp autoRefreshApp, Typeable autoRefreshApp, InitControllerContext ()) => Middleware -> app -> Application -> Application
frontControllerToWAIApp middleware application notFoundAction request respond = do
let requestContext = RequestContext { request, respond, requestBody = FormBody { params = [], files = [] }, vault = ?applicationContext.session, frameworkConfig = ?applicationContext.frameworkConfig }
let requestContext = RequestContext { request, respond, requestBody = FormBody { params = [], files = [] }, frameworkConfig = ?applicationContext.frameworkConfig }

let ?context = requestContext

Expand Down
20 changes: 8 additions & 12 deletions IHP/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@ import IHP.Prelude
import qualified Network.Wai.Handler.Warp as Warp
import Network.Wai
import Network.Wai.Middleware.MethodOverridePost (methodOverridePost)
import Network.Wai.Session (withSession, Session)
import Network.Wai.Session (withSession)
import Network.Wai.Session.ClientSession (clientsessionStore)
import qualified Web.ClientSession as ClientSession
import qualified Data.Vault.Lazy as Vault
import IHP.Controller.Session (sessionVaultKey)
import IHP.ApplicationContext
import qualified IHP.ControllerSupport as ControllerSupport
import qualified IHP.Environment as Env
import qualified IHP.PGListener as PGListener

import IHP.FrameworkConfig
import IHP.RouterSupport (frontControllerToWAIApp, FrontController, webSocketApp, webSocketAppWithCustomPath)
import IHP.ErrorController
import IHP.RouterSupport (frontControllerToWAIApp, FrontController)
import qualified IHP.AutoRefresh as AutoRefresh
import qualified IHP.AutoRefresh.Types as AutoRefresh
import IHP.LibDir
Expand Down Expand Up @@ -48,14 +46,12 @@ run configBuilder = do

withInitalizers frameworkConfig modelContext do
withPGListener \pgListener -> do
sessionVault <- Vault.newKey

autoRefreshServer <- newIORef (AutoRefresh.newAutoRefreshServer pgListener)

let ?modelContext = modelContext
let ?applicationContext = ApplicationContext { modelContext = ?modelContext, session = sessionVault, autoRefreshServer, frameworkConfig, pgListener }
let ?applicationContext = ApplicationContext { modelContext = ?modelContext, autoRefreshServer, frameworkConfig, pgListener }

sessionMiddleware <- initSessionMiddleware sessionVault frameworkConfig
sessionMiddleware <- initSessionMiddleware frameworkConfig
staticApp <- initStaticApp frameworkConfig
let corsMiddleware = initCorsMiddleware frameworkConfig
let requestLoggerMiddleware = frameworkConfig.requestLoggerMiddleware
Expand Down Expand Up @@ -108,8 +104,8 @@ initStaticApp frameworkConfig = do

pure (Static.staticApp appSettings)

initSessionMiddleware :: Vault.Key (Session IO ByteString ByteString) -> FrameworkConfig -> IO Middleware
initSessionMiddleware sessionVault FrameworkConfig { sessionCookie } = do
initSessionMiddleware :: FrameworkConfig -> IO Middleware
initSessionMiddleware FrameworkConfig { sessionCookie } = do
let path = "Config/client_session_key.aes"

hasSessionSecretEnvVar <- EnvVar.hasEnvVar "IHP_SESSION_SECRET"
Expand All @@ -118,7 +114,7 @@ initSessionMiddleware sessionVault FrameworkConfig { sessionCookie } = do
if hasSessionSecretEnvVar || not doesConfigDirectoryExist
then ClientSession.getKeyEnv "IHP_SESSION_SECRET"
else ClientSession.getKey path
let sessionMiddleware :: Middleware = withSession store "SESSION" sessionCookie sessionVault
let sessionMiddleware :: Middleware = withSession store "SESSION" sessionCookie sessionVaultKey
pure sessionMiddleware

initCorsMiddleware :: FrameworkConfig -> Middleware
Expand Down
17 changes: 7 additions & 10 deletions IHP/Test/Mocking.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import qualified Network.Wai.Session
import qualified Data.Serialize as Serialize
import qualified Control.Exception as Exception
import qualified IHP.PGListener as PGListener
import IHP.Controller.Session (sessionVaultKey)

type ContextParameters application = (?applicationContext :: ApplicationContext, ?context :: RequestContext, ?modelContext :: ModelContext, ?application :: application, InitControllerContext application, ?mocking :: MockContext application)

Expand All @@ -58,17 +59,15 @@ withIHPApp application configBuilder hspecAction = do
withTestDatabase \testDatabase -> do
modelContext <- createModelContext dbPoolIdleTime dbPoolMaxConnections (testDatabase.url) logger

session <- Vault.newKey
pgListener <- PGListener.init modelContext
autoRefreshServer <- newIORef (AutoRefresh.newAutoRefreshServer pgListener)
let sessionVault = Vault.insert session mempty Vault.empty
let applicationContext = ApplicationContext { modelContext = modelContext, session, autoRefreshServer, frameworkConfig, pgListener }
let sessionVault = Vault.insert sessionVaultKey mempty Vault.empty
let applicationContext = ApplicationContext { modelContext = modelContext, autoRefreshServer, frameworkConfig, pgListener }

let requestContext = RequestContext
{ request = defaultRequest {vault = sessionVault}
, requestBody = FormBody [] []
, respond = const (pure ResponseReceived)
, vault = session
, frameworkConfig = frameworkConfig }

(hspecAction MockContext { .. })
Expand All @@ -81,17 +80,15 @@ mockContextNoDatabase application configBuilder = do
logger <- newLogger def { level = Warn } -- don't log queries
modelContext <- createModelContext dbPoolIdleTime dbPoolMaxConnections databaseUrl logger

session <- Vault.newKey
let sessionVault = Vault.insert session mempty Vault.empty
let sessionVault = Vault.insert sessionVaultKey mempty Vault.empty
pgListener <- PGListener.init modelContext
autoRefreshServer <- newIORef (AutoRefresh.newAutoRefreshServer pgListener)
let applicationContext = ApplicationContext { modelContext = modelContext, session, autoRefreshServer, frameworkConfig, pgListener }
let applicationContext = ApplicationContext { modelContext = modelContext, autoRefreshServer, frameworkConfig, pgListener }

let requestContext = RequestContext
{ request = defaultRequest {vault = sessionVault}
, requestBody = FormBody [] []
, respond = \resp -> pure ResponseReceived
, vault = session
, frameworkConfig = frameworkConfig }

pure MockContext{..}
Expand Down Expand Up @@ -230,8 +227,8 @@ withUser user callback =

insertSession key value = pure ()

newVault = Vault.insert vaultKey newSession (Wai.vault request)
RequestContext { request, vault = vaultKey } = ?mocking.requestContext
newVault = Vault.insert sessionVaultKey newSession (Wai.vault request)
RequestContext { request } = ?mocking.requestContext

sessionValue = Serialize.encode (user.id)
sessionKey = cs (Session.sessionKey @user)
Expand Down
2 changes: 1 addition & 1 deletion Test/Controller/AccessDeniedSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ config = do
makeApplication :: (?applicationContext :: ApplicationContext) => IO Application
makeApplication = do
store <- Session.mapStore_
let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie ?applicationContext.session
let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie sessionVaultKey
pure (sessionMiddleware $ (Server.application handleNotFound) (\app -> app))

assertAccessDenied :: SResponse -> IO ()
Expand Down
2 changes: 1 addition & 1 deletion Test/Controller/CookieSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ createControllerContext = do
let
requestBody = FormBody { params = [], files = [] }
request = Wai.defaultRequest
requestContext = RequestContext { request, respond = error "respond", requestBody, vault = error "vault", frameworkConfig = error "frameworkConfig" }
requestContext = RequestContext { request, respond = error "respond", requestBody, frameworkConfig = error "frameworkConfig" }
let ?requestContext = requestContext
newControllerContext
2 changes: 1 addition & 1 deletion Test/Controller/NotFoundSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ config = do
makeApplication :: (?applicationContext :: ApplicationContext) => IO Application
makeApplication = do
store <- Session.mapStore_
let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie ?applicationContext.session
let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie sessionVaultKey
pure (sessionMiddleware $ (Server.application handleNotFound) (\app -> app))

assertNotFound :: SResponse -> IO ()
Expand Down
4 changes: 2 additions & 2 deletions Test/Controller/ParamSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -434,14 +434,14 @@ createControllerContextWithParams params =
let
requestBody = FormBody { params, files = [] }
request = Wai.defaultRequest
requestContext = RequestContext { request, respond = error "respond", requestBody, vault = error "vault", frameworkConfig = error "frameworkConfig" }
requestContext = RequestContext { request, respond = error "respond", requestBody, frameworkConfig = error "frameworkConfig" }
in FrozenControllerContext { requestContext, customFields = TypeMap.empty }

createControllerContextWithJson params =
let
requestBody = JSONBody { jsonPayload = Just (json params), rawPayload = cs params }
request = Wai.defaultRequest
requestContext = RequestContext { request, respond = error "respond", requestBody, vault = error "vault", frameworkConfig = error "frameworkConfig" }
requestContext = RequestContext { request, respond = error "respond", requestBody, frameworkConfig = error "frameworkConfig" }
in FrozenControllerContext { requestContext, customFields = TypeMap.empty }

json :: Text -> Aeson.Value
Expand Down
2 changes: 1 addition & 1 deletion Test/View/CSSFrameworkSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -721,5 +721,5 @@ createControllerContextWithCSSFramework cssFramework = do
option cssFramework
let requestBody = FormBody { params = [], files = [] }
let request = Wai.defaultRequest
let requestContext = RequestContext { request, respond = error "respond", requestBody, vault = error "vault", frameworkConfig = frameworkConfig }
let requestContext = RequestContext { request, respond = error "respond", requestBody, frameworkConfig = frameworkConfig }
pure FrozenControllerContext { requestContext, customFields = TypeMap.empty }
2 changes: 1 addition & 1 deletion Test/View/FormSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ createControllerContext = do
frameworkConfig <- FrameworkConfig.buildFrameworkConfig (pure ())
let requestBody = FormBody { params = [], files = [] }
let request = Wai.defaultRequest
let requestContext = RequestContext { request, respond = undefined, requestBody, vault = undefined, frameworkConfig = frameworkConfig }
let requestContext = RequestContext { request, respond = undefined, requestBody, frameworkConfig = frameworkConfig }
pure FrozenControllerContext { requestContext, customFields = mempty }

data Project' = Project {id :: (Id' "projects"), title :: Text, meta :: MetaBag} deriving (Eq, Show)
Expand Down
2 changes: 1 addition & 1 deletion Test/ViewSupportSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ config = do
makeApplication :: (?applicationContext :: ApplicationContext) => IO Application
makeApplication = do
store <- Session.mapStore_
let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie ?applicationContext.session
let sessionMiddleware :: Middleware = Session.withSession store "SESSION" ?applicationContext.frameworkConfig.sessionCookie sessionVaultKey
pure (sessionMiddleware $ (Server.application handleNotFound (\app -> app)))

tests :: Spec
Expand Down
Loading