Skip to content

Commit

Permalink
Merge branch 'garnet' of github.com:RiskAverseRL/MDPs.jl into garnet
Browse files Browse the repository at this point in the history
  • Loading branch information
Marek Petrik committed Jul 17, 2024
2 parents 28a21a8 + a1c2d66 commit aec6dc6
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 15 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@ version = "0.1.5"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
DataFrames = "1.6.1"
DataFramesMeta = "0.14.1"
julia = "1.9"
StatsBase = "0.34.2"
Distributions = "0.25.107"
StatsBase = "0.34.2"
julia = "1.9"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
3 changes: 3 additions & 0 deletions src/MDPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ export mrp!, mrp, mrp_sparse
include("algorithms/policyiteration.jl")
export policy_iteration, policy_iteration_sparse

include("algorithms/linprogsolve.jl")
export lp_solve

include("simulation.jl")
export simulate, random_π
export Policy, PolicyStationary, PolicyMarkov
Expand Down
31 changes: 31 additions & 0 deletions src/algorithms/linprogsolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using JuMP

# ----------------------------------------------------------------
# Linear Program Solver
# ----------------------------------------------------------------


"""
lp_solve(model, γ, lpm)
Implements the linear program primal problem for an MDP `model` with a discount factor `γ`.
It uses the JuMP model `lpm` as the linear program solver and returns the state values
found by `lpm`.
"""

function lp_solve(model::TabMDP, γ::Number, lpm)
0 γ < 1 || error("γ must be between 0 and 1")
set_silent(lpm)
n = state_count(model)
@variable(lpm, v[1:n])
@objective(lpm,Min, sum(v[1:n]))
for s in 1:n
m = action_count(model,s)
for a in 1:m
snext = transition(model,s,a)
@constraint(lpm, v[s] sum(sp[2]*(sp[3]+γ*v[sp[1]]) for sp in snext))
end
end
optimize!(lpm)
return value.(v)
end
5 changes: 5 additions & 0 deletions src/domains/garnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ struct GarnetMDP <: TabMDP
# TODO: add a constructor that checks for consistency
end

"""
A Garnet MDP is a tabular MDP where the number of next states available from any current state is a fixed proportion of the total number of states in the model.
This proportion is called "nbranch" and it must between 0 and 1.
"""

function make_garnet(S::Integer, A::AbstractVector{Int}, nbranch::Number, min_reward::Integer, max_reward::Integer)

0.0 nbranch 1.0 || error("nbranch must be in [0,1]")
Expand Down
23 changes: 12 additions & 11 deletions test/src/domains/garnet.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
using Main.MDPs
using MDPs.Domains
import HiGHS

@testset "Solve Garnet" begin

g = Domains.Garnet.GarnetMDP([[1,1],[2,0]],[[[1,0],[0,1]],[[0,1],[1,0]]],2,[2,2])
"""
g = Garnet.GarnetMDP([[1,1],[2,0]],[[[1,0],[0,1]],[[0,1],[1,0]]],2,[2,2])
simulate(g, random_π(g), 1, 10000, 500)
g1 = make_int_mdp(g; docompress=false)
g2 = make_int_mdp(g; docompress=true)

v1 = value_iteration(g, InfiniteH(0.99); ϵ=1e-7)
v2 = value_iteration(g1, InfiniteH(0.99); ϵ=1e-7)
v3 = value_iteration(g2, InfiniteH(0.99); ϵ=1e-7)
v1 = value_iteration(g, InfiniteH(0.95); ϵ=1e-10)
v2 = value_iteration(g1, InfiniteH(0.95); ϵ=1e-10)
v3 = value_iteration(g2, InfiniteH(0.95); ϵ=1e-10)
v4 = policy_iteration(g2, 0.95)
#v5 = lp_solve(g, .95, HiGHS.Optimizer)

# Ensure value functions are close
V = hcat(v1.value, v2.value[1:end-1], v3.value[1:end-1], v4.value[1:end-1])
V = hcat(v1.value, v2.value[1:end-1], v3.value[1:end-1], v4.value[1:end-1], v5)
@test map(x -> x[2] - x[1], mapslices(extrema, V; dims=2)) |> maximum 1e-6

# Ensure policies are identical
p1 = greedy(model, InfiniteH(0.95), v1.value)
p2 = greedy(model_g, InfiniteH(0.95), v2.value)
p3 = greedy(model_gc, InfiniteH(0.95), v3.value)
p1 = greedy(g, InfiniteH(0.95), v1.value)
p2 = greedy(g1, InfiniteH(0.95), v2.value)
p3 = greedy(g2, InfiniteH(0.95), v3.value)
p4 = v4.policy
p5 = greedy(g, InfiniteH(0.95), v5)

P = hcat(p1, p2[1:end-1], p3[1:end-1], p4[1:end-1])
@test all(mapslices(allequal, P; dims=2))
"""
end
6 changes: 4 additions & 2 deletions test/src/domains/inventory.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using MDPs.Domains
import HiGHS

@testset "Solve Inventory" begin

Expand Down Expand Up @@ -36,20 +37,21 @@ using MDPs.Domains
v2 = value_iteration(model_g, InfiniteH(0.95); ϵ = 1e-10)
v3 = value_iteration(model_gc, InfiniteH(0.95); ϵ = 1e-10)
v4 = policy_iteration(model_gc, 0.95)

#v5 = lp_solve(g, .95, HiGHS.Optimizer)

# note that the IntMDP does not have terminal states,
# so the last action will not be -1

#make sure value functions are close
V = hcat(v1.value, v2.value[1:(end-1)], v3.value[1:(end-1)], v4.value[1:(end-1)])
V = hcat(v1.value, v2.value[1:(end-1)], v3.value[1:(end-1)], v4.value[1:(end-1)], v5)
@test map(x->x[2] - x[1], mapslices(extrema, V; dims = 2)) |> maximum 1e-6

# make sure policies are identical
p1 = greedy(model, InfiniteH(0.95), v1.value)
p2 = greedy(model_g, InfiniteH(0.95), v2.value)
p3 = greedy(model_gc, InfiniteH(0.95), v3.value)
p4 = v4.policy
p5 = greedy(model, InfiniteH(0.95), v5)

P = hcat(p1, p2[1:(end-1)], p3[1:(end-1)], p4[1:(end-1)])
@test all(mapslices(allequal, P; dims = 2))
Expand Down

0 comments on commit aec6dc6

Please sign in to comment.