Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Garnet #2

Merged
merged 30 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a97a52f
Added garnet MDPs to domains
keithbadger Jun 25, 2024
641be7d
added garnet domain
keithbadger Jun 25, 2024
ae9ee73
implemented garnet MDP constructor
keithbadger Jun 26, 2024
3bcef5f
fixed transition function
keithbadger Jun 26, 2024
26841ff
added tests, fail to recognize Garnet class though
keithbadger Jul 3, 2024
057a847
updating exporting
keithbadger Jul 3, 2024
3eed2bc
dependencies and bugs fixed
Jul 3, 2024
2f604d9
committing before switch
keithbadger Jul 8, 2024
205a803
Merge branch 'garnet' of https://github.com/RiskAverseRL/MDPs.jl into…
keithbadger Jul 8, 2024
ffe7a82
added linear program method for solving infinite horizon MDPs
keithbadger Jul 9, 2024
6c0c097
added linprogsolve to exports
keithbadger Jul 9, 2024
e023389
changed linear_program_solve to require gamma as float rather than en…
keithbadger Jul 9, 2024
01d1058
Added explanation of garnet MDPs
keithbadger Jul 9, 2024
75203a3
added tests for linear program solver and fixed tests for garnet domain
keithbadger Jul 10, 2024
61d52c3
fixed linprogsolve output and added optimizer options
keithbadger Jul 10, 2024
d2d3acc
updated linear program solver
keithbadger Jul 16, 2024
f3b3ef4
fixed issues with lp_solve vs linear_program_solve function naming
keithbadger Jul 16, 2024
a1c2d66
improved documentation for lp_solve()
keithbadger Jul 16, 2024
28a21a8
added HiGHS LP dependency to the test
Jul 17, 2024
aec6dc6
Merge branch 'garnet' of github.com:RiskAverseRL/MDPs.jl into garnet
Jul 17, 2024
341ace3
fixed some bugs
Jul 17, 2024
23b406b
added HiGHS and JuMP to test dependencies and removed HiGHS for MDPs …
keithbadger Jul 18, 2024
53594d2
?
keithbadger Jul 19, 2024
0f65735
Merge branch 'garnet' of https://github.com/RiskAverseRL/MDPs.jl into…
keithbadger Jul 19, 2024
c08b0a6
removed unneeded variable from lp_solve
keithbadger Jul 19, 2024
78dbd47
added optimal policy output to lp_solve using dual variables
keithbadger Jul 22, 2024
d764c04
Merge branch 'garnet' of https://github.com/RiskAverseRL/MDPs.jl into…
keithbadger Jul 22, 2024
63b704c
Delete test/src/domains/.inventory.jl.~undo-tree~
marekpetrik Aug 16, 2024
935ab0f
removed a file that is not supposed to be in the repo
Aug 16, 2024
c4d6ca5
Merge branch 'main' into garnet
Aug 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@ version = "0.1.5"
[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
julia = "1.9"
DataFrames = "1.6.1"
DataFramesMeta = "0.14.1"
Distributions = "0.25.107"
StatsBase = "0.34.2"
julia = "1.9"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
5 changes: 5 additions & 0 deletions src/MDPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ export mrp!, mrp, mrp_sparse
include("algorithms/policyiteration.jl")
export policy_iteration, policy_iteration_sparse

include("algorithms/linprogsolve.jl")
export lp_solve

include("simulation.jl")
export simulate, random_π
export Policy, PolicyStationary, PolicyMarkov
Expand All @@ -45,6 +48,8 @@ export Transition
module Domains
include("domains/simple.jl")
export Simple
include("domains/garnet.jl")
export Garnet
include("domains/inventory.jl")
export Inventory
include("domains/machine.jl")
Expand Down
33 changes: 33 additions & 0 deletions src/algorithms/linprogsolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using JuMP

# ----------------------------------------------------------------
# Linear Program Solver
# ----------------------------------------------------------------


"""
lp_solve(model, γ, lpm)

Implements the linear program primal problem for an MDP `model` with a discount factor `γ`.
It uses the JuMP model `lpm` as the linear program solver and returns the state values
found by `lpm`.
"""

function lp_solve(model::TabMDP, γ::Number, lpm)
0 ≤ γ < 1 || error("γ must be between 0 and 1")
set_silent(lpm)
n = state_count(model)
@variable(lpm, v[1:n])
@objective(lpm,Min, sum(v[1:n]))
π::Vector{Vector{ConstraintRef}} = []
for s in 1:n
m = action_count(model,s)
π_s::Vector{ConstraintRef} = []
for a in 1:m
push!(π_s, @constraint(lpm, v[s] ≥ sum(sp[2]*(sp[3]+γ*v[sp[1]]) for sp in transition(model,s,a))))
end
push!(π, π_s)
end
optimize!(lpm)
(value = value.(v), policy = map(x->argmax(dual.(x)), π))
end
73 changes: 73 additions & 0 deletions src/domains/garnet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
module Garnet

import ...TabMDP, ...transition, ...state_count, ...action_count
import ...actions, ...states

# TODO: are these reasonable or can we replace them?
import StatsBase, Distributions
# ----------------------------------------------------------------
# A Garnet MDP
# ----------------------------------------------------------------

struct GarnetMDP <: TabMDP
reward::Vector{Vector{Float64}}
transition::Vector{Vector{Vector{Float64}}}
S::Int
A::Vector{Int}

# TODO: add a constructor that checks for consistency
end

"""
A Garnet MDP is a tabular MDP where the number of next states available from any current state is a fixed proportion of the total number of states in the model.
This proportion is called "nbranch" and it must between 0 and 1.
"""

function make_garnet(S::Integer, A::AbstractVector{Int}, nbranch::Number, min_reward::Integer, max_reward::Integer)

0.0 ≤ nbranch ≤ 1.0 || error("nbranch must be in [0,1]")

reward = Vector{Vector{Float64}}()
transition = Vector{Vector{Vector{Float64}}}()
dist = Distributions.Exponential(1)
sout = Int(round(nbranch*S))

for i in 1:S
r = Vector{Float64}()
p = Vector{Vector{Float64}}()
for j in 1:A[i]
push!(r, rand(min_reward:max_reward))
inds = StatsBase.sample(1:S, sout, replace=false)
z = rand(dist,sout)
z /= sum(z)
pp = zeros(S)
for (k,l) in enumerate(inds) pp[l] = z[k] end
push!(p,pp)
end
push!(reward,r)
push!(transition, p)
end

GarnetMDP(reward,transition,S,A)
end

make_garnet(S::Integer, A::Integer, nbranch, min_reward, max_reward) = make_garnet(S, fill(Int(A),S), nbranch, min_reward, max_reward)

function transition(model::GarnetMDP, state::Int, action::Int)
@assert state in 1:model.S
@assert action in 1:model.A[state]

next = []
for (s,p) in enumerate(model.transition[state][action])
if p != 0
push!(next, (s,p,model.reward[state][action]))
end
end
return next
end

state_count(model::GarnetMDP) = model.S
action_count(model::GarnetMDP, s::Int) = model.A[s]

end
# Module: Garnet
Loading
Loading