From bec77df4a4e8cec28f1ef0542200a11a470e5d02 Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford Date: Wed, 2 Oct 2024 12:55:28 +0700 Subject: [PATCH 01/13] Have `/status` actually make sense --- .../src/Inferno/ML/Server/Client.hs | 2 +- .../src/Inferno/ML/Server/Types.hs | 15 ++++++++------- inferno-ml-server/src/Inferno/ML/Server.hs | 10 ++++++---- .../src/Inferno/ML/Server/Inference.hs | 2 +- inferno-ml-server/src/Inferno/ML/Server/Log.hs | 2 +- inferno-ml-server/src/Inferno/ML/Server/Types.hs | 2 +- 6 files changed, 18 insertions(+), 15 deletions(-) diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs index b8d2ad7..5bc1992 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs @@ -18,7 +18,7 @@ import Servant.Client.Streaming (ClientM, client) -- | Get the status of the server. @Nothing@ indicates that an inference job -- is being evaluated. @Just ()@ means the server is idle -statusC :: ClientM (Maybe ()) +statusC :: ClientM ServerStatus -- | Run an inference parameter inferenceC :: diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs index 7aae8ce..da08487 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs @@ -122,13 +122,8 @@ import Web.HttpApiData -- API type for `inferno-ml-server` type InfernoMlServerAPI uid gid p s t = - -- Check if the server is up and if any job is currently running: - -- - -- * `Nothing` -> The server is evaluating a script - -- * `Just ()` -> The server is not doing anything and can be killed - -- - -- This can be implemented using an `MVar ()` - "status" :> Get '[JSON] (Maybe ()) + -- Check if the server is up and if any job is currently running + "status" :> Get '[JSON] ServerStatus -- Evaluate an inference script :<|> "inference" :> Capture "id" (Id (InferenceParam uid gid p s)) @@ -167,6 +162,12 @@ type BridgeAPI p t = -- This means the same output may appear more than once in the stream type WriteStream m = ConduitT () (Int, [(EpochTime, IValue)]) m () +data ServerStatus + = Idle + | EvaluatingParam + deriving stock (Show, Eq, Generic) + deriving anyclass (FromJSON, ToJSON) + -- | Information for contacting a bridge server that implements the 'BridgeAPI' data BridgeInfo uid gid p s = BridgeInfo { id :: Id (InferenceParam uid gid p s), diff --git a/inferno-ml-server/src/Inferno/ML/Server.hs b/inferno-ml-server/src/Inferno/ML/Server.hs index 5026613..1534fd0 100644 --- a/inferno-ml-server/src/Inferno/ML/Server.hs +++ b/inferno-ml-server/src/Inferno/ML/Server.hs @@ -126,10 +126,12 @@ api = Proxy server :: ServerT InfernoMlServerAPI RemoteM server = getStatus :<|> runInferenceParam :<|> cancelInference where - -- If the server is currently evaluating a script, this will return `Nothing`, - -- otherwise `Just ()` - getStatus :: RemoteM (Maybe ()) - getStatus = tryReadMVar =<< view #lock + -- If the server is currently evaluating a script, the var will be taken, + -- i.e. evaluate to `Nothing`, otherwise `Just ()` + getStatus :: RemoteM ServerStatus + getStatus = + fmap (maybe EvaluatingScript (const Idle)) $ + tryReadMVar =<< view #lock -- When an inference request is run, the server will store the `Async` in -- the `job` `MVar`. Canceling the request throws to the `Async` thread diff --git a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs index ed14cad..2bf5604 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs @@ -147,7 +147,7 @@ runInferenceParam ipid mres uuid = -- need to be updated to use an absolute path to a versioned model, -- e.g. `loadModel "~/inferno/.cache/..."`) withCurrentDirectory (view #path cache) $ do - logInfo $ EvaluatingScript ipid + logInfo $ EvaluatingParam ipid traverse_ linkVersionedModel =<< getAndCacheModels cache (view #models param) runEval interpreter param t obj diff --git a/inferno-ml-server/src/Inferno/ML/Server/Log.hs b/inferno-ml-server/src/Inferno/ML/Server/Log.hs index 6a59c6e..41f423d 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Log.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Log.hs @@ -29,7 +29,7 @@ traceRemote = \case tshow $ t `div` 1000000, "(seconds)" ] - EvaluatingScript s -> + EvaluatingParam s -> Text.unwords [ "Evaluating inferno script for parameter:", tshow s diff --git a/inferno-ml-server/src/Inferno/ML/Server/Types.hs b/inferno-ml-server/src/Inferno/ML/Server/Types.hs index c9d7bd4..ce5af16 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Types.hs @@ -375,7 +375,7 @@ data RemoteTrace data TraceInfo = StartingServer | RunningInference (Id InferenceParam) Int - | EvaluatingScript (Id InferenceParam) + | EvaluatingParam (Id InferenceParam) | CopyingModel (Id ModelVersion) | OtherInfo Text deriving stock (Show, Eq, Generic) From a0a196c3802adb38c0c7b4e6e83c4d26b4695deb Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford Date: Wed, 2 Oct 2024 13:33:34 +0700 Subject: [PATCH 02/13] WIP: Testing route --- .../src/Inferno/ML/Server/Client.hs | 29 ++++++---- .../src/Inferno/ML/Server/Types.hs | 53 +++++++++++++++---- inferno-ml-server/src/Inferno/ML/Server.hs | 2 +- .../src/Inferno/ML/Server/Types.hs | 1 - 4 files changed, 65 insertions(+), 20 deletions(-) diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs index 5bc1992..d952108 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs @@ -1,27 +1,33 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE NoMonomorphismRestriction #-} module Inferno.ML.Server.Client ( statusC, inferenceC, + inferenceTestC, cancelC, ) where +import Data.Aeson (ToJSON) import Data.Int (Int64) import Data.Proxy (Proxy (Proxy)) import Data.UUID (UUID) import Inferno.ML.Server.Types -import Servant ((:<|>) ((:<|>))) import Servant.Client.Streaming (ClientM, client) -- | Get the status of the server. @Nothing@ indicates that an inference job -- is being evaluated. @Just ()@ means the server is idle statusC :: ClientM ServerStatus +statusC = client $ Proxy @StatusAPI + +-- | Cancel the existing inference job, if it exists +cancelC :: ClientM () +cancelC = client $ Proxy @CancelAPI -- | Run an inference parameter inferenceC :: + forall uid gid p s. -- | SQL identifier of the inference parameter to be run Id (InferenceParam uid gid p s) -> -- | Optional resolution for scripts that use e.g. @valueAt@; defaults to @@ -38,11 +44,16 @@ inferenceC :: -- (not defined in this repository) to verify this before directing -- the writes to their final destination ClientM (WriteStream IO) +inferenceC = client $ Proxy @(InferenceAPI uid gid p s) --- | Cancel the existing inference job, if it exists -cancelC :: ClientM () -statusC :<|> inferenceC :<|> cancelC = - client api - -api :: Proxy (InfernoMlServerAPI uid gid p s t) -api = Proxy +-- | Run an inference parameter +inferenceTestC :: + forall uid gid p s. + ToJSON p => + -- | SQL identifier of the inference parameter to be run + Id (InferenceParam uid gid p s) -> + Maybe Int64 -> + UUID -> + EvaluationEnv gid p -> + ClientM (WriteStream IO) +inferenceTestC = client $ Proxy @(InferenceTestAPI uid gid p s) diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs index da08487..4f8198b 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs @@ -84,6 +84,7 @@ import Servant Put, QueryParam, QueryParam', + ReqBody, Required, StreamPost, (:<|>), @@ -121,16 +122,34 @@ import Web.HttpApiData ) -- API type for `inferno-ml-server` -type InfernoMlServerAPI uid gid p s t = +type InfernoMlServerAPI uid gid p s = + StatusAPI + -- Evaluate an inference script + :<|> InferenceAPI uid gid p s + :<|> InferenceTestAPI uid gid p s + :<|> CancelAPI + +type StatusAPI = -- Check if the server is up and if any job is currently running "status" :> Get '[JSON] ServerStatus - -- Evaluate an inference script - :<|> "inference" - :> Capture "id" (Id (InferenceParam uid gid p s)) - :> QueryParam "res" Int64 - :> QueryParam' '[Required] "uuid" UUID - :> StreamPost NewlineFraming JSON (WriteStream IO) - :<|> "inference" :> "cancel" :> Put '[JSON] () + +type CancelAPI = "inference" :> "cancel" :> Put '[JSON] () + +type InferenceAPI uid gid p s = + "inference" + :> Capture "id" (Id (InferenceParam uid gid p s)) + :> QueryParam "res" Int64 + :> QueryParam' '[Required] "uuid" UUID + :> StreamPost NewlineFraming JSON (WriteStream IO) + +type InferenceTestAPI uid gid p s = + -- Evaluate an inference script + "inference" + :> Capture "id" (Id (InferenceParam uid gid p s)) + :> QueryParam "res" Int64 + :> QueryParam' '[Required] "uuid" UUID + :> ReqBody '[JSON] (EvaluationEnv gid p) + :> StreamPost NewlineFraming JSON (WriteStream IO) -- A bridge to get or write data for use with Inferno scripts. This is implemented -- by a bridge server connected to a data source, not by `inferno-ml-server` @@ -164,7 +183,7 @@ type WriteStream m = ConduitT () (Int, [(EpochTime, IValue)]) m () data ServerStatus = Idle - | EvaluatingParam + | EvaluatingScript deriving stock (Show, Eq, Generic) deriving anyclass (FromJSON, ToJSON) @@ -1038,6 +1057,22 @@ instance Ord a => Ord (SingleOrMany a) where (Single _, Many _) -> LT (Many _, Single _) -> GT +-- | An environment that can be used to override the @inferno-ml-server@ script +-- evaluator. This allows for more interactive testing +data EvaluationEnv gid p = EvaluationEnv + { script :: VCObjectHash, + inputs :: Map Ident (SingleOrMany p, ScriptInputType), + models :: + Map + Ident + ( Id (ModelVersion gid Oid), + -- Name of parent model + Text + ) + } + deriving stock (Show, Eq, Generic) + deriving anyclass (FromJSON, ToJSON) + tshow :: Show a => a -> Text tshow = Text.pack . show diff --git a/inferno-ml-server/src/Inferno/ML/Server.hs b/inferno-ml-server/src/Inferno/ML/Server.hs index 1534fd0..2c6ef85 100644 --- a/inferno-ml-server/src/Inferno/ML/Server.hs +++ b/inferno-ml-server/src/Inferno/ML/Server.hs @@ -124,7 +124,7 @@ api :: Proxy InfernoMlServerAPI api = Proxy server :: ServerT InfernoMlServerAPI RemoteM -server = getStatus :<|> runInferenceParam :<|> cancelInference +server = getStatus :<|> runInferenceParam :<|> undefined :<|> cancelInference where -- If the server is currently evaluating a script, the var will be taken, -- i.e. evaluate to `Nothing`, otherwise `Just ()` diff --git a/inferno-ml-server/src/Inferno/ML/Server/Types.hs b/inferno-ml-server/src/Inferno/ML/Server/Types.hs index ce5af16..e671842 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Types.hs @@ -479,7 +479,6 @@ type InfernoMlServerAPI = (EntityId GId) PID VCObjectHash - EpochTime -- Orphans From b0024f71323d97df353aed66e9677f540dc4896c Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford Date: Thu, 3 Oct 2024 11:36:24 +0700 Subject: [PATCH 03/13] More WIP: Test endpoint --- inferno-ml-server/src/Inferno/ML/Server.hs | 6 +- .../src/Inferno/ML/Server/Inference.hs | 82 +++++++++++++++---- .../src/Inferno/ML/Server/Types.hs | 3 + 3 files changed, 72 insertions(+), 19 deletions(-) diff --git a/inferno-ml-server/src/Inferno/ML/Server.hs b/inferno-ml-server/src/Inferno/ML/Server.hs index 2c6ef85..b93629d 100644 --- a/inferno-ml-server/src/Inferno/ML/Server.hs +++ b/inferno-ml-server/src/Inferno/ML/Server.hs @@ -124,7 +124,11 @@ api :: Proxy InfernoMlServerAPI api = Proxy server :: ServerT InfernoMlServerAPI RemoteM -server = getStatus :<|> runInferenceParam :<|> undefined :<|> cancelInference +server = + getStatus + :<|> runInferenceParam + :<|> testInferenceParam + :<|> cancelInference where -- If the server is currently evaluating a script, the var will be taken, -- i.e. evaluate to `Nothing`, otherwise `Just ()` diff --git a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs index 2bf5604..7067ee1 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs @@ -1,13 +1,14 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE LexicalNegation #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -Wno-unused-local-binds #-} module Inferno.ML.Server.Inference ( runInferenceParam, + testInferenceParam, getAndCacheModels, linkVersionedModel, ) @@ -43,6 +44,7 @@ import Database.PostgreSQL.Simple import Database.PostgreSQL.Simple.Newtypes (getAeson) import Database.PostgreSQL.Simple.SqlQQ (sql) import Foreign.C (CTime) +import GHC.Generics (Generic) import Inferno.Core ( Interpreter (Interpreter, evalExpr, mkEnvFromClosure), ) @@ -110,6 +112,41 @@ runInferenceParam :: UUID -> RemoteM (WriteStream IO) runInferenceParam ipid mres uuid = + runInferenceParamWithEnv ipid uuid + =<< mkScriptEnv + =<< getParameterWithModels ipid + where + mkScriptEnv :: InferenceParamWithModels -> RemoteM ScriptEnv + mkScriptEnv pwm = + ScriptEnv param (view #models pwm) (view #inputs param) + <$> getVcObject script + <*> pure script + <*> pure mres + <*> liftIO nowCTime + where + param :: InferenceParam + param = pwm ^. #param + + script :: VCObjectHash + script = param ^. #script + + nowCTime :: IO CTime + nowCTime = fromIntegral @Int . round <$> getPOSIXTime + +testInferenceParam :: + Id InferenceParam -> + Maybe Int64 -> + UUID -> + EvaluationEnv -> + RemoteM (WriteStream IO) +testInferenceParam = undefined + +runInferenceParamWithEnv :: + Id InferenceParam -> + UUID -> + ScriptEnv -> + RemoteM (WriteStream IO) +runInferenceParamWithEnv ipid uuid senv = withTimeoutMillis $ \t -> do logInfo $ RunningInference ipid t maybe (throwM (ScriptTimeout t)) pure @@ -139,9 +176,7 @@ runInferenceParam ipid mres uuid = -- until the server is started again interpreter <- getOrMkInferno ipid param <- getParameterWithModels ipid - obj <- param ^. #param . #script & getVcObject cache <- view $ #config . #cache - t <- liftIO $ fromIntegral @Int . round <$> getPOSIXTime -- Change working directories to the model cache so that Hasktorch -- can find the models using relative paths (otherwise the AST would -- need to be updated to use an absolute path to a versioned model, @@ -150,16 +185,13 @@ runInferenceParam ipid mres uuid = logInfo $ EvaluatingParam ipid traverse_ linkVersionedModel =<< getAndCacheModels cache (view #models param) - runEval interpreter param t obj + runEval interpreter where runEval :: Interpreter RemoteM BridgeMlValue -> - InferenceParamWithModels -> - CTime -> - VCMeta VCObject -> RemoteM (WriteStream IO) - runEval Interpreter {evalExpr, mkEnvFromClosure} param t vcm = - vcm ^. #obj & \case + runEval Interpreter {evalExpr, mkEnvFromClosure} = + senv ^. #obj . #obj & \case VCFunction {} -> do let -- Note that this both includes inputs (i.e. readable) -- and outputs (i.e. writable, or readable/writable). @@ -171,7 +203,7 @@ runInferenceParam ipid mres uuid = -- and commits the output write object pids :: [SingleOrMany PID] pids = - param ^.. #param . #inputs . to Map.toAscList . each . _2 . _1 + senv ^.. #inputs . to Map.toAscList . each . _2 . _1 -- These are all of the models selected for use with the -- script. The ID of the actual model version is included, @@ -190,7 +222,7 @@ runInferenceParam ipid mres uuid = -- in the model cache (see `getAndCacheModels` below) models :: [(Id ModelVersion, Text)] models = - param ^.. #models . to Map.toAscList . each . _2 + senv ^.. #models . to Map.toAscList . each . _2 mkIdentWith :: Text -> Int -> ExtIdent mkIdentWith x = ExtIdent . Right . (x <>) . tshow @@ -231,8 +263,8 @@ runInferenceParam ipid mres uuid = closure :: Map VCObjectHash VCObject closure = - param ^. #param . #script - & ( `Map.singleton` view #obj vcm + senv ^. #script + & ( `Map.singleton` view (#obj . #obj) senv ) expr :: Expr (Maybe VCObjectHash) () @@ -241,7 +273,7 @@ runInferenceParam ipid mres uuid = Var () mhash LocalScope dummy where mhash :: Maybe VCObjectHash - mhash = param ^? #param . #script + mhash = senv ^. #script & Just -- See note above about inputs/outputs args :: [Expr (Maybe a) ()] @@ -287,15 +319,15 @@ runInferenceParam ipid mres uuid = -- use that. Otherwise, use the resolution stored in the -- parameter resolution :: InverseResolution - resolution = mres ^. non res & toResolution + resolution = senv ^. #mres . non res & toResolution where res :: Int64 - res = param ^. #param . #resolution & fromIntegral + res = senv ^. #param . #resolution & fromIntegral implEnv :: Map ExtIdent (Value BridgeMlValue m) implEnv = Map.fromList - [ (ExtIdent $ Right "now", VEpochTime t), + [ (ExtIdent $ Right "now", VEpochTime $ view #ctime senv), ( ExtIdent $ Right "resolution", VCustom . VExtended $ VResolution resolution ) @@ -305,7 +337,7 @@ runInferenceParam ipid mres uuid = . InvalidScript $ Text.unwords [ "Script identified by VC hash", - param ^. #param . #script & tshow, + senv ^. #script & tshow, "is not a function" ] @@ -569,3 +601,17 @@ modelsByAccessTime = sortByM compareAccessTime <=< listDirectory compareAccessTime f1 f2 = getAccessTime f1 >>= \t1 -> compare t1 <$> getAccessTime f2 + +-- Everything needed to evaluate an ML script. For the normal endpoint, all of +-- these will be derived directly from the param. For the interactive test +-- endpoint, these will be overridden +data ScriptEnv = ScriptEnv + { param :: InferenceParam, + models :: Map Ident (Id ModelVersion, Text), + inputs :: Map Ident (SingleOrMany PID, ScriptInputType), + obj :: VCMeta VCObject, + script :: VCObjectHash, + mres :: Maybe Int64, + ctime :: CTime + } + deriving stock (Generic) diff --git a/inferno-ml-server/src/Inferno/ML/Server/Types.hs b/inferno-ml-server/src/Inferno/ML/Server/Types.hs index e671842..c67c80d 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Types.hs @@ -76,6 +76,7 @@ import Inferno.Core (Interpreter) import Inferno.ML.Server.Module.Types as M import "inferno-ml-server-types" Inferno.ML.Server.Types as M hiding ( BridgeInfo, + EvaluationEnv, EvaluationInfo, InferenceParam, InferenceParamWithModels, @@ -480,6 +481,8 @@ type InfernoMlServerAPI = PID VCObjectHash +type EvaluationEnv = Types.EvaluationEnv (EntityId GId) PID + -- Orphans instance FromField VCObjectHash where From 33ad140ee1bf0f17e09a61ea1d085210e55222da Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford Date: Fri, 4 Oct 2024 11:00:41 +0700 Subject: [PATCH 04/13] Mostly finish testing route --- .../src/Inferno/ML/Server/Inference.hs | 72 +++++++++++-------- 1 file changed, 42 insertions(+), 30 deletions(-) diff --git a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs index 7067ee1..b47b110 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs @@ -1,10 +1,13 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE LexicalNegation #-} {-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE NoFieldSelectors #-} module Inferno.ML.Server.Inference ( runInferenceParam, @@ -118,20 +121,10 @@ runInferenceParam ipid mres uuid = where mkScriptEnv :: InferenceParamWithModels -> RemoteM ScriptEnv mkScriptEnv pwm = - ScriptEnv param (view #models pwm) (view #inputs param) - <$> getVcObject script - <*> pure script - <*> pure mres - <*> liftIO nowCTime - where - param :: InferenceParam - param = pwm ^. #param - - script :: VCObjectHash - script = param ^. #script - - nowCTime :: IO CTime - nowCTime = fromIntegral @Int . round <$> getPOSIXTime + ScriptEnv pwm.param pwm.models pwm.param.inputs + <$> getVcObject pwm.param.script + ?? pwm.param.script + ?? mres testInferenceParam :: Id InferenceParam -> @@ -139,7 +132,29 @@ testInferenceParam :: UUID -> EvaluationEnv -> RemoteM (WriteStream IO) -testInferenceParam = undefined +testInferenceParam ipid mres uuid eenv = + runInferenceParamWithEnv ipid uuid + =<< mkScriptEnv + -- Just need to get the param, we already have the model information + -- from the overrides + =<< getParam + where + -- Note that, unlike `runInferenceParam`, several of the items required + -- for script eval come from the `EvaluationEnv` + mkScriptEnv :: InferenceParam -> RemoteM ScriptEnv + mkScriptEnv param = + ScriptEnv param eenv.models eenv.inputs + <$> getVcObject eenv.script + ?? eenv.script + ?? mres + + getParam :: RemoteM InferenceParam + getParam = + firstOrThrow (NoSuchParameter (wrappedTo ipid)) + =<< queryStore q (Only ipid) + where + q :: Query + q = [sql| SELECT * FROM params WHERE id = ? |] runInferenceParamWithEnv :: Id InferenceParam -> @@ -175,22 +190,23 @@ runInferenceParamWithEnv ipid uuid senv = -- will not have been initialized yet. After that, it will be reused -- until the server is started again interpreter <- getOrMkInferno ipid - param <- getParameterWithModels ipid cache <- view $ #config . #cache + t <- liftIO $ fromIntegral @Int . round <$> getPOSIXTime -- Change working directories to the model cache so that Hasktorch -- can find the models using relative paths (otherwise the AST would -- need to be updated to use an absolute path to a versioned model, -- e.g. `loadModel "~/inferno/.cache/..."`) - withCurrentDirectory (view #path cache) $ do + withCurrentDirectory cache.path $ do logInfo $ EvaluatingParam ipid traverse_ linkVersionedModel - =<< getAndCacheModels cache (view #models param) - runEval interpreter + =<< getAndCacheModels cache senv.models + runEval interpreter t where runEval :: Interpreter RemoteM BridgeMlValue -> + CTime -> RemoteM (WriteStream IO) - runEval Interpreter {evalExpr, mkEnvFromClosure} = + runEval Interpreter {evalExpr, mkEnvFromClosure} t = senv ^. #obj . #obj & \case VCFunction {} -> do let -- Note that this both includes inputs (i.e. readable) @@ -270,11 +286,8 @@ runInferenceParamWithEnv ipid uuid senv = expr :: Expr (Maybe VCObjectHash) () expr = flip (foldl' App) args $ - Var () mhash LocalScope dummy + Var () (Just senv.script) LocalScope dummy where - mhash :: Maybe VCObjectHash - mhash = senv ^. #script & Just - -- See note above about inputs/outputs args :: [Expr (Maybe a) ()] args = exprsFrom "input$" pids <> exprsFrom "model$" models @@ -327,7 +340,7 @@ runInferenceParamWithEnv ipid uuid senv = implEnv :: Map ExtIdent (Value BridgeMlValue m) implEnv = Map.fromList - [ (ExtIdent $ Right "now", VEpochTime $ view #ctime senv), + [ (ExtIdent $ Right "now", VEpochTime t), ( ExtIdent $ Right "resolution", VCustom . VExtended $ VResolution resolution ) @@ -445,14 +458,14 @@ linkVersionedModel withVersion = do withExt = dropExtensions withVersion <.> "ts" <.> "pt" getParameterWithModels :: Id InferenceParam -> RemoteM InferenceParamWithModels -getParameterWithModels iid = +getParameterWithModels ipid = fmap ( uncurry InferenceParamWithModels . fmap (getAeson . fromOnly) . joinToTuple ) - . firstOrThrow (NoSuchParameter (wrappedTo iid)) - =<< queryStore q (Only iid) + . firstOrThrow (NoSuchParameter (wrappedTo ipid)) + =<< queryStore q (Only ipid) where -- This query is somewhat complex in order to get all relevent information -- for creating the script evaluator's Inferno environment. @@ -611,7 +624,6 @@ data ScriptEnv = ScriptEnv inputs :: Map Ident (SingleOrMany PID, ScriptInputType), obj :: VCMeta VCObject, script :: VCObjectHash, - mres :: Maybe Int64, - ctime :: CTime + mres :: Maybe Int64 } deriving stock (Generic) From b2a74ce993c51ea525ea128d332173c290d2c704 Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford Date: Fri, 4 Oct 2024 11:13:14 +0700 Subject: [PATCH 05/13] Switch to UUIDs for DB primary keys --- .../src/Inferno/ML/Server/Types.hs | 20 +++++++++---------- .../src/Inferno/ML/Server/Inference.hs | 8 ++++++-- .../src/Inferno/ML/Server/Types.hs | 13 ++++++++---- .../migrations/v1-create-tables.sql | 8 ++++---- 4 files changed, 29 insertions(+), 20 deletions(-) diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs index 4f8198b..0ed5e5c 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs @@ -28,7 +28,7 @@ import qualified Data.ByteString.Char8 as ByteString.Char8 import Data.Char (chr) import Data.Data (Typeable) import Data.Generics.Product (HasType (typed)) -import Data.Generics.Wrapped (wrappedFrom, wrappedTo) +import Data.Generics.Wrapped (wrappedTo) import Data.Hashable (Hashable) import qualified Data.IP import Data.Int (Int64) @@ -95,7 +95,6 @@ import System.Posix (EpochTime) import Test.QuickCheck ( Arbitrary (arbitrary), Gen, - Positive (getPositive), choose, chooseInt, listOf, @@ -211,7 +210,7 @@ instance ToRow (BridgeInfo uid gid p s) where ] -- | The ID of a database entity -newtype Id a = Id Int64 +newtype Id a = Id UUID deriving stock (Show, Generic) deriving newtype ( Eq, @@ -224,13 +223,11 @@ newtype Id a = Id Int64 FromJSONKey, ToJSONKey, ToHttpApiData, - FromHttpApiData + FromHttpApiData, + Arbitrary ) deriving anyclass (NFData, ToADTArbitrary) -instance Arbitrary (Id a) where - arbitrary = wrappedFrom . getPositive <$> arbitrary - -- | Row for the table containing inference script closures data InferenceScript uid gid = InferenceScript { -- | This is the ID for each row, stored as a @bytea@ (bytes of the hash) @@ -332,7 +329,10 @@ instance instance ToField gid => ToRow (Model gid) where -- NOTE: Order of fields must align exactly with DB schema toRow m = - [ toField Default, + [ -- Normally the ID will be missing for new rows, in that case the default + -- will be used (a random UUID). But in other cases it is useful to set + -- the ID explicitly (e.g. for testing) + m.id & maybe (toField Default) toField, m.name & toField, m.gid & toField, m.visibility & Aeson & toField, @@ -441,7 +441,7 @@ instance FromField gid => FromRow (ModelVersion gid Oid) where instance ToField gid => ToRow (ModelVersion gid Oid) where -- NOTE: Order of fields must align exactly with DB schema toRow mv = - [ toField Default, + [ mv.id & maybe (toField Default) toField, mv.model & toField, mv.description & toField, mv.card & Aeson & toField, @@ -748,7 +748,7 @@ instance where -- NOTE: Do not change the order of the field actions toRow ip = - [ toField Default, + [ ip.id & maybe (toField Default) toField, ip.script & VCObjectHashRow & toField, ip.inputs & Aeson & toField, ip.resolution & Aeson & toField, diff --git a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs index b47b110..855f4c5 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs @@ -126,6 +126,10 @@ runInferenceParam ipid mres uuid = ?? pwm.param.script ?? mres +-- | Test an inference param. This requires a script object to be saved to +-- the DB, but it is does not need to be linked to the parameter itself. It +-- also allows for overriding models and inputs, which normally need to be +-- fixed to the script or param, respectively testInferenceParam :: Id InferenceParam -> Maybe Int64 -> @@ -150,7 +154,7 @@ testInferenceParam ipid mres uuid eenv = getParam :: RemoteM InferenceParam getParam = - firstOrThrow (NoSuchParameter (wrappedTo ipid)) + firstOrThrow (NoSuchParameter ipid) =<< queryStore q (Only ipid) where q :: Query @@ -464,7 +468,7 @@ getParameterWithModels ipid = . fmap (getAeson . fromOnly) . joinToTuple ) - . firstOrThrow (NoSuchParameter (wrappedTo ipid)) + . firstOrThrow (NoSuchParameter ipid) =<< queryStore q (Only ipid) where -- This query is somewhat complex in order to get all relevent information diff --git a/inferno-ml-server/src/Inferno/ML/Server/Types.hs b/inferno-ml-server/src/Inferno/ML/Server/Types.hs index c67c80d..873fc6d 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Types.hs @@ -46,7 +46,6 @@ import qualified Data.ByteString.Char8 as ByteString.Char8 import Data.Data (Typeable) import Data.Generics.Labels () import Data.Generics.Wrapped (wrappedTo) -import Data.Int (Int64) import Data.Map.Strict (Map) import Data.Scientific (Scientific) import Data.Text (Text) @@ -301,9 +300,9 @@ data RemoteError | -- | Either the requested model version does not exist, or the -- parent model row corresponding to the model version does not -- exist - NoSuchModel Int64 + NoSuchModel (Either (Id Model) (Id ModelVersion)) | NoSuchScript VCObjectHash - | NoSuchParameter Int64 + | NoSuchParameter (Id InferenceParam) | InvalidScript Text | InvalidOutput Text | -- | Any error condition returned by Inferno script evaluation @@ -317,12 +316,18 @@ data RemoteError instance Exception RemoteError where displayException = \case CacheSizeExceeded -> "Model exceeds maximum cache size" - NoSuchModel m -> + NoSuchModel (Left m) -> unwords [ "Model:", "'" <> show m <> "'", "does not exist in the store" ] + NoSuchModel (Right mv) -> + unwords + [ "Model version:", + "'" <> show mv <> "'", + "does not exist in the store" + ] NoSuchScript vch -> unwords [ "Script identified by hash", diff --git a/nix/inferno-ml/migrations/v1-create-tables.sql b/nix/inferno-ml/migrations/v1-create-tables.sql index 0bc2a19..6f35a66 100644 --- a/nix/inferno-ml/migrations/v1-create-tables.sql +++ b/nix/inferno-ml/migrations/v1-create-tables.sql @@ -17,7 +17,7 @@ create extension lo; -- "deleted" and cannot be used any longer create table if not exists models - ( id serial primary key + ( id uuid primary key default gen_random_uuid() , name text not null , gid numeric not null , visibility jsonb @@ -29,7 +29,7 @@ create table if not exists models ); create table if not exists mversions - ( id serial primary key + ( id uuid primary key default gen_random_uuid() , model integer references models (id) -- Short, high-level model description , description text not null @@ -66,7 +66,7 @@ create table if not exists mselections ); create table if not exists params - ( id serial primary key + ( id uuid primary key default gen_random_uuid() -- Script hash from `inferno-vc` , script bytea not null references scripts (id) -- Strictly speaking, this includes both inputs and outputs. The @@ -97,7 +97,7 @@ create table if not exists evalinfo -- Stores information required to call the data bridge create table if not exists bridges ( -- Same ID as the referenced param - id integer not null references params (id) + id uuid not null references params (id) -- Host of the bridge server , ip inet not null , port integer check (port > 0) From 340c9188a8a333e70d3553f2be8100dd06bdf548 Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford Date: Fri, 4 Oct 2024 11:50:28 +0700 Subject: [PATCH 06/13] Fix everything to use UUIDs instead of ints --- inferno-ml-server/exe/ParseAndSave.hs | 53 +++++++++++-------- inferno-ml-server/inferno-ml-server.cabal | 2 + inferno-ml-server/test/Client.hs | 16 +++--- inferno-ml-server/test/Main.hs | 3 +- .../migrations/v1-create-tables.sql | 6 +-- nix/inferno-ml/tests/server.nix | 36 +++++++++---- 6 files changed, 72 insertions(+), 44 deletions(-) diff --git a/inferno-ml-server/exe/ParseAndSave.hs b/inferno-ml-server/exe/ParseAndSave.hs index 886fda1..f4dafaf 100644 --- a/inferno-ml-server/exe/ParseAndSave.hs +++ b/inferno-ml-server/exe/ParseAndSave.hs @@ -22,14 +22,13 @@ import qualified Data.Map.Strict as Map import Data.Text (Text) import qualified Data.Text.IO as Text.IO import Data.Time.Clock.POSIX (getPOSIXTime) +import qualified Data.UUID as UUID import Database.PostgreSQL.Simple ( Connection, - Only (fromOnly), Query, close, connectPostgreSQL, execute, - query, withTransaction, ) import Database.PostgreSQL.Simple.SqlQQ (sql) @@ -52,7 +51,6 @@ import Inferno.VersionControl.Types VCObjectPred (Init), VCObjectVisibility (VCObjectPublic), ) -import Lens.Micro.Platform import System.Environment (getArgs) import System.Exit (die) import UnliftIO.Exception (bracket, throwString) @@ -60,32 +58,36 @@ import UnliftIO.Exception (bracket, throwString) main :: IO () main = getArgs >>= \case - scriptp : pstr : conns : _ -> - either throwString (parseAndSave scriptp (Char8.pack conns)) - . eitherDecode - $ Lazy.Char8.pack pstr - _ -> die "Usage ./parse " + ustr : scriptp : pstr : conns : _ -> do + pids <- either throwString pure $ eitherDecode (Lazy.Char8.pack pstr) + ipid <- + maybe (throwString "Invalid inference param Id") pure $ + UUID.fromString ustr + parseAndSave (Id ipid) scriptp (Char8.pack conns) pids + _ -> die "Usage ./parse " parseAndSave :: + Id InferenceParam -> FilePath -> ByteString -> Map Ident (SingleOrMany PID, ScriptInputType) -> IO () -parseAndSave p conns inputs = do +parseAndSave ipid p conns inputs = do t <- Text.IO.readFile p now <- fromIntegral @Int . round <$> getPOSIXTime ast <- either (throwString . displayException) pure . (`parse` t) =<< mkInferno @_ @BridgeMlValue (mkBridgePrelude funs) customTypes - bracket (connectPostgreSQL conns) close (saveScriptAndParam ast now inputs) + bracket (connectPostgreSQL conns) close (saveScriptAndParam ipid ast now inputs) saveScriptAndParam :: + Id InferenceParam -> (Expr (Pinned VCObjectHash) (), TCScheme) -> CTime -> Map Ident (SingleOrMany PID, ScriptInputType) -> Connection -> IO () -saveScriptAndParam x now inputs conn = insertScript *> insertParam +saveScriptAndParam ipid x now inputs conn = insertScript *> insertParam where insertScript :: IO () insertScript = @@ -108,17 +110,15 @@ saveScriptAndParam x now inputs conn = insertScript *> insertParam |] insertParam :: IO () - insertParam = - maybe (throwString "Can't get param ID") (saveBridgeInfo . fromOnly) - . preview _head - =<< saveParam + insertParam = saveParam *> saveBridgeInfo where - saveParam :: IO [Only (Id InferenceParam)] + saveParam :: IO () saveParam = - withTransaction conn - . query conn q + void + . withTransaction conn + . execute conn q . InferenceParam - Nothing + (Just ipid) hash inputs 128 @@ -129,12 +129,18 @@ saveScriptAndParam x now inputs conn = insertScript *> insertParam q = [sql| INSERT INTO params + ( id + , script + , inputs + , resolution + , terminated + , uid + ) VALUES (?, ?, ?, ?, ?, ?) - RETURNING id |] - saveBridgeInfo :: Id InferenceParam -> IO () - saveBridgeInfo ipid = + saveBridgeInfo :: IO () + saveBridgeInfo = void . withTransaction conn . execute conn q @@ -161,7 +167,8 @@ saveScriptAndParam x now inputs conn = insertScript *> insertParam MLInferenceScript . InferenceOptions . Map.singleton "mnist" - $ Id 1 + . Id + $ UUID.fromWords 6 0 0 0 uid :: EntityId UId uid = entityIdFromInteger 0 diff --git a/inferno-ml-server/inferno-ml-server.cabal b/inferno-ml-server/inferno-ml-server.cabal index f912db7..7728a8b 100644 --- a/inferno-ml-server/inferno-ml-server.cabal +++ b/inferno-ml-server/inferno-ml-server.cabal @@ -118,6 +118,7 @@ executable tests , mtl , plow-log , text + , uuid , unliftio , vector @@ -184,4 +185,5 @@ executable parse-and-save , text , time , unliftio + , uuid , vector diff --git a/inferno-ml-server/test/Client.hs b/inferno-ml-server/test/Client.hs index b23ed88..8f125fb 100644 --- a/inferno-ml-server/test/Client.hs +++ b/inferno-ml-server/test/Client.hs @@ -9,8 +9,9 @@ module Client (main) where import Conduit import Control.Monad (unless) import Data.Coerce (coerce) -import Data.Int (Int64) import qualified Data.Map as Map +import Data.UUID (UUID) +import qualified Data.UUID as UUID import Inferno.ML.Server.Client (inferenceC) import Inferno.ML.Server.Types ( IValue (IDouble), @@ -46,7 +47,7 @@ main = _ -> die "Usage: test-client " -- Check that the returned write stream matches the expected value -verifyWrites :: Int64 -> WriteStream IO -> IO () +verifyWrites :: UUID -> WriteStream IO -> IO () verifyWrites ipid c = do expected <- getExpected -- Note that there is only one chunk per PID in the output stream, so we @@ -64,17 +65,18 @@ verifyWrites ipid c = do where getExpected :: IO [(Int, [(EpochTime, IValue)])] getExpected = - maybe (throwString "Missing PID") pure . Map.lookup ipid $ - Map.fromList - [ ( 1, + maybe (throwString "Missing output PID for parameter") pure + . Map.lookup ipid + $ Map.fromList + [ ( UUID.fromWords 1 0 0 0, [ (1, [(151, IDouble 2.5), (251, IDouble 3.5)]) ] ), - ( 2, + ( UUID.fromWords 2 0 0 0, [ (2, [(300, IDouble 25.0)]) ] ), - ( 3, + ( UUID.fromWords 3 0 0 0, [ (3, [(100, IDouble 7.0)]), (4, [(100, IDouble 8.0)]) ] diff --git a/inferno-ml-server/test/Main.hs b/inferno-ml-server/test/Main.hs index ef17c04..babb5ef 100644 --- a/inferno-ml-server/test/Main.hs +++ b/inferno-ml-server/test/Main.hs @@ -15,6 +15,7 @@ import Data.Foldable (toList, traverse_) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import Data.Text (Text) +import qualified Data.UUID as UUID import Data.Vector (Vector) import qualified Data.Vector as Vector import Data.Word (Word8) @@ -127,7 +128,7 @@ getWithContents env = flip runReaderT env $ do v : _ -> getModelVersionSizeAndContents $ view #contents v mnistV1 :: Id ModelVersion -mnistV1 = Id 1 +mnistV1 = Id $ UUID.fromWords 6 0 0 0 models :: Vector (Id ModelVersion) models = Vector.singleton mnistV1 diff --git a/nix/inferno-ml/migrations/v1-create-tables.sql b/nix/inferno-ml/migrations/v1-create-tables.sql index 6f35a66..66264a1 100644 --- a/nix/inferno-ml/migrations/v1-create-tables.sql +++ b/nix/inferno-ml/migrations/v1-create-tables.sql @@ -30,7 +30,7 @@ create table if not exists models create table if not exists mversions ( id uuid primary key default gen_random_uuid() - , model integer references models (id) + , model uuid references models (id) -- Short, high-level model description , description text not null -- Model card (description and metadata) serialized as JSON @@ -59,7 +59,7 @@ create table if not exists scripts -- between `scripts` and `mversions`) create table if not exists mselections ( script bytea not null references scripts (id) - , model integer not null references mversions (id) + , model uuid not null references mversions (id) -- Inferno identifier linked to this specific model version , ident text not null , unique (script, model) @@ -83,7 +83,7 @@ create table if not exists params -- Execution info for inference evaluation create table if not exists evalinfo ( id uuid primary key - , param integer not null references params (id) + , param uuid not null references params (id) -- When inference evaluation began , started timestamptz not null -- When inference evaluation ended diff --git a/nix/inferno-ml/tests/server.nix b/nix/inferno-ml/tests/server.nix index 823d632..f56a9f7 100644 --- a/nix/inferno-ml/tests/server.nix +++ b/nix/inferno-ml/tests/server.nix @@ -1,5 +1,14 @@ { pkgs, ... }: +let + # Fixed UUIDs for parameters. The Haskell test executable will use these + # IDs as well + ids = { + ones = "00000001-0000-0000-0000-000000000000"; + contrived = "00000002-0000-0000-0000-000000000000"; + mnist = "00000003-0000-0000-0000-000000000000"; + }; +in pkgs.nixosTest { name = "inferno-ml-server-test"; nodes.node = { config, ... }: { @@ -35,26 +44,30 @@ pkgs.nixosTest { '' psql -U inferno -d inferno << EOF INSERT INTO models - ( name + ( id + , name , gid , visibility ) VALUES - ( 'mnist' + ( '00000005-0000-0000-0000-000000000000'::uuid + , 'mnist' , 1::bigint , '"VCObjectPublic"'::jsonb ); \lo_import ${./models/mnist.ts.pt} INSERT INTO mversions - ( model + ( id + , model , description , card , contents , version ) VALUES - ( 1 + ( '00000006-0000-0000-0000-000000000000'::uuid + , '00000005-0000-0000-0000-000000000000'::uuid , 'My first model' , '${card}'::jsonb , :LASTOID @@ -87,9 +100,12 @@ pkgs.nixosTest { }; in '' - parse-and-save ${./scripts/ones.inferno} '${ios.ones}' ${dbstr} - parse-and-save ${./scripts/contrived.inferno} '${ios.contrived}' ${dbstr} - parse-and-save ${./scripts/mnist.inferno} '${ios.mnist}' ${dbstr} + parse-and-save \ + ${./scripts/ones.inferno} '${ids.ones}' '${ios.ones}' ${dbstr} + parse-and-save \ + ${./scripts/contrived.inferno} '${ids.contrived}' '${ios.contrived}' ${dbstr} + parse-and-save \ + ${./scripts/mnist.inferno} '${ids.mnist}' '${ios.mnist}' ${dbstr} ''; } ) @@ -206,12 +222,12 @@ pkgs.nixosTest { node.succeed('sudo -HE -u inferno run-db-test >&2') # `tests/scripts/ones.inferno` - runtest(1) + runtest('${ids.ones}') # `tests/scripts/contrived.inferno` - runtest(2) + runtest('${ids.contrived}') # `tests/scripts/mnist.inferno` - runtest(3) + runtest('${ids.mnist}') ''; } From eb2d4f3a24e218c6f1131c4dc0f8e016b8af330a Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford Date: Fri, 4 Oct 2024 12:00:53 +0700 Subject: [PATCH 07/13] More fixes --- inferno-ml-server/exe/ParseAndSave.hs | 2 +- nix/inferno-ml/tests/server.nix | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/inferno-ml-server/exe/ParseAndSave.hs b/inferno-ml-server/exe/ParseAndSave.hs index f4dafaf..dfd9ed1 100644 --- a/inferno-ml-server/exe/ParseAndSave.hs +++ b/inferno-ml-server/exe/ParseAndSave.hs @@ -105,7 +105,7 @@ saveScriptAndParam ipid x now inputs conn = insertScript *> insertParam RETURNING id ) INSERT INTO mselections (script, model, ident) - SELECT id, 1::integer, 'mnist' + SELECT id, '00000006-0000-0000-0000-000000000000'::uuid, 'mnist' FROM ins |] diff --git a/nix/inferno-ml/tests/server.nix b/nix/inferno-ml/tests/server.nix index f56a9f7..4658242 100644 --- a/nix/inferno-ml/tests/server.nix +++ b/nix/inferno-ml/tests/server.nix @@ -101,11 +101,11 @@ pkgs.nixosTest { in '' parse-and-save \ - ${./scripts/ones.inferno} '${ids.ones}' '${ios.ones}' ${dbstr} + '${ids.ones}' ${./scripts/ones.inferno} '${ios.ones}' ${dbstr} parse-and-save \ - ${./scripts/contrived.inferno} '${ids.contrived}' '${ios.contrived}' ${dbstr} + '${ids.contrived}' ${./scripts/contrived.inferno} '${ios.contrived}' ${dbstr} parse-and-save \ - ${./scripts/mnist.inferno} '${ids.mnist}' '${ios.mnist}' ${dbstr} + '${ids.mnist}' ${./scripts/mnist.inferno} '${ios.mnist}' ${dbstr} ''; } ) From 26343f9d6948678d0024ce7c661a1089123c230a Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford Date: Fri, 4 Oct 2024 12:40:28 +0700 Subject: [PATCH 08/13] Switch to GID for params --- .../src/Inferno/ML/Server/Client.hs | 12 +-- .../src/Inferno/ML/Server/Types.hs | 85 +++++++++---------- inferno-ml-server/exe/ParseAndSave.hs | 2 +- .../src/Inferno/ML/Server/Types.hs | 27 +++--- inferno-ml-server/test/Client.hs | 4 +- .../migrations/v1-create-tables.sql | 2 +- 6 files changed, 61 insertions(+), 71 deletions(-) diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs index d952108..41461ae 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs @@ -27,9 +27,9 @@ cancelC = client $ Proxy @CancelAPI -- | Run an inference parameter inferenceC :: - forall uid gid p s. + forall gid p s. -- | SQL identifier of the inference parameter to be run - Id (InferenceParam uid gid p s) -> + Id (InferenceParam gid p s) -> -- | Optional resolution for scripts that use e.g. @valueAt@; defaults to -- the param\'s stored resolution if not provided. This lets users override -- the resolution on an ad-hoc basis without needing to alter the stored @@ -44,16 +44,16 @@ inferenceC :: -- (not defined in this repository) to verify this before directing -- the writes to their final destination ClientM (WriteStream IO) -inferenceC = client $ Proxy @(InferenceAPI uid gid p s) +inferenceC = client $ Proxy @(InferenceAPI gid p s) -- | Run an inference parameter inferenceTestC :: - forall uid gid p s. + forall gid p s. ToJSON p => -- | SQL identifier of the inference parameter to be run - Id (InferenceParam uid gid p s) -> + Id (InferenceParam gid p s) -> Maybe Int64 -> UUID -> EvaluationEnv gid p -> ClientM (WriteStream IO) -inferenceTestC = client $ Proxy @(InferenceTestAPI uid gid p s) +inferenceTestC = client $ Proxy @(InferenceTestAPI gid p s) diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs index 0ed5e5c..7985e0c 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs @@ -74,7 +74,11 @@ import Inferno.Types.VersionControl byteStringToVCObjectHash, vcObjectHashToByteString, ) -import Inferno.VersionControl.Types (VCMeta, VCObject, VCObjectVisibility) +import Inferno.VersionControl.Types + ( VCMeta, + VCObject, + VCObjectVisibility, + ) import Lens.Micro.Platform hiding ((.=)) import Servant ( Capture, @@ -121,11 +125,11 @@ import Web.HttpApiData ) -- API type for `inferno-ml-server` -type InfernoMlServerAPI uid gid p s = +type InfernoMlServerAPI gid p s = StatusAPI -- Evaluate an inference script - :<|> InferenceAPI uid gid p s - :<|> InferenceTestAPI uid gid p s + :<|> InferenceAPI gid p s + :<|> InferenceTestAPI gid p s :<|> CancelAPI type StatusAPI = @@ -134,17 +138,17 @@ type StatusAPI = type CancelAPI = "inference" :> "cancel" :> Put '[JSON] () -type InferenceAPI uid gid p s = +type InferenceAPI gid p s = "inference" - :> Capture "id" (Id (InferenceParam uid gid p s)) + :> Capture "id" (Id (InferenceParam gid p s)) :> QueryParam "res" Int64 :> QueryParam' '[Required] "uuid" UUID :> StreamPost NewlineFraming JSON (WriteStream IO) -type InferenceTestAPI uid gid p s = +type InferenceTestAPI gid p s = -- Evaluate an inference script "inference" - :> Capture "id" (Id (InferenceParam uid gid p s)) + :> Capture "id" (Id (InferenceParam gid p s)) :> QueryParam "res" Int64 :> QueryParam' '[Required] "uuid" UUID :> ReqBody '[JSON] (EvaluationEnv gid p) @@ -187,22 +191,22 @@ data ServerStatus deriving anyclass (FromJSON, ToJSON) -- | Information for contacting a bridge server that implements the 'BridgeAPI' -data BridgeInfo uid gid p s = BridgeInfo - { id :: Id (InferenceParam uid gid p s), +data BridgeInfo gid p s = BridgeInfo + { id :: Id (InferenceParam gid p s), host :: IPv4, port :: Word64 } deriving stock (Show, Eq, Generic) deriving anyclass (FromJSON, ToJSON, NFData) -instance FromRow (BridgeInfo uid gid p s) where +instance FromRow (BridgeInfo gid p s) where fromRow = BridgeInfo <$> field <*> field <*> fmap (fromIntegral @Int64) field -instance ToRow (BridgeInfo uid gid p s) where +instance ToRow (BridgeInfo gid p s) where toRow bi = [ bi.id & toField, bi.host & toField, @@ -669,8 +673,8 @@ showVersion (Version ns ts) = -- | Row of the inference parameter table, parameterized by the user, group, and -- script type -data InferenceParam uid gid p s = InferenceParam - { id :: Maybe (Id (InferenceParam uid gid p s)), +data InferenceParam gid p s = InferenceParam + { id :: Maybe (Id (InferenceParam gid p s)), -- | The script of the parameter -- -- For new parameters, this will be textual or some other identifier @@ -695,18 +699,13 @@ data InferenceParam uid gid p s = InferenceParam -- | The time that this parameter was \"deleted\", if any. For active -- parameters, this will be @Nothing@ terminated :: Maybe UTCTime, - uid :: uid + gid :: gid } deriving stock (Show, Eq, Generic) deriving anyclass (NFData, ToJSON) {- ORMOLU_DISABLE -} -instance - ( FromJSON s, - FromJSON p, - FromJSON uid - ) => - FromJSON (InferenceParam uid gid p s) +instance (FromJSON s, FromJSON p, FromJSON gid) => FromJSON (InferenceParam gid p s) where parseJSON = withObject "InferenceParam" $ \o -> InferenceParam @@ -717,19 +716,18 @@ instance <*> o .:? "resolution" .!= 128 -- We shouldn't require this field <*> o .:? "terminated" - <*> o .: "uid" + <*> o .: "gid" {- ORMOLU_ENABLE -} -- We only want this instance if the `script` is a `VCObjectHash` (because it -- should not be possible to store a new param with a raw script) instance ( FromJSON p, - Typeable p, - FromField uid, + FromField gid, Typeable gid, - Typeable uid + Typeable p ) => - FromRow (InferenceParam uid gid p VCObjectHash) + FromRow (InferenceParam gid p VCObjectHash) where fromRow = InferenceParam @@ -740,12 +738,7 @@ instance <*> field <*> field -instance - ( ToJSON p, - ToField uid - ) => - ToRow (InferenceParam uid gid p VCObjectHash) - where +instance (ToJSON p, ToField gid) => ToRow (InferenceParam gid p VCObjectHash) where -- NOTE: Do not change the order of the field actions toRow ip = [ ip.id & maybe (toField Default) toField, @@ -753,16 +746,16 @@ instance ip.inputs & Aeson & toField, ip.resolution & Aeson & toField, toField Default, - ip.uid & toField + ip.gid & toField ] -- Not derived generically in order to use special `Gen UTCTime` instance - ( Arbitrary s, + ( Arbitrary gid, Arbitrary p, - Arbitrary uid + Arbitrary s ) => - Arbitrary (InferenceParam uid gid p s) + Arbitrary (InferenceParam gid p s) where arbitrary = InferenceParam @@ -775,11 +768,11 @@ instance -- Can't be derived because there is (intentially) no `Arbitrary UTCTime` in scope instance - ( Arbitrary s, + ( Arbitrary gid, Arbitrary p, - Arbitrary uid + Arbitrary s ) => - ToADTArbitrary (InferenceParam uid gid p s) + ToADTArbitrary (InferenceParam gid p s) where toADTArbitrarySingleton _ = ADTArbitrarySingleton "Inferno.ML.Server.Types" "InferenceParam" @@ -792,8 +785,8 @@ instance -- | An 'InferenceParam' together with all of the model versions that are -- linked to it indirectly via its script. This is provided for convenience -data InferenceParamWithModels uid gid p s = InferenceParamWithModels - { param :: InferenceParam uid gid p s, +data InferenceParamWithModels gid p s = InferenceParamWithModels + { param :: InferenceParam gid p s, models :: Map Ident @@ -841,11 +834,11 @@ instance Arbitrary ScriptInputType where -- @inferno-ml-server@ after script evaluation completes and can be queried -- later by using the same job identifier that was provided to the @/inference@ -- route -data EvaluationInfo uid gid p = EvaluationInfo +data EvaluationInfo gid p = EvaluationInfo { -- | Note that this is the job identifier provided to the inference -- evaluation route, and is also the primary key of the database table id :: UUID, - param :: Id (InferenceParam uid gid p VCObjectHash), + param :: Id (InferenceParam gid p VCObjectHash), -- | When inference evaluation started start :: UTCTime, -- | When inference evaluation ended @@ -863,7 +856,7 @@ data EvaluationInfo uid gid p = EvaluationInfo deriving stock (Show, Eq, Generic) deriving anyclass (FromJSON, ToJSON) -instance FromRow (EvaluationInfo uid gid p) where +instance FromRow (EvaluationInfo gid p) where fromRow = EvaluationInfo <$> field @@ -873,7 +866,7 @@ instance FromRow (EvaluationInfo uid gid p) where <*> fmap (fromIntegral @Int64) field <*> fmap (fromIntegral @Int64) field -instance ToRow (EvaluationInfo uid gid p) where +instance ToRow (EvaluationInfo gid p) where toRow ei = [ ei.id & toField, ei.param & toField, @@ -884,7 +877,7 @@ instance ToRow (EvaluationInfo uid gid p) where ] -- Not derived generically in order to use special `Gen UTCTime` -instance Arbitrary (EvaluationInfo uid gid p) where +instance Arbitrary (EvaluationInfo gid p) where arbitrary = EvaluationInfo <$> arbitrary diff --git a/inferno-ml-server/exe/ParseAndSave.hs b/inferno-ml-server/exe/ParseAndSave.hs index dfd9ed1..f1aa9da 100644 --- a/inferno-ml-server/exe/ParseAndSave.hs +++ b/inferno-ml-server/exe/ParseAndSave.hs @@ -134,7 +134,7 @@ saveScriptAndParam ipid x now inputs conn = insertScript *> insertParam , inputs , resolution , terminated - , uid + , gid ) VALUES (?, ?, ?, ?, ?, ?) |] diff --git a/inferno-ml-server/src/Inferno/ML/Server/Types.hs b/inferno-ml-server/src/Inferno/ML/Server/Types.hs index 873fc6d..23231c6 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Types.hs @@ -297,9 +297,8 @@ newtype InferenceOptions = InferenceOptions data RemoteError = CacheSizeExceeded - | -- | Either the requested model version does not exist, or the - -- parent model row corresponding to the model version does not - -- exist + | -- | Either parent model row corresponding to the model version, or the + -- the requested model version itself, does not exist NoSuchModel (Either (Id Model) (Id ModelVersion)) | NoSuchScript VCObjectHash | NoSuchParameter (Id InferenceParam) @@ -415,15 +414,15 @@ infixl 4 ?? f ?? x = ($ x) <$> f type InferenceParam = - Types.InferenceParam (EntityId UId) (EntityId GId) PID VCObjectHash + Types.InferenceParam (EntityId GId) PID VCObjectHash type InferenceParamWithModels = - Types.InferenceParamWithModels (EntityId UId) (EntityId GId) PID VCObjectHash + Types.InferenceParamWithModels (EntityId GId) PID VCObjectHash type BridgeInfo = - Types.BridgeInfo (EntityId UId) (EntityId GId) PID VCObjectHash + Types.BridgeInfo (EntityId GId) PID VCObjectHash -type EvaluationInfo = Types.EvaluationInfo (EntityId UId) (EntityId GId) PID +type EvaluationInfo = Types.EvaluationInfo (EntityId GId) PID type Model = Types.Model (EntityId GId) @@ -442,10 +441,10 @@ pattern InferenceParam :: Map Ident (SingleOrMany PID, ScriptInputType) -> Word64 -> Maybe UTCTime -> - EntityId UId -> + EntityId GId -> InferenceParam -pattern InferenceParam iid s ios res mt uid = - Types.InferenceParam iid s ios res mt uid +pattern InferenceParam iid s ios res mt gid = + Types.InferenceParam iid s ios res mt gid pattern InferenceParamWithModels :: InferenceParam -> @@ -480,11 +479,7 @@ pattern EvaluationInfo :: pattern EvaluationInfo u i s e m c = Types.EvaluationInfo u i s e m c type InfernoMlServerAPI = - Types.InfernoMlServerAPI - (EntityId UId) - (EntityId GId) - PID - VCObjectHash + Types.InfernoMlServerAPI (EntityId GId) PID VCObjectHash type EvaluationEnv = Types.EvaluationEnv (EntityId GId) PID @@ -500,6 +495,8 @@ deriving newtype instance ToHttpApiData EpochTime deriving newtype instance FromHttpApiData EpochTime +-- Etc + joinToTuple :: (a :. b) -> (a, b) joinToTuple (a :. b) = (a, b) diff --git a/inferno-ml-server/test/Client.hs b/inferno-ml-server/test/Client.hs index 8f125fb..0ceef22 100644 --- a/inferno-ml-server/test/Client.hs +++ b/inferno-ml-server/test/Client.hs @@ -50,8 +50,8 @@ main = verifyWrites :: UUID -> WriteStream IO -> IO () verifyWrites ipid c = do expected <- getExpected - -- Note that there is only one chunk per PID in the output stream, so we - -- don't need to concatenate the results by PID. We can just sink it into + -- Note that there are only one or two chunks per PID in the output stream, so + -- we don't need to concatenate the results by PID. We can just sink it into -- a list directly result <- runConduit $ c .| sinkList unless (result == expected) . throwString . unwords $ diff --git a/nix/inferno-ml/migrations/v1-create-tables.sql b/nix/inferno-ml/migrations/v1-create-tables.sql index 66264a1..5835cce 100644 --- a/nix/inferno-ml/migrations/v1-create-tables.sql +++ b/nix/inferno-ml/migrations/v1-create-tables.sql @@ -77,7 +77,7 @@ create table if not exists params , resolution integer not null -- See note above , terminated timestamptz - , uid numeric not null + , gid numeric not null ); -- Execution info for inference evaluation From 7232eed3c073984fcc13a730f9bc030779e32cc7 Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford Date: Fri, 4 Oct 2024 13:36:45 +0700 Subject: [PATCH 09/13] More cleanup --- inferno-ml-server/src/Inferno/ML/Server/Bridge.hs | 8 ++------ .../src/Inferno/ML/Server/Inference.hs | 15 +++++++-------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/inferno-ml-server/src/Inferno/ML/Server/Bridge.hs b/inferno-ml-server/src/Inferno/ML/Server/Bridge.hs index b238252..9be33d7 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Bridge.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Bridge.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE QuasiQuotes #-} module Inferno.ML.Server.Bridge @@ -69,9 +70,4 @@ callBridge bi c = mkEnv = asks $ (`mkClientEnv` url) . view #manager where url :: BaseUrl - url = - BaseUrl - Http - (view (#host . to show) bi) - (view (#port . to fromIntegral) bi) - mempty + url = BaseUrl Http (show bi.host) (fromIntegral bi.port) mempty diff --git a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs index 855f4c5..2b1aca3 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs @@ -283,9 +283,8 @@ runInferenceParamWithEnv ipid uuid senv = closure :: Map VCObjectHash VCObject closure = - senv ^. #script - & ( `Map.singleton` view (#obj . #obj) senv - ) + Map.singleton senv.script $ + view (#obj . #obj) senv expr :: Expr (Maybe VCObjectHash) () expr = @@ -354,7 +353,7 @@ runInferenceParamWithEnv ipid uuid senv = . InvalidScript $ Text.unwords [ "Script identified by VC hash", - senv ^. #script & tshow, + tshow senv.script, "is not a function" ] @@ -444,7 +443,7 @@ getVcObject vch = -- It's easier to `SELECT *` and then get the `obj` field, instead of -- `SELECT obj`, because the `FromRow` instance for `InferenceScript` -- deals with the JSON encoding of the `obj` - fmap (view #obj) . firstOrThrow (NoSuchScript vch) + fmap (.obj) . firstOrThrow (NoSuchScript vch) =<< queryStore @_ @InferenceScript q (Only vch) where -- The script hash is used as the primary key in the table @@ -544,9 +543,9 @@ getAndCacheModels cache = copyAndCache model mversion = versioned <$ do unlessM (doesPathExist versioned) $ do - mversion ^. #id & (`whenJust` logInfo . CopyingModel) + whenJust mversion.id $ logInfo . CopyingModel bitraverse_ checkCacheSize (writeBinaryFileDurableAtomic versioned) - =<< getModelVersionSizeAndContents (view #contents mversion) + =<< getModelVersionSizeAndContents mversion.contents where -- Cache the model with its specific version, i.e. -- `.ts.pt.`, which will later be @@ -606,7 +605,7 @@ getAndCacheModels cache = =<< getCurrentDirectory maxSize :: Integer - maxSize = cache ^. #maxSize & fromIntegral + maxSize = fromIntegral cache.maxSize -- Get a list of models by their access time, so that models that have not been -- used recently can be deleted. This will put the least-recently-used paths From 9908f7ec227389219006db9cb33ff41348638151 Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford Date: Fri, 4 Oct 2024 14:33:37 +0700 Subject: [PATCH 10/13] Stop linking model version paths to parent name --- .../src/Inferno/ML/Server/Types.hs | 16 +-- inferno-ml-server/inferno-ml-server.cabal | 3 +- .../src/Inferno/ML/Server/Inference.hs | 132 ++++++------------ .../src/Inferno/ML/Server/Types.hs | 4 +- inferno-ml-server/test/Main.hs | 53 ++++--- inferno-ml/inferno-ml.cabal | 1 - inferno-ml/src/Inferno/ML/Module/Prelude.hs | 3 +- 7 files changed, 76 insertions(+), 136 deletions(-) diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs index 7985e0c..52c66a3 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs @@ -787,13 +787,7 @@ instance -- linked to it indirectly via its script. This is provided for convenience data InferenceParamWithModels gid p s = InferenceParamWithModels { param :: InferenceParam gid p s, - models :: - Map - Ident - ( Id (ModelVersion gid Oid), - -- Name of parent model - Text - ) + models :: Map Ident (Id (ModelVersion gid Oid)) } deriving stock (Show, Eq, Generic) @@ -1055,13 +1049,7 @@ instance Ord a => Ord (SingleOrMany a) where data EvaluationEnv gid p = EvaluationEnv { script :: VCObjectHash, inputs :: Map Ident (SingleOrMany p, ScriptInputType), - models :: - Map - Ident - ( Id (ModelVersion gid Oid), - -- Name of parent model - Text - ) + models :: Map Ident (Id (ModelVersion gid Oid)) } deriving stock (Show, Eq, Generic) deriving anyclass (FromJSON, ToJSON) diff --git a/inferno-ml-server/inferno-ml-server.cabal b/inferno-ml-server/inferno-ml-server.cabal index 7728a8b..5d87ffa 100644 --- a/inferno-ml-server/inferno-ml-server.cabal +++ b/inferno-ml-server/inferno-ml-server.cabal @@ -110,6 +110,7 @@ executable tests , base , bytestring , containers + , filepath , generic-lens , hspec , inferno-ml-server @@ -118,8 +119,8 @@ executable tests , mtl , plow-log , text - , uuid , unliftio + , uuid , vector executable test-client diff --git a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs index 2b1aca3..a2f93db 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs @@ -13,7 +13,6 @@ module Inferno.ML.Server.Inference ( runInferenceParam, testInferenceParam, getAndCacheModels, - linkVersionedModel, ) where @@ -24,6 +23,7 @@ import Control.Monad.Extra (loopM, unlessM, whenJust, whenM) import Control.Monad.IO.Class (MonadIO (liftIO)) import Control.Monad.ListM (sortByM) import Data.Bifoldable (bitraverse_) +import Data.Bifunctor (bimap) import Data.Conduit.List (chunksOf, sourceList) import Data.Foldable (foldl', traverse_) import Data.Generics.Wrapped (wrappedFrom, wrappedTo) @@ -36,7 +36,7 @@ import Data.Time (UTCTime, getCurrentTime) import Data.Time.Clock.POSIX (getPOSIXTime) import Data.Traversable (for) import Data.UUID (UUID) -import Data.Vector (Vector) +import qualified Data.UUID as UUID import qualified Data.Vector as Vector import Data.Word (Word64) import Database.PostgreSQL.Simple @@ -75,20 +75,18 @@ import Inferno.VersionControl.Types ) import Lens.Micro.Platform import System.CPUTime (getCPUTime) -import System.FilePath (dropExtensions, (<.>)) +import System.FilePath ((<.>)) import System.Mem (getAllocationCounter, setAllocationCounter) import System.Posix.Types (EpochTime) import UnliftIO (withRunInIO) import UnliftIO.Async (wait, withAsync) import UnliftIO.Directory - ( createFileLink, - doesPathExist, + ( doesPathExist, getAccessTime, getCurrentDirectory, getFileSize, listDirectory, removeFile, - removePathForcibly, withCurrentDirectory, ) import UnliftIO.Exception @@ -202,8 +200,7 @@ runInferenceParamWithEnv ipid uuid senv = -- e.g. `loadModel "~/inferno/.cache/..."`) withCurrentDirectory cache.path $ do logInfo $ EvaluatingParam ipid - traverse_ linkVersionedModel - =<< getAndCacheModels cache senv.models + getAndCacheModels cache $ senv.models runEval interpreter t where runEval :: @@ -225,24 +222,11 @@ runInferenceParamWithEnv ipid uuid senv = pids = senv ^.. #inputs . to Map.toAscList . each . _2 . _1 - -- These are all of the models selected for use with the - -- script. The ID of the actual model version is included, - -- along with the name of the parent model. The latter is - -- required to generate the correct call to Hasktorch's - -- `loadScript`, indirectly via the `ML.loadModel` primitve - -- - -- For example, given a parent model name of `"mnist"` and - -- a binding of `model0` (which is provided to the script), - -- `ML.loadModel model0` will ultimately produce - -- `Torch.Script.loadScript "mnist.ts.pt"`; note that the - -- correct model name and extension is handled for the user; - -- furthermore, the script evaluator caches the model _version_ - -- based on the ID and links it to the name of the parent - -- model, so that the `"mnist.ts.pt"` is an existing path - -- in the model cache (see `getAndCacheModels` below) - models :: [(Id ModelVersion, Text)] - models = - senv ^.. #models . to Map.toAscList . each . _2 + -- List of model versions, which are used to evaluate + -- `loadModel` primitive (eventually calling Hasktorch to + -- load the script module) + models :: [Id ModelVersion] + models = senv ^.. #models . to Map.toAscList . each . _2 mkIdentWith :: Text -> Int -> ExtIdent mkIdentWith x = ExtIdent . Right . (x <>) . tshow @@ -250,8 +234,8 @@ runInferenceParamWithEnv ipid uuid senv = toSeries :: PID -> Value BridgeMlValue m toSeries = VCustom . VExtended . VSeries - toModelName :: FilePath -> Value BridgeMlValue m - toModelName = VCustom . VModelName . wrappedFrom + toModelPath :: Id ModelVersion -> Value BridgeMlValue m + toModelPath = VCustom . VModelName . wrappedFrom . mkModelPath argsFrom :: [a] -> @@ -276,15 +260,11 @@ runInferenceParamWithEnv ipid uuid senv = modelArgs :: [(ExtIdent, Value BridgeMlValue m)] modelArgs = - argsFrom models $ \(i, (_, name)) -> - ( mkIdentWith "model$" i, - toModelName $ Text.unpack name - ) + argsFrom models $ + bimap (mkIdentWith "model$") toModelPath closure :: Map VCObjectHash VCObject - closure = - Map.singleton senv.script $ - view (#obj . #obj) senv + closure = senv ^. #obj . #obj & Map.singleton senv.script expr :: Expr (Maybe VCObjectHash) () expr = @@ -358,7 +338,7 @@ runInferenceParamWithEnv ipid uuid senv = ] -- Convert the script timeout from seconds (for ease of configuration) to - -- milliseconds, for use with `timeout` + -- milliseconds, for use with `timeout`; then execute the action withTimeoutMillis :: (Int -> RemoteM b) -> RemoteM b withTimeoutMillis = (view (#config . #timeout) >>=) @@ -450,16 +430,6 @@ getVcObject vch = q :: Query q = [sql| SELECT * FROM scripts WHERE id = ? |] --- Link the versioned versioned to the model, e.g. `.` to just --- `.ts.pt`, so it can be loaded by Hasktorch -linkVersionedModel :: FilePath -> RemoteM () -linkVersionedModel withVersion = do - whenM (doesPathExist withExt) $ removePathForcibly withExt - createFileLink withVersion withExt - where - withExt :: FilePath - withExt = dropExtensions withVersion <.> "ts" <.> "pt" - getParameterWithModels :: Id InferenceParam -> RemoteM InferenceParamWithModels getParameterWithModels ipid = fmap @@ -474,47 +444,33 @@ getParameterWithModels ipid = -- for creating the script evaluator's Inferno environment. -- -- For each row in the `mselections` table linked to the param's script, - -- it selects the model version and the parent model. To create the - -- `Map Ident (Id ModelVersion, Text)` that is used in `ParamWithModels`, - -- it generates a JSONB object. For example, given the following rows from - -- `mselections`: + -- it selects the model version. To create the `Map Ident (Id ModelVersion)` + -- that is used in `ParamWithModels`, it generates a JSONB object. For + -- example, given the following rows from `mselections`: -- -- ``` - -- | script | model | ident | - -- - - - - - -- | \x123abc... | 1 | 'model0' | - -- | \x123abc... | 2 | 'model1' | + -- | script | id | ident | + -- - - - - + -- | \x123abc... | 00000-0000... | 'model0' | + -- | \x123abc... | 00001-0000... | 'model1' | -- ``` -- -- the query will create the following JSONB object: -- -- ``` -- { - -- "model0": [1, "my-model"], - -- "model1": [2, "my-other-model"] + -- "model0": "00000-0000...", + -- "model1": "00001-0000..." -- } -- ``` - -- where the second element of each tuple value is the name of the parent - -- model - -- - -- `jsonb_object_agg` is used in order to convert the row of results - -- into a single JSONB object - -- - -- Note that `jsonb_build_array` is used with the model version ID and - -- parent model name to create a two-element array, because this is the - -- tuple encoding expected by Aeson, and the `FromJSON` instance is - -- reused in order to parse the `InferenceParamWithModels` q :: Query q = [sql| - SELECT - P.*, - jsonb_object_agg(MS.ident, jsonb_build_array(MS.model, M.name)) models + SELECT P.*, jsonb_object_agg(MS.ident, MS.model) mversions FROM params P INNER JOIN scripts S ON P.script = S.id INNER JOIN mselections MS ON MS.script = S.id INNER JOIN mversions MV ON MV.id = MS.model - INNER JOIN models M ON MV.model = M.id WHERE P.id = ? AND P.terminated IS NULL GROUP BY @@ -532,29 +488,25 @@ getParameterWithModels ipid = -- NOTE: This action assumes that the current working directory is the model -- cache! It can be run using e.g. 'withCurrentDirectory' getAndCacheModels :: - ModelCache -> Map Ident (Id ModelVersion, Text) -> RemoteM (Vector FilePath) + ModelCache -> Map Ident (Id ModelVersion) -> RemoteM () getAndCacheModels cache = - traverse (uncurry copyAndCache) + traverse_ (uncurry copyAndCache) <=< getModelsAndVersions . Vector.fromList - . toListOf (each . _1) + . toListOf each where - copyAndCache :: Model -> ModelVersion -> RemoteM FilePath - copyAndCache model mversion = - versioned <$ do - unlessM (doesPathExist versioned) $ do + copyAndCache :: Model -> ModelVersion -> RemoteM () + copyAndCache _ mversion = + mkPath >>= \path -> + unlessM (doesPathExist path) $ do whenJust mversion.id $ logInfo . CopyingModel - bitraverse_ checkCacheSize (writeBinaryFileDurableAtomic versioned) + bitraverse_ checkCacheSize (writeBinaryFileDurableAtomic path) =<< getModelVersionSizeAndContents mversion.contents where - -- Cache the model with its specific version, i.e. - -- `.ts.pt.`, which will later be - -- symlinked to `.ts.pt` - versioned :: FilePath - versioned = model ^. #name . unpacked & (<.> v) - where - v :: FilePath - v = mversion ^. #version . to showVersion . unpacked + mkPath :: RemoteM FilePath + mkPath = + maybe (throwM (OtherRemoteError "todo")) (pure . mkModelPath) $ + mversion.id -- Checks that the configured cache size will not be exceeded by -- caching the new model. If it will, least-recently-used models @@ -618,12 +570,18 @@ modelsByAccessTime = sortByM compareAccessTime <=< listDirectory getAccessTime f1 >>= \t1 -> compare t1 <$> getAccessTime f2 +-- There should only be one way to generate a filepath from a model version, so +-- that the path pointing to the contents is always unambiguous. This uses its +-- UUID to do so +mkModelPath :: Id ModelVersion -> FilePath +mkModelPath = (<.> "ts" <.> "pt") . UUID.toString . wrappedTo + -- Everything needed to evaluate an ML script. For the normal endpoint, all of -- these will be derived directly from the param. For the interactive test -- endpoint, these will be overridden data ScriptEnv = ScriptEnv { param :: InferenceParam, - models :: Map Ident (Id ModelVersion, Text), + models :: Map Ident (Id ModelVersion), inputs :: Map Ident (SingleOrMany PID, ScriptInputType), obj :: VCMeta VCObject, script :: VCObjectHash, diff --git a/inferno-ml-server/src/Inferno/ML/Server/Types.hs b/inferno-ml-server/src/Inferno/ML/Server/Types.hs index 23231c6..e16c8a1 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Types.hs @@ -447,9 +447,7 @@ pattern InferenceParam iid s ios res mt gid = Types.InferenceParam iid s ios res mt gid pattern InferenceParamWithModels :: - InferenceParam -> - Map Ident (Id ModelVersion, Text) -> - InferenceParamWithModels + InferenceParam -> Map Ident (Id ModelVersion) -> InferenceParamWithModels pattern InferenceParamWithModels ip mvs = Types.InferenceParamWithModels ip mvs pattern BridgeInfo :: Id InferenceParam -> IPv4 -> Word64 -> BridgeInfo diff --git a/inferno-ml-server/test/Main.hs b/inferno-ml-server/test/Main.hs index babb5ef..a1e39fe 100644 --- a/inferno-ml-server/test/Main.hs +++ b/inferno-ml-server/test/Main.hs @@ -1,3 +1,6 @@ +{-# LANGUAGE OverloadedRecordDot #-} +{-# LANGUAGE NoFieldSelectors #-} + -- These test items do not run a full `inferno-ml-server` instance and only -- check that a limited subset of server operations work as intended (e.g. model -- fetching and caching). For full server tests, see `tests/server.nix`. @@ -12,32 +15,24 @@ import Data.Aeson (eitherDecodeFileStrict) import Data.ByteString (ByteString) import qualified Data.ByteString as ByteString import Data.Foldable (toList, traverse_) +import Data.Generics.Wrapped (wrappedTo) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map -import Data.Text (Text) import qualified Data.UUID as UUID import Data.Vector (Vector) import qualified Data.Vector as Vector import Data.Word (Word8) import Inferno.ML.Server (runInEnv) -import Inferno.ML.Server.Inference - ( getAndCacheModels, - linkVersionedModel, - ) +import Inferno.ML.Server.Inference (getAndCacheModels) import Inferno.ML.Server.Inference.Model ( getModelVersionSizeAndContents, getModelsAndVersions, ) import Inferno.ML.Server.Types - ( Config, - Env, - Id (Id), - ModelVersion, - showVersion, - ) import Inferno.Types.Syntax (Ident) import Lens.Micro.Platform import Plow.Logging.Message (LogLevel (LevelWarn)) +import System.FilePath ((<.>)) import Test.Hspec (Spec) import qualified Test.Hspec as Hspec import UnliftIO (throwString) @@ -69,22 +64,21 @@ mkCacheSpec :: Env -> Spec mkCacheSpec env = Hspec.before_ clearCache . Hspec.describe "Model cache" $ do Hspec.it "caches a model" . cdCache $ do cacheModel - dir <- env ^. #config . #cache . #path & listDirectory - dir `Hspec.shouldMatchList` ["mnist.v1", "mnist.ts.pt"] - contents <- ByteString.readFile "mnist.ts.pt" + dir <- listDirectory env.config.cache.path + dir `Hspec.shouldMatchList` [mnistV1Path] + contents <- ByteString.readFile mnistV1Path ByteString.length contents `Hspec.shouldBe` mnistV1Size getZipMagic contents `Hspec.shouldBe` zipMagic Hspec.it "doesn't re-cache" . cdCache $ do - atime1 <- cacheModel *> getModificationTime "mnist.ts.pt" - atime2 <- cacheModel *> getModificationTime "mnist.ts.pt" + atime1 <- cacheModel *> getModificationTime mnistV1Path + atime2 <- cacheModel *> getModificationTime mnistV1Path atime2 `Hspec.shouldBe` atime1 where cacheModel :: IO () cacheModel = void . flip runReaderT env $ - traverse_ linkVersionedModel - =<< (`getAndCacheModels` modelsWithIdents) + (`getAndCacheModels` modelsWithIdents) =<< view (#config . #cache) clearCache :: IO () @@ -95,19 +89,19 @@ mkCacheSpec env = Hspec.before_ clearCache . Hspec.describe "Model cache" $ do =<< getCurrentDirectory cdCache :: IO a -> IO a - cdCache = env ^. #config . #cache . #path & withCurrentDirectory + cdCache = withCurrentDirectory env.config.cache.path -modelsWithIdents :: Map Ident (Id ModelVersion, Text) -modelsWithIdents = Map.singleton "dummy" (mnistV1, "mnist") +modelsWithIdents :: Map Ident (Id ModelVersion) +modelsWithIdents = Map.singleton "dummy" mnistV1 mkDbSpec :: Env -> Spec mkDbSpec env = Hspec.describe "Database" $ do Hspec.it "gets a model" $ do - runReaderT (getModelsAndVersions models) env >>= \case + runReaderT (getModelsAndVersions modelVersions) env >>= \case v | Just (model, mversion) <- v ^? _head -> do - view #name model `Hspec.shouldBe` "mnist" - view (#version . to showVersion) mversion `Hspec.shouldBe` "v1" + model.name `Hspec.shouldBe` "mnist" + showVersion mversion.version `Hspec.shouldBe` "v1" | otherwise -> Hspec.expectationFailure "No models were retrieved" Hspec.it "gets model size and contents" $ do @@ -123,15 +117,18 @@ mkDbSpec env = Hspec.describe "Database" $ do getWithContents :: Env -> IO (Integer, ByteString) getWithContents env = flip runReaderT env $ do - (getModelsAndVersions models >>=) . (. fmap snd . toList) $ \case + (getModelsAndVersions modelVersions >>=) . (. fmap snd . toList) $ \case [] -> throwString "No model was retrieved" - v : _ -> getModelVersionSizeAndContents $ view #contents v + v : _ -> getModelVersionSizeAndContents v.contents mnistV1 :: Id ModelVersion mnistV1 = Id $ UUID.fromWords 6 0 0 0 -models :: Vector (Id ModelVersion) -models = Vector.singleton mnistV1 +mnistV1Path :: FilePath +mnistV1Path = UUID.toString (wrappedTo mnistV1) <.> "ts" <.> "pt" + +modelVersions :: Vector (Id ModelVersion) +modelVersions = Vector.singleton mnistV1 mnistV1Size :: Int mnistV1Size = 4808991 diff --git a/inferno-ml/inferno-ml.cabal b/inferno-ml/inferno-ml.cabal index 234922d..5a29c1b 100644 --- a/inferno-ml/inferno-ml.cabal +++ b/inferno-ml/inferno-ml.cabal @@ -34,7 +34,6 @@ library , template-haskell , text , prettyprinter - , filepath default-language: Haskell2010 default-extensions: LambdaCase diff --git a/inferno-ml/src/Inferno/ML/Module/Prelude.hs b/inferno-ml/src/Inferno/ML/Module/Prelude.hs index af564b8..b3c57a2 100644 --- a/inferno-ml/src/Inferno/ML/Module/Prelude.hs +++ b/inferno-ml/src/Inferno/ML/Module/Prelude.hs @@ -25,7 +25,6 @@ import qualified Inferno.Module.Prelude as Prelude import Inferno.Types.Syntax (Ident) import Inferno.Types.Value (Value (..)) import Prettyprinter (Pretty) -import System.FilePath ((<.>)) import Torch import qualified Torch.DType as TD import Torch.Functional @@ -143,7 +142,7 @@ loadModelFun = VFun $ \case =<< liftIO (try @_ @SomeException loadModel) where loadModel :: IO ScriptModule - loadModel = TS.loadScript TS.WithoutRequiredGrad $ mn <.> "ts.pt" + loadModel = TS.loadScript TS.WithoutRequiredGrad mn _ -> throwM $ RuntimeError "Expected a modelName" forwardFun :: ScriptModule -> [Tensor] -> [Tensor] From 482db0dd5b10588b9c7bd6a846b80f12c7a745c2 Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford Date: Fri, 18 Oct 2024 11:47:32 +0700 Subject: [PATCH 11/13] Fix URLs for running/testing --- inferno-ml-server-types/src/Inferno/ML/Server/Types.hs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs index 52c66a3..ccf1666 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs @@ -140,6 +140,7 @@ type CancelAPI = "inference" :> "cancel" :> Put '[JSON] () type InferenceAPI gid p s = "inference" + :> "run" :> Capture "id" (Id (InferenceParam gid p s)) :> QueryParam "res" Int64 :> QueryParam' '[Required] "uuid" UUID @@ -148,6 +149,7 @@ type InferenceAPI gid p s = type InferenceTestAPI gid p s = -- Evaluate an inference script "inference" + :> "test" :> Capture "id" (Id (InferenceParam gid p s)) :> QueryParam "res" Int64 :> QueryParam' '[Required] "uuid" UUID From fed8e75c9ca3eff1fbbe8ea53af07648f5a48fe4 Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford Date: Fri, 18 Oct 2024 15:08:36 +0700 Subject: [PATCH 12/13] Versions, changelogs --- inferno-ml-server-types/CHANGELOG.md | 4 ++++ inferno-ml-server-types/inferno-ml-server-types.cabal | 2 +- inferno-ml-server/CHANGELOG.md | 5 +++++ inferno-ml-server/inferno-ml-server.cabal | 2 +- inferno-ml-server/src/Inferno/ML/Server/Inference.hs | 4 +++- 5 files changed, 14 insertions(+), 3 deletions(-) diff --git a/inferno-ml-server-types/CHANGELOG.md b/inferno-ml-server-types/CHANGELOG.md index 95cad3c..1da27b7 100644 --- a/inferno-ml-server-types/CHANGELOG.md +++ b/inferno-ml-server-types/CHANGELOG.md @@ -1,6 +1,10 @@ # Revision History for inferno-ml-server-types *Note*: we use https://pvp.haskell.org/ (MAJOR.MAJOR.MINOR.PATCH) +## 0.10.0 +* Change `Id` to `UUID` +* Add new testing endpoint to override models, script, etc... + ## 0.9.1 * `Ord`/`VCHashUpdate` instances for `ScriptInputType` diff --git a/inferno-ml-server-types/inferno-ml-server-types.cabal b/inferno-ml-server-types/inferno-ml-server-types.cabal index df56801..4be5e40 100644 --- a/inferno-ml-server-types/inferno-ml-server-types.cabal +++ b/inferno-ml-server-types/inferno-ml-server-types.cabal @@ -1,6 +1,6 @@ cabal-version: 2.4 name: inferno-ml-server-types -version: 0.9.1 +version: 0.10.0 synopsis: Types for Inferno ML server description: Types for Inferno ML server homepage: https://github.com/plow-technologies/inferno.git#readme diff --git a/inferno-ml-server/CHANGELOG.md b/inferno-ml-server/CHANGELOG.md index c13f6f9..bd4d854 100644 --- a/inferno-ml-server/CHANGELOG.md +++ b/inferno-ml-server/CHANGELOG.md @@ -1,5 +1,10 @@ # Revision History for `inferno-ml-server` +## 2023.10.18 +* Add new testing route +* Some improvements to model caching +* Make `/status` not awful and confusing + ## 2023.9.27 * Change entity DB representation to `numeric` diff --git a/inferno-ml-server/inferno-ml-server.cabal b/inferno-ml-server/inferno-ml-server.cabal index 5d87ffa..ca96748 100644 --- a/inferno-ml-server/inferno-ml-server.cabal +++ b/inferno-ml-server/inferno-ml-server.cabal @@ -1,6 +1,6 @@ cabal-version: 2.4 name: inferno-ml-server -version: 2023.9.27 +version: 2023.10.18 synopsis: Server for Inferno ML description: Server for Inferno ML homepage: https://github.com/plow-technologies/inferno.git#readme diff --git a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs index a2f93db..2d88f2f 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs @@ -505,7 +505,9 @@ getAndCacheModels cache = where mkPath :: RemoteM FilePath mkPath = - maybe (throwM (OtherRemoteError "todo")) (pure . mkModelPath) $ + maybe + (throwM (OtherRemoteError "Missing model version ID")) + (pure . mkModelPath) mversion.id -- Checks that the configured cache size will not be exceeded by From 35350fdf4ac71d434db8d857f0679a9df1861cd8 Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford Date: Fri, 18 Oct 2024 15:21:09 +0700 Subject: [PATCH 13/13] More small cleanup --- inferno-ml-server/src/Inferno/ML/Server/Inference.hs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs index 2d88f2f..2f62ba2 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs @@ -73,6 +73,7 @@ import Inferno.Utils.Prettyprinter (renderPretty) import Inferno.VersionControl.Types ( VCObject (VCFunction), ) +import qualified Inferno.VersionControl.Types import Lens.Micro.Platform import System.CPUTime (getCPUTime) import System.FilePath ((<.>)) @@ -208,7 +209,7 @@ runInferenceParamWithEnv ipid uuid senv = CTime -> RemoteM (WriteStream IO) runEval Interpreter {evalExpr, mkEnvFromClosure} t = - senv ^. #obj . #obj & \case + case senv.obj.obj of VCFunction {} -> do let -- Note that this both includes inputs (i.e. readable) -- and outputs (i.e. writable, or readable/writable). @@ -264,7 +265,7 @@ runInferenceParamWithEnv ipid uuid senv = bimap (mkIdentWith "model$") toModelPath closure :: Map VCObjectHash VCObject - closure = senv ^. #obj . #obj & Map.singleton senv.script + closure = Map.singleton senv.script senv.obj.obj expr :: Expr (Maybe VCObjectHash) () expr = @@ -315,10 +316,9 @@ runInferenceParamWithEnv ipid uuid senv = -- use that. Otherwise, use the resolution stored in the -- parameter resolution :: InverseResolution - resolution = senv ^. #mres . non res & toResolution - where - res :: Int64 - res = senv ^. #param . #resolution & fromIntegral + resolution = + senv ^. #mres . non (fromIntegral senv.param.resolution) + & toResolution implEnv :: Map ExtIdent (Value BridgeMlValue m) implEnv =