diff --git a/data/inventory.arr b/data/inventory.arr new file mode 100644 index 0000000..759932c Binary files /dev/null and b/data/inventory.arr differ diff --git a/data/population.arr b/data/population.arr new file mode 100644 index 0000000..a41c692 Binary files /dev/null and b/data/population.arr differ diff --git a/data/riverswim.arrow b/data/riverswim.arrow deleted file mode 100644 index 2cebda1..0000000 Binary files a/data/riverswim.arrow and /dev/null differ diff --git a/docs/make.jl b/docs/make.jl index bfe61fa..d4c3a1e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -4,7 +4,8 @@ makedocs(sitename="MDPs.jl", modules = [MDPs], format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"), pages = ["index.md", - "simulation.md"] + "simulation.md", + "recipes.md"] ) deploydocs(; diff --git a/docs/src/recipes.md b/docs/src/recipes.md new file mode 100644 index 0000000..1d1080d --- /dev/null +++ b/docs/src/recipes.md @@ -0,0 +1,34 @@ +# Recipes + +## Converting a file format of an MDP + +Converting from a CSV to an Arrow file +```jldoctest +using MDPs +using DataFrames +using Arrow +using CSV + +filein = joinpath(dirname(pathof(MDPs)), "..", "data", "riverswim.csv") +fileout = tempname() + +model = load_mdp(CSV.File(filein); idoutcome = 1) + +output = save_mdp(DataFrame, model) +``` + +Converting from an Arrow to a CSV file +```jldoctest +using MDPs +using DataFrames +using Arrow +using CSV + +filein = joinpath(dirname(pathof(MDPs)), "..", "data", "inventory.arr") +fileout = tempname() + +model = load_mdp(Arrow.Table(filein)) + +output = save_mdp(DataFrame, model) +CSV.write(fileout, output) +``` diff --git a/src/MDPs.jl b/src/MDPs.jl index 8b20ade..7dbccca 100644 --- a/src/MDPs.jl +++ b/src/MDPs.jl @@ -12,7 +12,7 @@ export valuefunction include("models/tabular.jl") export TabMDP export state_count, action_count, states, actions -export transform +export save_mdp include("models/integral.jl") export IntMDP, IntState, IntAction diff --git a/src/models/tabular.jl b/src/models/tabular.jl index eec90e1..e458477 100644 --- a/src/models/tabular.jl +++ b/src/models/tabular.jl @@ -3,14 +3,12 @@ using LinearAlgebra using SparseArrays """ -An abstract tabular Markov Decision Process, time independent. - -Default interpretation -- State: Positive integer (>0) is non-terminal, zero or negative integer is terminal -- Action: Positive integer, anything else is invalid +An abstract tabular Markov Decision Process which is specified by a transition function. Functions that should be defined for any subtype for value and policy iterations -to work are: `state_count`, `action_count`, `transition` +to work are: `state_count`, `states`, `action_count`, `actions`, and `transition`. + +Generally, states should be 1-based. The methods `state_count` and `states` should only include non-terminal states """ @@ -27,9 +25,9 @@ function state_count end function action_count end # enumerates possible states -function states end +states(model::TabMDP) = 1:state_count(model) # enumerated possible actions -function actions end +actions(model::TabMDP, s::Int) = 1:action_count(model, s) # ---------------------------------------------------------------- @@ -37,23 +35,38 @@ function actions end # ---------------------------------------------------------------- using DataFrames: DataFrame, append! -#import Base: convert + """ - transform(T::DataFrame, model) + save_mdp(T::DataFrame, model::TabMDP) -Convert a tabular MDP to a data frame representation +Convert an MDP `model` to a `DataFrame` representation with 0-based indices. + +Important: The MDP representation uses 0-based indexes while the output +DataFrame is 0-based for backwards compatibility. + +The columns are: `idstatefrom`, `idaction`, `idstateto`, `probability`, +and `reward`. """ -function transform(::Type{DataFrame}, model::TabMDP) - result = DataFrame() +function save_mdp(::Type{DataFrame}, model::TabMDP) + arr_idstatefrom = Vector{Int}() + arr_idstateto = Vector{Int}() + arr_idaction = Vector{Int}() + arr_prob = Vector{Float64}() + arr_reward = Vector{Float64}() + for s ∈ states(model) for a ∈ actions(model, s) for (sn,p,r) ∈ transition(model, s, a) - newrow = (idstatefrom = s, idaction = a, idstateto = sn, - probability = p, reward = r) - push!(result, newrow) + push!(arr_idstatefrom, s - 1) + push!(arr_idaction, a - 1) + push!(arr_idstateto, sn - 1) + push!(arr_prob, p) + push!(arr_reward, r) end end end - result + DataFrame(idstatefrom = arr_idstatefrom, idaction = arr_idaction, + idstateto = arr_idstateto, probability = arr_prob, + reward = arr_reward) end diff --git a/test/.runtests.jl.~undo-tree~ b/test/.runtests.jl.~undo-tree~ index 79a1587..94d3b49 100644 --- a/test/.runtests.jl.~undo-tree~ +++ b/test/.runtests.jl.~undo-tree~ @@ -1,5 +1,5 @@ (undo-tree-save-format-version . 1) -"abe5057d9a119420551598e3651c2e431183c411" +"c0b404be3585f50ca0ea945d5c634100c32ebb33" [nil nil nil nil (25921 45716 187939 756000) 0 nil] ([nil nil ((#("#end " 0 1 (face font-lock-comment-delimiter-face fontified t) 1 5 (face font-lock-comment-face fontified t)) . 157) (undo-tree-id48 . -3) (undo-tree-id49 . -4) (undo-tree-id50 . -4) (undo-tree-id51 . -4) (undo-tree-id52 . -3) (undo-tree-id53 . -3) (undo-tree-id54 . -3) (undo-tree-id55 . -3) (undo-tree-id56 . -3) (undo-tree-id57 . -5) (undo-tree-id58 . -4) 160 (t 25921 45265 790689 741000)) nil (25921 45716 187938 722000) 0 nil]) @@ -13,5 +13,9 @@ ([nil nil ((#("n" 0 1 (face font-lock-string-face fontified t)) . -63) (undo-tree-id16 . -1) (undo-tree-id17 . -1) (undo-tree-id18 . -1) (undo-tree-id19 . -1) (undo-tree-id20 . -1) (undo-tree-id21 . -1) (undo-tree-id22 . -1) (undo-tree-id23 . -1) (undo-tree-id24 . -1) (undo-tree-id25 . -1) (undo-tree-id26 . -1) (undo-tree-id27 . -1) (undo-tree-id28 . -1) (undo-tree-id29 . -1) (undo-tree-id30 . -1) (undo-tree-id31 . -1) (undo-tree-id32 . -1) (undo-tree-id33 . -1) (undo-tree-id34 . -1) (#("u" 0 1 (face font-lock-string-face fontified t)) . -64) (undo-tree-id35 . -1) (undo-tree-id36 . -1) (undo-tree-id37 . -1) (undo-tree-id38 . -1) (undo-tree-id39 . -1) (undo-tree-id40 . -1) (undo-tree-id41 . -1) (undo-tree-id42 . -1) (undo-tree-id43 . -1) (undo-tree-id44 . -1) (undo-tree-id45 . -1) (undo-tree-id46 . -1) (undo-tree-id47 . -1) (undo-tree-id48 . -1) (undo-tree-id49 . -1) (undo-tree-id50 . -1) (undo-tree-id51 . -1) (undo-tree-id52 . -1) (undo-tree-id53 . -1) (#("m" 0 1 (face font-lock-string-face fontified t)) . -65) (undo-tree-id54 . -1) (undo-tree-id55 . -1) (undo-tree-id56 . -1) (undo-tree-id57 . -1) (undo-tree-id58 . -1) (undo-tree-id59 . -1) (undo-tree-id60 . -1) (undo-tree-id61 . -1) (undo-tree-id62 . -1) (undo-tree-id63 . -1) (undo-tree-id64 . -1) (undo-tree-id65 . -1) (undo-tree-id66 . -1) (undo-tree-id67 . -1) (undo-tree-id68 . -1) (undo-tree-id69 . -1) (undo-tree-id70 . -1) (undo-tree-id71 . -1) (undo-tree-id72 . -1) (#("e" 0 1 (face font-lock-string-face fontified t)) . -66) (undo-tree-id73 . -1) (undo-tree-id74 . -1) (undo-tree-id75 . -1) (undo-tree-id76 . -1) (undo-tree-id77 . -1) (undo-tree-id78 . -1) (undo-tree-id79 . -1) (undo-tree-id80 . -1) (undo-tree-id81 . -1) (undo-tree-id82 . -1) (undo-tree-id83 . -1) (undo-tree-id84 . -1) (undo-tree-id85 . -1) (undo-tree-id86 . -1) (undo-tree-id87 . -1) (undo-tree-id88 . -1) (undo-tree-id89 . -1) (undo-tree-id90 . -1) (undo-tree-id91 . -1) (#("r" 0 1 (face font-lock-string-face fontified t)) . -67) (undo-tree-id92 . -1) (undo-tree-id93 . -1) (undo-tree-id94 . -1) (undo-tree-id95 . -1) (undo-tree-id96 . -1) (undo-tree-id97 . -1) (undo-tree-id98 . -1) (undo-tree-id99 . -1) (undo-tree-id100 . -1) (undo-tree-id101 . -1) (undo-tree-id102 . -1) (undo-tree-id103 . -1) (undo-tree-id104 . -1) (undo-tree-id105 . -1) (undo-tree-id106 . -1) (undo-tree-id107 . -1) (undo-tree-id108 . -1) (undo-tree-id109 . -1) (undo-tree-id110 . -1) (#("i" 0 1 (face font-lock-string-face fontified t)) . -68) (undo-tree-id111 . -1) (undo-tree-id112 . -1) (undo-tree-id113 . -1) (undo-tree-id114 . -1) (undo-tree-id115 . -1) (undo-tree-id116 . -1) (undo-tree-id117 . -1) (undo-tree-id118 . -1) (undo-tree-id119 . -1) (undo-tree-id120 . -1) (undo-tree-id121 . -1) (undo-tree-id122 . -1) (undo-tree-id123 . -1) (undo-tree-id124 . -1) (undo-tree-id125 . -1) (undo-tree-id126 . -1) (undo-tree-id127 . -1) (undo-tree-id128 . -1) (undo-tree-id129 . -1) (#("c" 0 1 (face font-lock-string-face fontified t)) . -69) (undo-tree-id130 . -1) (undo-tree-id131 . -1) (undo-tree-id132 . -1) (undo-tree-id133 . -1) (undo-tree-id134 . -1) (undo-tree-id135 . -1) (undo-tree-id136 . -1) (undo-tree-id137 . -1) (undo-tree-id138 . -1) (undo-tree-id139 . -1) (undo-tree-id140 . -1) (undo-tree-id141 . -1) 70 (t 25984 44233 739300 220000)) nil (25984 45016 204347 271000) 0 nil]) ([nil nil ((63 . 72)) nil (25984 45016 203612 532000) 0 nil]) ([nil nil ((#("e" 0 1 (face font-lock-string-face fontified t)) . -69) (undo-tree-id8 . -1) (undo-tree-id9 . -1) (undo-tree-id10 . -1) (undo-tree-id11 . -1) (#("a" 0 1 (face font-lock-string-face fontified t)) . -70) (undo-tree-id12 . -1) (undo-tree-id13 . -1) (#("l" 0 1 (face font-lock-string-face fontified t)) . -71) (undo-tree-id14 . -1) (undo-tree-id15 . -1) 72) nil (25984 45016 203610 436000) 0 nil]) -([nil current ((69 . 71)) nil (25984 45016 203594 834000) 0 nil]) +([nil nil ((69 . 71)) nil (25984 45016 203594 834000) 0 nil]) +([nil nil ((nil rear-nonsticky nil 112 . 113) (#(" +" 0 1 (fontified nil)) . -177) (112 . 178) 101 (t 25984 45016 242127 193000)) nil (26005 48360 626922 471000) 0 nil]) +([nil nil ((157 . 161) (t 26005 48360 661988 281000)) nil (26005 48400 680711 788000) 0 nil]) +([nil current ((122 . 126)) nil (26005 48400 680708 268000) 0 nil]) nil diff --git a/test/Manifest.toml b/test/Manifest.toml index 35a648b..f7d791d 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.0-beta3" +julia_version = "1.10.0" manifest_format = "2.0" -project_hash = "107325876e5df39b0cb4915f34106a9f7d2947ee" +project_hash = "ecea5b9035b56afaf9684d09b30a2c06bccec517" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" @@ -82,11 +82,28 @@ git-tree-sha1 = "5372dbbf8f0bdb8c700db5367132925c0771ef7e" uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" version = "2.2.1" +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + [[deps.DataAPI]] git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" version = "1.15.0" +[[deps.DataFrames]] +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +version = "1.6.1" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.15" + [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" @@ -134,6 +151,11 @@ version = "1.4.0" deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[deps.InvertedIndices]] +git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.3.0" + [[deps.IteratorInterfaceExtensions]] git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" uuid = "82899510-4779-5014-852e-03e436cf321d" @@ -145,6 +167,11 @@ git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" version = "1.5.0" +[[deps.LaTeXStrings]] +git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.3.1" + [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" @@ -157,7 +184,7 @@ version = "0.6.4" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.0.1+1" +version = "8.4.0+0" [[deps.LibGit2]] deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] @@ -204,6 +231,12 @@ deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" version = "2.28.2+1" +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.1.0" + [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" @@ -260,6 +293,12 @@ git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e" uuid = "21216c6a-2e73-6563-6e65-726566657250" version = "1.4.1" +[[deps.PrettyTables]] +deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "88b895d13d53b5577fd53379d913b9ab9ac82660" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "2.3.1" + [[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -272,6 +311,11 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" @@ -294,6 +338,33 @@ uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.2.1" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.10.0" + +[[deps.StringManipulation]] +deps = ["PrecompileTools"] +git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" +version = "0.3.4" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.2.1+1" + [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" diff --git a/test/Project.toml b/test/Project.toml index 74b6467..c34f129 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index bb8f693..ac13831 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,3 +4,5 @@ using Test include("src/tabular.jl") include("src/integral.jl") include("src/domains/inventory.jl") +include("src/domains/make_domains.jl") +include("src/domains/solvers.jl") diff --git a/test/src/.runtests.jl.~undo-tree~ b/test/src/.runtests.jl.~undo-tree~ index 93263b2..7a8c095 100644 --- a/test/src/.runtests.jl.~undo-tree~ +++ b/test/src/.runtests.jl.~undo-tree~ @@ -1,5 +1,5 @@ (undo-tree-save-format-version . 1) -"93f6811ecfa1cdb8a882f389613babbb84d7bd17" +"0b97d8960b5c1c493e0e4910c2a1c37b79cf9a09" [nil nil nil nil (25921 45381 925302 206000) 0 nil] ([nil nil ((#("#end " 0 1 (face font-lock-comment-delimiter-face fontified t) 1 5 (face font-lock-comment-face fontified t)) . 198) (undo-tree-id10 . -3) (undo-tree-id11 . -4) (undo-tree-id12 . -5) (undo-tree-id13 . -5) (undo-tree-id14 . -5) (undo-tree-id15 . -5) (undo-tree-id16 . -5) (undo-tree-id17 . -5) (undo-tree-id18 . -5) (undo-tree-id19 . -5) (undo-tree-id20 . -4) (undo-tree-id21 . -4) (undo-tree-id22 . -3) (undo-tree-id23 . -3) (undo-tree-id24 . -3) (undo-tree-id25 . -3) (undo-tree-id26 . -3) (undo-tree-id27 . -3) (undo-tree-id28 . -3) (undo-tree-id29 . -5) (undo-tree-id30 . -4) 201 (t 25356 48225 350037 195000)) nil (25921 45381 925300 175000) 0 nil]) @@ -14,6 +14,19 @@ nil ([nil nil ((10 . 11)) nil (25921 45385 154707 840000) 0 nil]) ([nil nil ((#("include(\"mdpo.jl\") " 0 8 (fontified t) 8 17 (face font-lock-string-face fontified t) 17 19 (fontified t)) . 69) (undo-tree-id4 . -17) (undo-tree-id5 . -17) (undo-tree-id6 . -19) (undo-tree-id7 . -18) (undo-tree-id8 . -18) (undo-tree-id9 . -18) (undo-tree-id10 . -18) (undo-tree-id11 . -18) 86 (t 25921 45385 188546 528000)) nil (25986 17230 294986 665000) 0 nil]) -([nil current ((#("include(\"robust_mmdp.jl\") +([nil nil ((#("include(\"robust_mmdp.jl\") " 0 8 (fontified t) 8 24 (face font-lock-string-face fontified t) 24 25 (fontified t) 25 26 (fontified t)) . 69) (undo-tree-id0 . -17) (undo-tree-id1 . -26) (undo-tree-id2 . 26) (undo-tree-id3 . -25) 86) nil (25986 17230 294977 84000) 0 nil]) +([nil nil ((100 . 101) (t 25986 17230 305802 701000) 99) nil (26005 47962 10945 382000) 0 nil]) +([nil nil ((101 . 122)) nil (26005 47962 10944 740000) 0 nil]) +([nil nil ((122 . 131)) nil (26005 47962 10944 323000) 0 nil]) +([nil nil ((#("/" 0 1 (face font-lock-string-face fontified t)) . -130) (undo-tree-id58 . -1) (undo-tree-id59 . -1) 131) nil (26005 47962 10943 714000) 0 nil]) +([nil nil ((130 . 135)) nil (26005 47962 10942 356000) 0 nil]) +([nil nil ((135 . 136)) nil (26005 47962 10942 9000) 0 nil]) +([nil nil ((136 . 153)) nil (26005 47962 10941 528000) 0 nil]) +([nil nil ((#("d" 0 1 (face font-lock-string-face fontified t)) . -151) (undo-tree-id54 . -1) (undo-tree-id55 . -1) (#("/" 0 1 (face font-lock-string-face fontified t)) . -152) (undo-tree-id56 . -1) (undo-tree-id57 . -1) 153) nil (26005 47962 10940 767000) 0 nil]) +([nil nil ((151 . 157)) nil (26005 47962 10938 405000) 0 nil]) +([nil nil ((#("m" 0 1 (face font-lock-string-face fontified t)) . -153) (undo-tree-id27 . -1) (undo-tree-id28 . -1) (undo-tree-id29 . -1) (undo-tree-id30 . -1) (undo-tree-id31 . -1) (undo-tree-id32 . -1) (undo-tree-id33 . -1) (undo-tree-id34 . -1) (#("a" 0 1 (face font-lock-string-face fontified t)) . -154) (undo-tree-id35 . -1) (undo-tree-id36 . -1) (undo-tree-id37 . -1) (undo-tree-id38 . -1) (undo-tree-id39 . -1) (undo-tree-id40 . -1) (undo-tree-id41 . -1) (undo-tree-id42 . -1) (#("k" 0 1 (face font-lock-string-face fontified t)) . -155) (undo-tree-id43 . -1) (undo-tree-id44 . -1) (undo-tree-id45 . -1) (undo-tree-id46 . -1) (undo-tree-id47 . -1) (undo-tree-id48 . -1) (#("e" 0 1 (face font-lock-string-face fontified t)) . -156) (undo-tree-id49 . -1) (undo-tree-id50 . -1) (undo-tree-id51 . -1) (undo-tree-id52 . -1) (undo-tree-id53 . -1) 157) nil (26005 47962 10937 6000) 0 nil]) +([nil nil ((153 . 157)) nil (26005 47962 10925 584000) 0 nil]) +([nil nil ((#("m" 0 1 (face font-lock-string-face fontified t)) . -153) (undo-tree-id0 . -1) (undo-tree-id1 . -1) (undo-tree-id2 . -1) (undo-tree-id3 . -1) (undo-tree-id4 . -1) (undo-tree-id5 . -1) (undo-tree-id6 . -1) (undo-tree-id7 . -1) (#("a" 0 1 (face font-lock-string-face fontified t)) . -154) (undo-tree-id8 . -1) (undo-tree-id9 . -1) (undo-tree-id10 . -1) (undo-tree-id11 . -1) (undo-tree-id12 . -1) (undo-tree-id13 . -1) (undo-tree-id14 . -1) (undo-tree-id15 . -1) (#("k" 0 1 (face font-lock-string-face fontified t)) . -155) (undo-tree-id16 . -1) (undo-tree-id17 . -1) (undo-tree-id18 . -1) (undo-tree-id19 . -1) (undo-tree-id20 . -1) (undo-tree-id21 . -1) (#("e" 0 1 (face font-lock-string-face fontified t)) . -156) (undo-tree-id22 . -1) (undo-tree-id23 . -1) (undo-tree-id24 . -1) (undo-tree-id25 . -1) (undo-tree-id26 . -1) 157) nil (26005 47962 10923 141000) 0 nil]) +([nil current ((153 . 165)) nil (26005 47962 10894 582000) 0 nil]) nil diff --git a/test/src/domains/make_domains.jl b/test/src/domains/make_domains.jl new file mode 100644 index 0000000..851c88c --- /dev/null +++ b/test/src/domains/make_domains.jl @@ -0,0 +1,87 @@ +using Arrow +using MDPs.Domains +using CSV + + +struct Problem{M <: TabMDP} + γ :: Float64 + horizon :: Int + initstate :: Int + model :: M +end + +# creates a set of benchmark problems +function make_domains() + problems = Dict{String, Problem}() + # inventory + begin + # risk parameters + γ = 0.8 + initstate = 1 # initial state + horizon = 100 + # Define the inventory model + demand = Inventory.Demand([0,2,3,4,5,30,3,2], + [0.1,0.3,0.1,0.1,0.1,0.1,0.0,0.2]) + costs = Inventory.Costs(5.,2.,0.3,0.5) + limits = Inventory.Limits(100, 0, 50) + params = Inventory.Parameters(demand, costs, 16., limits) + model = Inventory.Model(params) + problems["inventory"] = Problem(γ, horizon, initstate, model) + end + #invetory_generic + begin + γ = 0.9 + filein = joinpath(dirname(pathof(MDPs)), "..", "data", "inventory.arr") + model = load_mdp(Arrow.Table(filein)) + initstate = 1 # initial state + horizon = 100 + problems["inventory_generic"] = Problem(γ, horizon, initstate, model) + end + # machine + begin + γ = 0.8 + initstate = 1 # initial state + horizon = 100 + model = Domains.Machine.Replacement() + problems["machine"] = Problem(γ, horizon, initstate, model) + end + # ruin + begin + α = 0.9 # var, cvar, evar + horizon = 200 + initstate = 8 # capital: state - 1 + model = Domains.Gambler.Ruin(0.7, 10) + problems["ruin"] = Problem(γ, horizon, initstate, model) + end + # riverswim + begin + filein = joinpath(dirname(pathof(MDPs)), "..", "data", "riverswim.csv") + model = load_mdp(CSV.File(filein); idoutcome = 1) + γ = 0.98 + horizon = 100 + initstate = 1 # initial state + problems["riverswim"] = Problem(γ, horizon, initstate, model) + end + # population + begin + filein = joinpath(dirname(pathof(MDPs)), "..", "data", "inventory.arr") + model = load_mdp(Arrow.Table(filein)) + α = 0.9 # var, cvar, evar + β = 0.5 # erm + γ = 0.7 + horizon = 50 + initstate = 1 # initial state + problems["population"] = Problem(γ, horizon, initstate, model) + end + # onestatepm + begin + model = Domains.Simple.OneStatePlusMinus(100) + initstate = 1 # initial state + horizon = 100 + γ = 0.95 + problems["onestatepm"] = Problem(γ, horizon, initstate, model) + end + problems +end + +#make_domains() diff --git a/test/src/domains/solvers.jl b/test/src/domains/solvers.jl new file mode 100644 index 0000000..7af7ad5 --- /dev/null +++ b/test/src/domains/solvers.jl @@ -0,0 +1,47 @@ +using MDPs +using DataFrames + +#include("domains/make_domains.jl") + +function solve_domain(probname, prob) + + episodes::Int = 10000 + + # evaluation helper variables + rweights::Vector{Float64} = prob.γ .^ (0:prob.horizon-1) # reward weights + edist::Vector{Float64} = ones(episodes) / episodes # distribution over episodes + + results::DataFrame = DataFrame() + + # Risk neutral solution + #println("Risk neutral infinite ...") + #time = @elapsed v = value_iteration(model, γ; ϵ = 0.1) + #π = greedy(model, γ, v.value) + #report_disc!(results, "Neutral, inf", π, v, time) + + # Risk-neutral finite + vp = value_iteration(prob.model, FiniteH(prob.γ, prob.horizon)) + v = vp.value + π = vp.policy + + # confirm using simulation + roundresult(x) = round(x; sigdigits = 3) + + H = simulate(prob.model, π, prob.initstate, prob.horizon, episodes) + returns = rweights' * H.rewards |> vec + rmean = sum(returns) / length(returns) + + @test rmean ≈ vp.value[1][prob.initstate] rtol = 0.05 + #println(rmean, " <===> ", vp.value[1][prob.initstate]) + #println(isapprox(rmean, vp.value[1][prob.initstate], rtol = 0.05)) +end + +@testset "Solve benchmark domains" begin + # general parameters + + domains::Dict{String, Problem} = make_domains() + + for (dname, domain) ∈ domains + solve_domain(dname, domain) + end +end diff --git a/test/src/tabular.jl b/test/src/tabular.jl index e69de29..c630172 100644 --- a/test/src/tabular.jl +++ b/test/src/tabular.jl @@ -0,0 +1,14 @@ +using CSV +using Arrow +using DataFrames + + +@testset "Serialize and load an MDP file" begin + filein = joinpath(dirname(pathof(MDPs)), "..", "data", "population.arr") + + model = load_mdp(Arrow.Table(filein)) + output = save_mdp(DataFrame, model) + model2 = load_mdp(output) + output2 = save_mdp(DataFrame, model2) + @test all(map(all, eachcol(output .≈ output2))) +end