From aff8fc93c01f712d78d2a8795f976347428dd29f Mon Sep 17 00:00:00 2001 From: Rory Tyler Hayford <52039264+ngua@users.noreply.github.com> Date: Fri, 1 Nov 2024 10:03:08 +0700 Subject: [PATCH] [inferno-ml] Change parameter input/output representation (#147) The previous way that parameter inputs/outputs were stored, with both condensed into a single `Map Ident ...`, turned out to be quite nightmarish to work with. We are going to be doing things like filtering or querying parameters by output, so we need it to be easier to look outputs up. So this changes the inputs/outputs back to the way they used to be, e.g. two separate `Map`s --- inferno-ml-server-types/CHANGELOG.md | 3 + .../inferno-ml-server-types.cabal | 2 +- .../src/Inferno/ML/Server/Client.hs | 12 +- .../src/Inferno/ML/Server/Types.hs | 123 ++++++------------ inferno-ml-server/exe/ParseAndSave.hs | 29 +++-- .../src/Inferno/ML/Server/Inference.hs | 14 +- .../src/Inferno/ML/Server/Types.hs | 19 ++- .../migrations/v1-create-tables.sql | 7 +- .../tests/scripts/contrived.inferno | 4 +- nix/inferno-ml/tests/scripts/mnist.inferno | 6 +- nix/inferno-ml/tests/scripts/ones.inferno | 4 +- nix/inferno-ml/tests/server.nix | 17 ++- 12 files changed, 112 insertions(+), 128 deletions(-) diff --git a/inferno-ml-server-types/CHANGELOG.md b/inferno-ml-server-types/CHANGELOG.md index 1da27b7..7855d72 100644 --- a/inferno-ml-server-types/CHANGELOG.md +++ b/inferno-ml-server-types/CHANGELOG.md @@ -1,6 +1,9 @@ # Revision History for inferno-ml-server-types *Note*: we use https://pvp.haskell.org/ (MAJOR.MAJOR.MINOR.PATCH) +## 0.11.0 +* Split parameter inputs and outputs + ## 0.10.0 * Change `Id` to `UUID` * Add new testing endpoint to override models, script, etc... diff --git a/inferno-ml-server-types/inferno-ml-server-types.cabal b/inferno-ml-server-types/inferno-ml-server-types.cabal index 4be5e40..d976516 100644 --- a/inferno-ml-server-types/inferno-ml-server-types.cabal +++ b/inferno-ml-server-types/inferno-ml-server-types.cabal @@ -1,6 +1,6 @@ cabal-version: 2.4 name: inferno-ml-server-types -version: 0.10.0 +version: 0.11.0 synopsis: Types for Inferno ML server description: Types for Inferno ML server homepage: https://github.com/plow-technologies/inferno.git#readme diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs index 41461ae..c931552 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Client.hs @@ -27,9 +27,9 @@ cancelC = client $ Proxy @CancelAPI -- | Run an inference parameter inferenceC :: - forall gid p s. + forall gid p. -- | SQL identifier of the inference parameter to be run - Id (InferenceParam gid p s) -> + Id (InferenceParam gid p) -> -- | Optional resolution for scripts that use e.g. @valueAt@; defaults to -- the param\'s stored resolution if not provided. This lets users override -- the resolution on an ad-hoc basis without needing to alter the stored @@ -44,16 +44,16 @@ inferenceC :: -- (not defined in this repository) to verify this before directing -- the writes to their final destination ClientM (WriteStream IO) -inferenceC = client $ Proxy @(InferenceAPI gid p s) +inferenceC = client $ Proxy @(InferenceAPI gid p) -- | Run an inference parameter inferenceTestC :: - forall gid p s. + forall gid p. ToJSON p => -- | SQL identifier of the inference parameter to be run - Id (InferenceParam gid p s) -> + Id (InferenceParam gid p) -> Maybe Int64 -> UUID -> EvaluationEnv gid p -> ClientM (WriteStream IO) -inferenceTestC = client $ Proxy @(InferenceTestAPI gid p s) +inferenceTestC = client $ Proxy @(InferenceTestAPI gid p) diff --git a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs index 644cd68..f2f9884 100644 --- a/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server-types/src/Inferno/ML/Server/Types.hs @@ -70,9 +70,7 @@ import GHC.Generics (Generic) import Inferno.Instances.Arbitrary () import Inferno.Types.Syntax (Ident) import Inferno.Types.VersionControl - ( VCHashUpdate, - VCHashUpdateViaShow (VCHashUpdateViaShow), - VCObjectHash, + ( VCObjectHash, byteStringToVCObjectHash, vcObjectHashToByteString, ) @@ -127,11 +125,11 @@ import Web.HttpApiData ) -- API type for `inferno-ml-server` -type InfernoMlServerAPI gid p s = +type InfernoMlServerAPI gid p = StatusAPI -- Evaluate an inference script - :<|> InferenceAPI gid p s - :<|> InferenceTestAPI gid p s + :<|> InferenceAPI gid p + :<|> InferenceTestAPI gid p :<|> CancelAPI type StatusAPI = @@ -140,19 +138,19 @@ type StatusAPI = type CancelAPI = "inference" :> "cancel" :> Put '[JSON] () -type InferenceAPI gid p s = +type InferenceAPI gid p = "inference" :> "run" - :> Capture "id" (Id (InferenceParam gid p s)) + :> Capture "id" (Id (InferenceParam gid p)) :> QueryParam "res" Int64 :> QueryParam' '[Required] "uuid" UUID :> StreamPost NewlineFraming JSON (WriteStream IO) -type InferenceTestAPI gid p s = +type InferenceTestAPI gid p = -- Evaluate an inference script "inference" :> "test" - :> Capture "id" (Id (InferenceParam gid p s)) + :> Capture "id" (Id (InferenceParam gid p)) :> QueryParam "res" Int64 :> QueryParam' '[Required] "uuid" UUID :> ReqBody '[JSON] (EvaluationEnv gid p) @@ -192,25 +190,28 @@ data ServerStatus = Idle | EvaluatingScript deriving stock (Show, Eq, Generic) - deriving anyclass (FromJSON, ToJSON) + deriving anyclass (FromJSON, ToJSON, ToADTArbitrary, NFData) + +instance Arbitrary ServerStatus where + arbitrary = genericArbitrary -- | Information for contacting a bridge server that implements the 'BridgeAPI' -data BridgeInfo gid p s = BridgeInfo - { id :: Id (InferenceParam gid p s), +data BridgeInfo gid p = BridgeInfo + { id :: Id (InferenceParam gid p), host :: IPv4, port :: Word64 } deriving stock (Show, Eq, Generic) deriving anyclass (FromJSON, ToJSON, NFData) -instance FromRow (BridgeInfo gid p s) where +instance FromRow (BridgeInfo gid p) where fromRow = BridgeInfo <$> field <*> field <*> fmap (fromIntegral @Int64) field -instance ToRow (BridgeInfo gid p s) where +instance ToRow (BridgeInfo gid p) where toRow bi = [ bi.id & toField, bi.host & toField, @@ -677,8 +678,8 @@ showVersion (Version ns ts) = -- | Row of the inference parameter table, parameterized by the user, group, and -- script type -data InferenceParam gid p s = InferenceParam - { id :: Maybe (Id (InferenceParam gid p s)), +data InferenceParam gid p = InferenceParam + { id :: Maybe (Id (InferenceParam gid p)), -- | The script of the parameter -- -- For new parameters, this will be textual or some other identifier @@ -686,18 +687,12 @@ data InferenceParam gid p s = InferenceParam -- -- For existing inference params, this is the foreign key for the specific -- script in the 'InferenceScript' table (i.e. a @VCObjectHash@) - script :: s, - -- | This is called @inputs@ but is also used for script outputs as - -- well. The access (input or output) is controlled by the 'ScriptInputType'. - -- For example, if this field is set to @[("input0", Single (p, Readable))]@, - -- the script will only have a single read-only input and will not be able to - -- write anywhere (note that we should disallow this scenario, as script - -- evaluation would not work properly) - -- - -- Mapping the input\/output to the Inferno identifier helps ensure that + script :: VCObjectHash, + -- | Mapping the input\/output to the Inferno identifier helps ensure that -- Inferno identifiers are always pointing to the correct input\/output; -- otherwise we would need to rely on the order of the original identifiers - inputs :: Map Ident (SingleOrMany p, ScriptInputType), + inputs :: Map Ident (SingleOrMany p), + outputs :: Map Ident (SingleOrMany p), -- | Resolution, passed to bridge routes resolution :: Word64, -- | The time that this parameter was \"deleted\", if any. For active @@ -709,7 +704,7 @@ data InferenceParam gid p s = InferenceParam deriving anyclass (NFData, ToJSON) {- ORMOLU_DISABLE -} -instance (FromJSON s, FromJSON p, FromJSON gid) => FromJSON (InferenceParam gid p s) +instance (FromJSON p, FromJSON gid) => FromJSON (InferenceParam gid p) where parseJSON = withObject "InferenceParam" $ \o -> InferenceParam @@ -717,6 +712,7 @@ instance (FromJSON s, FromJSON p, FromJSON gid) => FromJSON (InferenceParam gid <$> o .: "id" <*> o .: "script" <*> o .: "inputs" + <*> o .: "outputs" <*> o .:? "resolution" .!= 128 -- We shouldn't require this field <*> o .:? "terminated" @@ -731,53 +727,44 @@ instance Typeable gid, Typeable p ) => - FromRow (InferenceParam gid p VCObjectHash) + FromRow (InferenceParam gid p) where fromRow = InferenceParam <$> field <*> fmap wrappedTo (field @VCObjectHashRow) <*> fmap getAeson field + <*> fmap getAeson field <*> fmap fromIntegral (field @Int64) <*> field <*> field -instance (ToJSON p, ToField gid) => ToRow (InferenceParam gid p VCObjectHash) where +instance (ToJSON p, ToField gid) => ToRow (InferenceParam gid p) where -- NOTE: Do not change the order of the field actions toRow ip = [ ip.id & maybe (toField Default) toField, ip.script & VCObjectHashRow & toField, ip.inputs & Aeson & toField, + ip.outputs & Aeson & toField, ip.resolution & Aeson & toField, toField Default, ip.gid & toField ] -- Not derived generically in order to use special `Gen UTCTime` -instance - ( Arbitrary gid, - Arbitrary p, - Arbitrary s - ) => - Arbitrary (InferenceParam gid p s) - where +instance (Arbitrary gid, Arbitrary p) => Arbitrary (InferenceParam gid p) where arbitrary = InferenceParam <$> arbitrary <*> arbitrary <*> arbitrary <*> arbitrary + <*> arbitrary <*> genMUtc <*> arbitrary -- Can't be derived because there is (intentially) no `Arbitrary UTCTime` in scope -instance - ( Arbitrary gid, - Arbitrary p, - Arbitrary s - ) => - ToADTArbitrary (InferenceParam gid p s) - where +instance (Arbitrary gid, Arbitrary p) => ToADTArbitrary (InferenceParam gid p) where toADTArbitrarySingleton _ = ADTArbitrarySingleton "Inferno.ML.Server.Types" "InferenceParam" . ConstructorArbitraryPair "InferenceParam" @@ -789,46 +776,12 @@ instance -- | An 'InferenceParam' together with all of the model versions that are -- linked to it indirectly via its script. This is provided for convenience -data InferenceParamWithModels gid p s = InferenceParamWithModels - { param :: InferenceParam gid p s, +data InferenceParamWithModels gid p = InferenceParamWithModels + { param :: InferenceParam gid p, models :: Map Ident (Id (ModelVersion gid Oid)) } deriving stock (Show, Eq, Generic) --- | Controls input interaction within a script, i.e. ability to read from --- and\/or write to this input. Although the term \"input\" is used, those with --- writes enabled can also be described as \"outputs\" -data ScriptInputType - = -- | Script input can be read, but not written - Readable - | -- | Script input can be written, i.e. can be used in array of - -- write objects returned from script evaluation - Writable - | -- | Script input can be both read from and written to; this allows - -- the same script identifier to point to the same PID with both - -- types of access enabled - ReadableWritable - deriving stock (Show, Eq, Ord, Generic) - deriving anyclass (NFData, ToADTArbitrary) - deriving (VCHashUpdate) via (VCHashUpdateViaShow ScriptInputType) - -instance FromJSON ScriptInputType where - parseJSON = withText "ScriptInputType" $ \case - "r" -> pure Readable - "w" -> pure Writable - "rw" -> pure ReadableWritable - s -> fail $ "Invalid script input type: " <> Text.unpack s - -instance ToJSON ScriptInputType where - toJSON = - String . \case - Readable -> "r" - Writable -> "w" - ReadableWritable -> "rw" - -instance Arbitrary ScriptInputType where - arbitrary = genericArbitrary - -- | Information about execution time and resource usage. This is saved by -- @inferno-ml-server@ after script evaluation completes and can be queried -- later by using the same job identifier that was provided to the @/inference@ @@ -837,7 +790,7 @@ data EvaluationInfo gid p = EvaluationInfo { -- | Note that this is the job identifier provided to the inference -- evaluation route, and is also the primary key of the database table id :: UUID, - param :: Id (InferenceParam gid p VCObjectHash), + param :: Id (InferenceParam gid p), -- | When inference evaluation started start :: UTCTime, -- | When inference evaluation ended @@ -1053,11 +1006,15 @@ instance Ord a => Ord (SingleOrMany a) where -- evaluator. This allows for more interactive testing data EvaluationEnv gid p = EvaluationEnv { script :: VCObjectHash, - inputs :: Map Ident (SingleOrMany p, ScriptInputType), + inputs :: Map Ident (SingleOrMany p), + outputs :: Map Ident (SingleOrMany p), models :: Map Ident (Id (ModelVersion gid Oid)) } deriving stock (Show, Eq, Generic) - deriving anyclass (FromJSON, ToJSON) + deriving anyclass (FromJSON, ToJSON, ToADTArbitrary) + +instance Arbitrary p => Arbitrary (EvaluationEnv gid p) where + arbitrary = genericArbitrary tshow :: Show a => a -> Text tshow = Text.pack . show diff --git a/inferno-ml-server/exe/ParseAndSave.hs b/inferno-ml-server/exe/ParseAndSave.hs index f1aa9da..09a6b85 100644 --- a/inferno-ml-server/exe/ParseAndSave.hs +++ b/inferno-ml-server/exe/ParseAndSave.hs @@ -1,4 +1,7 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -13,7 +16,7 @@ module ParseAndSave (main) where import Control.Category ((>>>)) import Control.Exception (Exception (displayException)) import Control.Monad (void) -import Data.Aeson (eitherDecode) +import Data.Aeson (FromJSON, eitherDecode) import Data.ByteString (ByteString) import qualified Data.ByteString.Char8 as Char8 import qualified Data.ByteString.Lazy.Char8 as Lazy.Char8 @@ -33,6 +36,7 @@ import Database.PostgreSQL.Simple ) import Database.PostgreSQL.Simple.SqlQQ (sql) import Foreign.C (CTime) +import GHC.Generics (Generic) import Inferno.Core ( Interpreter (Interpreter, parseAndInfer), mkInferno, @@ -70,24 +74,24 @@ parseAndSave :: Id InferenceParam -> FilePath -> ByteString -> - Map Ident (SingleOrMany PID, ScriptInputType) -> + InputsOutputs -> IO () -parseAndSave ipid p conns inputs = do +parseAndSave ipid p conns ios = do t <- Text.IO.readFile p now <- fromIntegral @Int . round <$> getPOSIXTime ast <- either (throwString . displayException) pure . (`parse` t) =<< mkInferno @_ @BridgeMlValue (mkBridgePrelude funs) customTypes - bracket (connectPostgreSQL conns) close (saveScriptAndParam ipid ast now inputs) + bracket (connectPostgreSQL conns) close (saveScriptAndParam ipid ast now ios) saveScriptAndParam :: Id InferenceParam -> (Expr (Pinned VCObjectHash) (), TCScheme) -> CTime -> - Map Ident (SingleOrMany PID, ScriptInputType) -> + InputsOutputs -> Connection -> IO () -saveScriptAndParam ipid x now inputs conn = insertScript *> insertParam +saveScriptAndParam ipid x now ios conn = insertScript *> insertParam where insertScript :: IO () insertScript = @@ -120,7 +124,8 @@ saveScriptAndParam ipid x now inputs conn = insertScript *> insertParam . InferenceParam (Just ipid) hash - inputs + ios.inputs + ios.outputs 128 Nothing $ entityIdFromInteger 0 @@ -132,11 +137,12 @@ saveScriptAndParam ipid x now inputs conn = insertScript *> insertParam ( id , script , inputs + , outputs , resolution , terminated , gid ) - VALUES (?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, ?) |] saveBridgeInfo :: IO () @@ -193,3 +199,10 @@ funs = BridgeFuns notSupported notSupported notSupported notSupported where notSupported :: a notSupported = error "Not supported" + +data InputsOutputs = InputsOutputs + { inputs :: Map Ident (SingleOrMany PID), + outputs :: Map Ident (SingleOrMany PID) + } + deriving stock (Generic) + deriving anyclass (FromJSON) diff --git a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs index 2f62ba2..733298b 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Inference.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Inference.hs @@ -120,7 +120,7 @@ runInferenceParam ipid mres uuid = where mkScriptEnv :: InferenceParamWithModels -> RemoteM ScriptEnv mkScriptEnv pwm = - ScriptEnv pwm.param pwm.models pwm.param.inputs + ScriptEnv pwm.param pwm.models pwm.param.inputs pwm.param.outputs <$> getVcObject pwm.param.script ?? pwm.param.script ?? mres @@ -146,7 +146,7 @@ testInferenceParam ipid mres uuid eenv = -- for script eval come from the `EvaluationEnv` mkScriptEnv :: InferenceParam -> RemoteM ScriptEnv mkScriptEnv param = - ScriptEnv param eenv.models eenv.inputs + ScriptEnv param eenv.models eenv.inputs eenv.outputs <$> getVcObject eenv.script ?? eenv.script ?? mres @@ -220,8 +220,11 @@ runInferenceParamWithEnv ipid uuid senv = -- runtime that runs as a script evaluation engine -- and commits the output write object pids :: [SingleOrMany PID] - pids = - senv ^.. #inputs . to Map.toAscList . each . _2 . _1 + pids = is <> os + where + is, os :: [SingleOrMany PID] + is = senv ^.. #inputs . to Map.toAscList . each . _2 + os = senv ^.. #outputs . to Map.toAscList . each . _2 -- List of model versions, which are used to evaluate -- `loadModel` primitive (eventually calling Hasktorch to @@ -584,7 +587,8 @@ mkModelPath = (<.> "ts" <.> "pt") . UUID.toString . wrappedTo data ScriptEnv = ScriptEnv { param :: InferenceParam, models :: Map Ident (Id ModelVersion), - inputs :: Map Ident (SingleOrMany PID, ScriptInputType), + inputs :: Map Ident (SingleOrMany PID), + outputs :: Map Ident (SingleOrMany PID), obj :: VCMeta VCObject, script :: VCObjectHash, mres :: Maybe Int64 diff --git a/inferno-ml-server/src/Inferno/ML/Server/Types.hs b/inferno-ml-server/src/Inferno/ML/Server/Types.hs index e16c8a1..ff7eda1 100644 --- a/inferno-ml-server/src/Inferno/ML/Server/Types.hs +++ b/inferno-ml-server/src/Inferno/ML/Server/Types.hs @@ -413,14 +413,11 @@ infixl 4 ?? (??) :: Functor f => f (a -> b) -> a -> f b f ?? x = ($ x) <$> f -type InferenceParam = - Types.InferenceParam (EntityId GId) PID VCObjectHash +type InferenceParam = Types.InferenceParam (EntityId GId) PID -type InferenceParamWithModels = - Types.InferenceParamWithModels (EntityId GId) PID VCObjectHash +type InferenceParamWithModels = Types.InferenceParamWithModels (EntityId GId) PID -type BridgeInfo = - Types.BridgeInfo (EntityId GId) PID VCObjectHash +type BridgeInfo = Types.BridgeInfo (EntityId GId) PID type EvaluationInfo = Types.EvaluationInfo (EntityId GId) PID @@ -438,13 +435,14 @@ pattern InferenceScript h o = Types.InferenceScript h o pattern InferenceParam :: Maybe (Id InferenceParam) -> VCObjectHash -> - Map Ident (SingleOrMany PID, ScriptInputType) -> + Map Ident (SingleOrMany PID) -> + Map Ident (SingleOrMany PID) -> Word64 -> Maybe UTCTime -> EntityId GId -> InferenceParam -pattern InferenceParam iid s ios res mt gid = - Types.InferenceParam iid s ios res mt gid +pattern InferenceParam iid s is os res mt gid = + Types.InferenceParam iid s is os res mt gid pattern InferenceParamWithModels :: InferenceParam -> Map Ident (Id ModelVersion) -> InferenceParamWithModels @@ -476,8 +474,7 @@ pattern EvaluationInfo :: EvaluationInfo pattern EvaluationInfo u i s e m c = Types.EvaluationInfo u i s e m c -type InfernoMlServerAPI = - Types.InfernoMlServerAPI (EntityId GId) PID VCObjectHash +type InfernoMlServerAPI = Types.InfernoMlServerAPI (EntityId GId) PID type EvaluationEnv = Types.EvaluationEnv (EntityId GId) PID diff --git a/nix/inferno-ml/migrations/v1-create-tables.sql b/nix/inferno-ml/migrations/v1-create-tables.sql index 5835cce..7f84bd3 100644 --- a/nix/inferno-ml/migrations/v1-create-tables.sql +++ b/nix/inferno-ml/migrations/v1-create-tables.sql @@ -69,10 +69,11 @@ create table if not exists params ( id uuid primary key default gen_random_uuid() -- Script hash from `inferno-vc` , script bytea not null references scripts (id) - -- Strictly speaking, this includes both inputs and outputs. The - -- corresponding Haskell type contains `(p, ScriptInputType)`, with - -- the second element determining readability and writability + -- Inputs and outputs are a `Map Ident (SingleOrMany p)` on the Haskell + -- side. Stored as JSONB for convenience (e.g. Postgres subarrays must all + -- be the same length, making `SingleOrMany` harder to represent) , inputs jsonb not null + , outputs jsonb not null -- Resolution passed to script evaluator , resolution integer not null -- See note above diff --git a/nix/inferno-ml/tests/scripts/contrived.inferno b/nix/inferno-ml/tests/scripts/contrived.inferno index 0ddc378..925858f 100644 --- a/nix/inferno-ml/tests/scripts/contrived.inferno +++ b/nix/inferno-ml/tests/scripts/contrived.inferno @@ -1,5 +1,5 @@ -fun input0 mnist -> +fun input0 output0 mnist -> let t = Time.toTime (Time.seconds 200) in let ?resolution = (toResolution 128) in let v = valueAt input0 t ? 0.0 in - [makeWrites input0 [(Time.toTime (Time.seconds 300), v + 5.0)]] + [makeWrites output0 [(Time.toTime (Time.seconds 300), v + 5.0)]] diff --git a/nix/inferno-ml/tests/scripts/mnist.inferno b/nix/inferno-ml/tests/scripts/mnist.inferno index c974084..03bb044 100644 --- a/nix/inferno-ml/tests/scripts/mnist.inferno +++ b/nix/inferno-ml/tests/scripts/mnist.inferno @@ -1,4 +1,4 @@ -fun input0 input1 mnist -> +fun input0 output0 output1 mnist -> let input = ML.asTensor4 ML.#float [[[ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], @@ -34,7 +34,7 @@ fun input0 input1 mnist -> match ML.forward model [input] with { | [scores] -> let m = ML.toType ML.#double (ML.argmax 1 #false scores) in - [makeWrites input0 [(t, ML.asDouble m)], makeWrites input1 [(t, ML.asDouble m + 1.0)]] + [makeWrites output0 [(t, ML.asDouble m)], makeWrites output1 [(t, ML.asDouble m + 1.0)]] | _ -> - [makeWrites input0 [(t, -1.0)], makeWrites input1 [(t, -1.0)]] + [makeWrites output0 [(t, -1.0)], makeWrites output1 [(t, -1.0)]] } diff --git a/nix/inferno-ml/tests/scripts/ones.inferno b/nix/inferno-ml/tests/scripts/ones.inferno index 79eb6b1..230cf14 100644 --- a/nix/inferno-ml/tests/scripts/ones.inferno +++ b/nix/inferno-ml/tests/scripts/ones.inferno @@ -1,8 +1,8 @@ -fun input0 mnist -> +fun input0 output0 mnist -> let mkV = fun t -> valueAt input0 (Time.toTime (Time.seconds t)) ? 0.0 in let ts = [150, 250] in let vs = Array.map mkV ts in let xs = ML.ones ML.#double [2] in let ts1 = Array.map (fun t -> Time.toTime (Time.seconds (t + 1))) ts in let vs1 = ML.asArray1 (ML.add xs (ML.asTensor1 ML.#double vs)) in - [makeWrites input0 (zip ts1 vs1)] + [makeWrites output0 (zip ts1 vs1)] diff --git a/nix/inferno-ml/tests/server.nix b/nix/inferno-ml/tests/server.nix index 4658242..0e3a61f 100644 --- a/nix/inferno-ml/tests/server.nix +++ b/nix/inferno-ml/tests/server.nix @@ -90,12 +90,21 @@ pkgs.nixosTest { dbstr = "host='127.0.0.1' dbname='inferno' user='inferno' password=''"; ios = builtins.mapAttrs (_: builtins.toJSON) { - ones = { input0 = [ 1 "rw" ]; }; - contrived = { input0 = [ 2 "rw" ]; }; + ones = { + inputs.input0 = 1; + outputs.output0 = 1; + }; + contrived = { + inputs.input0 = 2; + outputs.output0 = 2; + }; # This test uses two outputs mnist = { - input0 = [ 3 "rw" ]; - input1 = [ 4 "w" ]; + inputs.input0 = 3; + outputs = { + output0 = 3; + output1 = 4; + }; }; }; in