diff --git a/Project.toml b/Project.toml index c54c70e..8783c2a 100644 --- a/Project.toml +++ b/Project.toml @@ -6,13 +6,18 @@ version = "0.1.5" [deps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +JuMP = "4076af6c-e467-56ae-b986-b466b2749572" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] -julia = "1.9" DataFrames = "1.6.1" DataFramesMeta = "0.14.1" +Distributions = "0.25.107" +StatsBase = "0.34.2" +julia = "1.9" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/MDPs.jl b/src/MDPs.jl index 7dbccca..fa62da6 100644 --- a/src/MDPs.jl +++ b/src/MDPs.jl @@ -34,6 +34,9 @@ export mrp!, mrp, mrp_sparse include("algorithms/policyiteration.jl") export policy_iteration, policy_iteration_sparse +include("algorithms/linprogsolve.jl") +export lp_solve + include("simulation.jl") export simulate, random_π export Policy, PolicyStationary, PolicyMarkov @@ -45,6 +48,8 @@ export Transition module Domains include("domains/simple.jl") export Simple +include("domains/garnet.jl") +export Garnet include("domains/inventory.jl") export Inventory include("domains/machine.jl") diff --git a/src/algorithms/linprogsolve.jl b/src/algorithms/linprogsolve.jl new file mode 100644 index 0000000..6cf906d --- /dev/null +++ b/src/algorithms/linprogsolve.jl @@ -0,0 +1,33 @@ +using JuMP + +# ---------------------------------------------------------------- +# Linear Program Solver +# ---------------------------------------------------------------- + + +""" +lp_solve(model, γ, lpm) + +Implements the linear program primal problem for an MDP `model` with a discount factor `γ`. +It uses the JuMP model `lpm` as the linear program solver and returns the state values +found by `lpm`. +""" + +function lp_solve(model::TabMDP, γ::Number, lpm) + 0 ≤ γ < 1 || error("γ must be between 0 and 1") + set_silent(lpm) + n = state_count(model) + @variable(lpm, v[1:n]) + @objective(lpm,Min, sum(v[1:n])) + π::Vector{Vector{ConstraintRef}} = [] + for s in 1:n + m = action_count(model,s) + π_s::Vector{ConstraintRef} = [] + for a in 1:m + push!(π_s, @constraint(lpm, v[s] ≥ sum(sp[2]*(sp[3]+γ*v[sp[1]]) for sp in transition(model,s,a)))) + end + push!(π, π_s) + end + optimize!(lpm) + (value = value.(v), policy = map(x->argmax(dual.(x)), π)) +end \ No newline at end of file diff --git a/src/domains/garnet.jl b/src/domains/garnet.jl new file mode 100644 index 0000000..f91549f --- /dev/null +++ b/src/domains/garnet.jl @@ -0,0 +1,73 @@ +module Garnet + +import ...TabMDP, ...transition, ...state_count, ...action_count +import ...actions, ...states + +# TODO: are these reasonable or can we replace them? +import StatsBase, Distributions +# ---------------------------------------------------------------- +# A Garnet MDP +# ---------------------------------------------------------------- + +struct GarnetMDP <: TabMDP + reward::Vector{Vector{Float64}} + transition::Vector{Vector{Vector{Float64}}} + S::Int + A::Vector{Int} + + # TODO: add a constructor that checks for consistency +end + +""" +A Garnet MDP is a tabular MDP where the number of next states available from any current state is a fixed proportion of the total number of states in the model. +This proportion is called "nbranch" and it must between 0 and 1. +""" + +function make_garnet(S::Integer, A::AbstractVector{Int}, nbranch::Number, min_reward::Integer, max_reward::Integer) + + 0.0 ≤ nbranch ≤ 1.0 || error("nbranch must be in [0,1]") + + reward = Vector{Vector{Float64}}() + transition = Vector{Vector{Vector{Float64}}}() + dist = Distributions.Exponential(1) + sout = Int(round(nbranch*S)) + + for i in 1:S + r = Vector{Float64}() + p = Vector{Vector{Float64}}() + for j in 1:A[i] + push!(r, rand(min_reward:max_reward)) + inds = StatsBase.sample(1:S, sout, replace=false) + z = rand(dist,sout) + z /= sum(z) + pp = zeros(S) + for (k,l) in enumerate(inds) pp[l] = z[k] end + push!(p,pp) + end + push!(reward,r) + push!(transition, p) + end + + GarnetMDP(reward,transition,S,A) +end + +make_garnet(S::Integer, A::Integer, nbranch, min_reward, max_reward) = make_garnet(S, fill(Int(A),S), nbranch, min_reward, max_reward) + +function transition(model::GarnetMDP, state::Int, action::Int) + @assert state in 1:model.S + @assert action in 1:model.A[state] + + next = [] + for (s,p) in enumerate(model.transition[state][action]) + if p != 0 + push!(next, (s,p,model.reward[state][action])) + end + end + return next +end + +state_count(model::GarnetMDP) = model.S +action_count(model::GarnetMDP, s::Int) = model.A[s] + +end +# Module: Garnet diff --git a/test/Manifest.toml b/test/Manifest.toml index f7d791d..35eee0e 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.0" +julia_version = "1.10.1" manifest_format = "2.0" -project_hash = "ecea5b9035b56afaf9684d09b30a2c06bccec517" +project_hash = "de2366d46b86291191f09b798751f58a1255d4a2" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" @@ -26,12 +26,24 @@ uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +[[deps.BenchmarkTools]] +deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"] +git-tree-sha1 = "f1dff6729bc61f4d49e140da1af55dcd1ac97b2f" +uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +version = "1.5.0" + [[deps.BitIntegers]] deps = ["Random"] git-tree-sha1 = "a55462dfddabc34bc97d3a7403a2ca2802179ae6" uuid = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1" version = "0.3.1" +[[deps.Bzip2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd" +uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" +version = "1.0.8+1" + [[deps.CEnum]] git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -43,6 +55,12 @@ git-tree-sha1 = "44dbf560808d49041989b8a96cae4cffbeb7966a" uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" version = "0.10.11" +[[deps.CodecBzip2]] +deps = ["Bzip2_jll", "Libdl", "TranscodingStreams"] +git-tree-sha1 = "f8889d1770addf59d0a015c49a473fa2bdb9f809" +uuid = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd" +version = "0.8.3" + [[deps.CodecLz4]] deps = ["Lz4_jll", "TranscodingStreams"] git-tree-sha1 = "59fe0cb37784288d6b9f1baebddbf75457395d40" @@ -61,6 +79,12 @@ git-tree-sha1 = "849470b337d0fa8449c21061de922386f32949d9" uuid = "6b39b394-51ab-5f42-8807-6242bab2b4c2" version = "0.7.2" +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + [[deps.Compat]] deps = ["UUIDs"] git-tree-sha1 = "8a62af3e248a8c4bad6b32cbbe663ae02275e32c" @@ -74,7 +98,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.5+1" +version = "1.1.0+0" [[deps.ConcurrentUtilities]] deps = ["Serialization", "Sockets"] @@ -113,6 +137,24 @@ version = "1.0.0" deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" @@ -137,10 +179,34 @@ version = "0.9.21" [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.36" + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + + [deps.ForwardDiff.weakdeps] + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + [[deps.Future]] deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" +[[deps.HiGHS]] +deps = ["HiGHS_jll", "MathOptInterface", "PrecompileTools", "SparseArrays"] +git-tree-sha1 = "1042e72e93e5916bbfe034576f2fc2fae73d5ec7" +uuid = "87dc4568-4c63-4d18-b0c0-bb2238e4078b" +version = "1.9.1" + +[[deps.HiGHS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Zlib_jll"] +git-tree-sha1 = "656db2048ed731484df16fc91e7232a190e330fb" +uuid = "8fd58aa0-07eb-5a78-9b36-339c94fd15ea" +version = "1.7.1+0" + [[deps.InlineStrings]] deps = ["Parsers"] git-tree-sha1 = "9cc2baf75c6d09f9da536ddf58eb2f29dedaf461" @@ -156,6 +222,11 @@ git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" version = "1.3.0" +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + [[deps.IteratorInterfaceExtensions]] git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" uuid = "82899510-4779-5014-852e-03e436cf321d" @@ -167,6 +238,24 @@ git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" version = "1.5.0" +[[deps.JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.4" + +[[deps.JuMP]] +deps = ["LinearAlgebra", "MacroTools", "MathOptInterface", "MutableArithmetics", "OrderedCollections", "PrecompileTools", "Printf", "SparseArrays"] +git-tree-sha1 = "7e10a0d8b534f2d8e9f712b33488584254624fb1" +uuid = "4076af6c-e467-56ae-b986-b466b2749572" +version = "1.22.2" + + [deps.JuMP.extensions] + JuMPDimensionalDataExt = "DimensionalData" + + [deps.JuMP.weakdeps] + DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" + [[deps.LaTeXStrings]] git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" @@ -207,6 +296,22 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.28" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -222,10 +327,22 @@ git-tree-sha1 = "6c26c5e8a4203d43b5497be3ec5d4e0c3cde240a" uuid = "5ced341a-0733-55b8-9ab6-a4889d929147" version = "1.9.4+0" +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.13" + [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +[[deps.MathOptInterface]] +deps = ["BenchmarkTools", "CodecBzip2", "CodecZlib", "DataStructures", "ForwardDiff", "JSON", "LinearAlgebra", "MutableArithmetics", "NaNMath", "OrderedCollections", "PrecompileTools", "Printf", "SparseArrays", "SpecialFunctions", "Test", "Unicode"] +git-tree-sha1 = "91b08d27a27d83cf1e63e50837403e7f53a0fd74" +uuid = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" +version = "1.31.0" + [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" @@ -250,6 +367,18 @@ version = "0.7.7" uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2023.1.10" +[[deps.MutableArithmetics]] +deps = ["LinearAlgebra", "SparseArrays", "Test"] +git-tree-sha1 = "898c56fbf8bf71afb0c02146ef26f3a454e88873" +uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" +version = "1.4.5" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.2" + [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" @@ -257,7 +386,18 @@ version = "1.2.0" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+2" +version = "0.3.23+4" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+2" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" [[deps.OrderedCollections]] git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" @@ -303,6 +443,10 @@ version = "2.3.1" deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +[[deps.Profile]] +deps = ["Printf"] +uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" + [[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" @@ -349,6 +493,23 @@ deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" version = "1.10.0" +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.4.0" + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + + [deps.SpecialFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.3" + [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/test/Project.toml b/test/Project.toml index c34f129..14e823e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,4 +2,6 @@ Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b" +JuMP = "4076af6c-e467-56ae-b986-b466b2749572" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index ac13831..111ca54 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,5 +4,6 @@ using Test include("src/tabular.jl") include("src/integral.jl") include("src/domains/inventory.jl") +include("src/domains/garnet.jl") include("src/domains/make_domains.jl") include("src/domains/solvers.jl") diff --git a/test/src/domains/.inventory.jl.~undo-tree~ b/test/src/domains/.inventory.jl.~undo-tree~ index c4fcc49..fd6b827 100644 --- a/test/src/domains/.inventory.jl.~undo-tree~ +++ b/test/src/domains/.inventory.jl.~undo-tree~ @@ -1,42 +1,9 @@ (undo-tree-save-format-version . 1) -"b8ac467663076265d2b8b3e1ed6e71547c1d01d7" -[nil nil nil nil (25921 45772 644259 104000) 0 nil] -([nil nil ((#("RobustRL" 0 8 (fontified t)) . 7) (undo-tree-id0 . -8) (undo-tree-id1 . -8) (undo-tree-id2 . -7) (t 25356 47990 803365 550000)) nil (25921 45772 644257 331000) 0 nil]) -([nil nil ((7 . 11)) nil (25921 45772 644239 207000) 0 nil]) -([nil nil ((1058 . 1061) (#("generic" 0 7 (fontified t)) . 1058) (undo-tree-id37 . -7) (undo-tree-id38 . -7) (undo-tree-id39 . -6) (t 25921 45772 655742 181000)) nil (25985 42540 497413 44000) 0 nil]) -([nil nil ((1114 . 1117) (#("generic" 0 7 (fontified t)) . 1114) (undo-tree-id31 . -6) (undo-tree-id32 . 5) (undo-tree-id33 . -1) (undo-tree-id34 . -1) (undo-tree-id35 . -7) (undo-tree-id36 . -7)) nil (25985 42540 497410 32000) 0 nil]) -([nil nil ((#(" " 0 1 (face font-lock-comment-face fontified t)) . -1366) (undo-tree-id0 . -1) (undo-tree-id1 . -1) 1367 (1367 . 1370) (#("w" 0 1 (face font-lock-comment-face fontified t)) . -1366) (undo-tree-id2 . -1) (undo-tree-id3 . -1) 1367 (1366 . 1367) (#("generic" 0 7 (face font-lock-comment-face fontified t)) . 1366) (undo-tree-id4 . -6) (undo-tree-id5 . -2) (undo-tree-id6 . -2) (undo-tree-id7 . -2) (undo-tree-id8 . -2) (undo-tree-id9 . -3) (undo-tree-id10 . -3) (undo-tree-id11 . -3) (undo-tree-id12 . -3) (undo-tree-id13 . -4) (undo-tree-id14 . -4) (undo-tree-id15 . -4) (undo-tree-id16 . -4) (undo-tree-id17 . -5) (undo-tree-id18 . -5) (undo-tree-id19 . -5) (undo-tree-id20 . -5) (undo-tree-id21 . -6) (undo-tree-id22 . -6) (undo-tree-id23 . -6) (undo-tree-id24 . -6) (undo-tree-id25 . -7) (undo-tree-id26 . -7) (undo-tree-id27 . -7) (undo-tree-id28 . -7) (undo-tree-id29 . -7) (undo-tree-id30 . -7)) nil (25985 42540 497402 987000) 0 nil]) -([nil nil ((1182 . 1187) (t 25985 42540 533735 444000)) nil (25985 42576 573749 436000) 0 nil]) -([nil nil ((#("Infin" 0 5 (fontified t)) . -1182) (1187 . 1196) 1187) nil (25985 42576 573748 709000) 0 nil]) -([nil nil ((1191 . 1192)) nil (25985 42576 573747 970000) 0 nil]) -([nil nil ((1196 . 1197)) nil (25985 42576 573747 493000) 0 nil]) -([nil nil ((1244 . 1249)) nil (25985 42576 573747 8000) 0 nil]) -([nil nil ((#("Infin" 0 5 (fontified t)) . -1244) (1249 . 1258) 1249) nil (25985 42576 573746 312000) 0 nil]) -([nil nil ((1253 . 1254)) nil (25985 42576 573745 468000) 0 nil]) -([nil nil ((1258 . 1259)) nil (25985 42576 573744 770000) 0 nil]) -([nil nil ((1307 . 1312)) nil (25985 42576 573744 226000) 0 nil]) -([nil nil ((#("Infin" 0 5 (fontified t)) . -1307) (1312 . 1321) 1312) nil (25985 42576 573742 873000) 0 nil]) -([nil nil ((1316 . 1317)) nil (25985 42576 573741 450000) 0 nil]) -([nil nil ((1321 . 1322)) nil (25985 42576 573737 814000) 0 nil]) -([nil nil ((1749 . 1754) (t 25985 42576 610358 510000)) nil (25985 42598 406494 861000) 0 nil]) -([nil nil ((#("Infin" 0 5 (fontified t)) . -1749) (1754 . 1763) 1754) nil (25985 42598 406494 267000) 0 nil]) -([nil nil ((1758 . 1759)) nil (25985 42598 406493 707000) 0 nil]) -([nil nil ((1763 . 1764)) nil (25985 42598 406493 241000) 0 nil]) -([nil nil ((#(" " 0 1 (fontified t)) . -1766) (undo-tree-id46 . -1) (undo-tree-id47 . -1) (#(" " 0 1 (fontified t)) . -1767) (undo-tree-id48 . -1) (undo-tree-id49 . -1) 1768) nil (25985 42598 406492 394000) 0 nil]) -([nil nil ((1802 . 1806)) nil (25985 42598 406489 695000) 0 nil]) -([nil nil ((#("f" 0 1 (fontified t)) . -1803) (undo-tree-id40 . -1) (undo-tree-id41 . -1) (#("i" 0 1 (fontified t)) . -1804) (undo-tree-id42 . -1) (undo-tree-id43 . -1) (#("n" 0 1 (fontified t)) . -1805) (undo-tree-id44 . -1) (undo-tree-id45 . -1) 1806) nil (25985 42598 406488 681000) 0 nil]) -([nil nil ((1803 . 1806)) nil (25985 42598 406477 309000) 0 nil]) -([nil nil ((#("Infi" 0 4 (fontified t)) . -1802) (1806 . 1815) 1806) nil (25985 42598 406476 629000) 0 nil]) -([nil nil ((1811 . 1812)) nil (25985 42598 406475 887000) 0 nil]) -([nil nil ((1816 . 1817)) nil (25985 42598 406475 486000) 0 nil]) -([nil nil ((1856 . 1861)) nil (25985 42598 406475 30000) 0 nil]) -([nil nil ((#("Infin" 0 5 (fontified t)) . -1856) (1861 . 1870) 1861) nil (25985 42598 406473 855000) 0 nil]) -([nil nil ((1865 . 1866)) nil (25985 42598 406472 336000) 0 nil]) -([nil nil ((1870 . 1871)) nil (25985 42598 406467 925000) 0 nil]) -([nil nil ((#(" -" 0 1 (fontified t)) . 1377) (undo-tree-id0 . -1) (t 26265 12265 956803 277000)) nil (26265 19756 387981 380000) 0 nil]) -([nil nil ((#(" -" 0 1 (fontified t)) . 2012) (undo-tree-id1 . -1) (t 26265 19756 401304 619000)) nil (26265 19758 713797 213000) 0 nil]) -([nil current ((#(" -" 0 5 (fontified t)) . 315) (undo-tree-id2 . -4) (undo-tree-id3 . -3) (undo-tree-id4 . -4) (undo-tree-id5 . -4) (undo-tree-id6 . -4) (undo-tree-id7 . -4) (undo-tree-id8 . -4) (undo-tree-id9 . -4) (undo-tree-id10 . -4) (undo-tree-id11 . -4) (undo-tree-id12 . -4) (undo-tree-id13 . -4) (undo-tree-id14 . -3) (undo-tree-id15 . -3) (undo-tree-id16 . -3) (undo-tree-id17 . -3) (undo-tree-id18 . -3) (undo-tree-id19 . -5) 318 (t 26265 19758 724525 17000)) nil (26265 19761 492996 643000) 0 nil]) +"7f5c20e0493d52a78b3976a36ddef4420bed61da" +[nil nil nil nil (26303 463 21508 884000) 0 nil] +([nil nil ((#("======= +>>>>>>> main +" 0 7 (face smerge-markers fontified t) 7 8 (face nil fontified t) 8 21 (face smerge-markers fontified t)) . 1463) (undo-tree-id3 . -8) (undo-tree-id4 . -8) (undo-tree-id5 . -8) (undo-tree-id6 . -8) (undo-tree-id7 . -21) (undo-tree-id8 . -21) (undo-tree-id9 . -21) (undo-tree-id10 . -21) (undo-tree-id11 . -8) (undo-tree-id12 . -8) (undo-tree-id13 . -21) (undo-tree-id14 . -20) 1471 (t 26303 431 259432 314000)) nil (26303 463 21506 379000) 0 nil]) +([nil current ((#("<<<<<<< HEAD +" 0 1 (face smerge-markers smerge-refine-part (13 . 2) fontified t) 1 13 (face smerge-markers fontified t)) . 1391) (undo-tree-id0 . -13) (undo-tree-id1 . -12) (undo-tree-id2 . -13)) nil (26303 463 21492 460000) 0 nil]) nil diff --git a/test/src/domains/garnet.jl b/test/src/domains/garnet.jl new file mode 100644 index 0000000..1054a5e --- /dev/null +++ b/test/src/domains/garnet.jl @@ -0,0 +1,30 @@ +using MDPs.Domains +import HiGHS, JuMP + +@testset "Solve Garnet" begin + + g = Garnet.GarnetMDP([[1,1],[2,0]],[[[1,0],[0,1]],[[0,1],[1,0]]],2,[2,2]) + simulate(g, random_π(g), 1, 10000, 500) + g1 = make_int_mdp(g; docompress=false) + g2 = make_int_mdp(g; docompress=true) + + v1 = value_iteration(g, InfiniteH(0.95); ϵ=1e-10) + v2 = value_iteration(g1, InfiniteH(0.95); ϵ=1e-10) + v3 = value_iteration(g2, InfiniteH(0.95); ϵ=1e-10) + v4 = policy_iteration(g2, 0.95) + v5 = lp_solve(g, .95, JuMP.Model(HiGHS.Optimizer)) + + # Ensure value functions are close + V = hcat(v1.value, v2.value[1:end-1], v3.value[1:end-1], v4.value[1:end-1], v5.value) + @test map(x -> x[2] - x[1], mapslices(extrema, V; dims=2)) |> maximum ≤ 1e-6 + + # Ensure policies are identical + p1 = greedy(g, InfiniteH(0.95), v1.value) + p2 = greedy(g1, InfiniteH(0.95), v2.value) + p3 = greedy(g2, InfiniteH(0.95), v3.value) + p4 = v4.policy + p5 = v5.policy + + P = hcat(p1, p2[1:end-1], p3[1:end-1], p4[1:end-1]) + @test all(mapslices(allequal, P; dims=2)) +end diff --git a/test/src/domains/inventory.jl b/test/src/domains/inventory.jl index d61fc77..e693da1 100644 --- a/test/src/domains/inventory.jl +++ b/test/src/domains/inventory.jl @@ -1,4 +1,5 @@ using MDPs.Domains +import HiGHS, JuMP @testset "Solve Inventory" begin @@ -35,12 +36,13 @@ using MDPs.Domains v2 = value_iteration(model_g, InfiniteH(0.95); ϵ = 1e-10) v3 = value_iteration(model_gc, InfiniteH(0.95); ϵ = 1e-10) v4 = policy_iteration(model_gc, 0.95) + v5 = lp_solve(model, .95, JuMP.Model(HiGHS.Optimizer)) # note that the IntMDP does not have terminal states, # so the last action will not be -1 #make sure value functions are close - V = hcat(v1.value, v2.value[1:(end-1)], v3.value[1:(end-1)], v4.value[1:(end-1)]) + V = hcat(v1.value, v2.value[1:(end-1)], v3.value[1:(end-1)], v4.value[1:(end-1)], v5.value) @test map(x->x[2] - x[1], mapslices(extrema, V; dims = 2)) |> maximum ≤ 1e-6 # make sure policies are identical @@ -48,6 +50,7 @@ using MDPs.Domains p2 = greedy(model_g, InfiniteH(0.95), v2.value) p3 = greedy(model_gc, InfiniteH(0.95), v3.value) p4 = v4.policy + p5 = v5.policy P = hcat(p1, p2[1:(end-1)], p3[1:(end-1)], p4[1:(end-1)]) @test all(mapslices(allequal, P; dims = 2))