Skip to content

Commit

Permalink
Merge pull request #8 from RiskAverseRL/transient
Browse files Browse the repository at this point in the history
Transient MDPs
  • Loading branch information
marekpetrik authored Sep 10, 2024
2 parents 3eac75f + 58a696a commit eb8b45a
Show file tree
Hide file tree
Showing 21 changed files with 498 additions and 118 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
matrix:
version:
- 1.9
- 1.10
os:
- ubuntu-latest
#- macOS-latest
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: '1.9'
version: '1.10'
- name: Install dependencies
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
- name: Build and deploy
Expand Down
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
DataFrames = "1.6.1"
DataFramesMeta = "0.14.1"
Distributions = "0.25.107"
StatsBase = "0.34.2"
julia = "1.9"
DataFramesMeta = "0.15"
Distributions = "0.25"
StatsBase = "0.34"
julia = "1.10"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
10 changes: 10 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ Pages = ["mrp.jl"]
Modules = [MDPs]
Pages = ["policyiteration.jl"]
```

```@autodocs
Modules = [MDPs]
Pages = ["transient.jl"]
```

## Value Function Manipulation

```@autodocs
Expand Down Expand Up @@ -95,3 +101,7 @@ Modules = [MDPs.Domains.Inventory]
```@autodocs
Modules = [MDPs.Domains.Machine]
```

```@autodocs
Modules = [MDPs.Domains.GridWorld]
```
7 changes: 6 additions & 1 deletion src/MDPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ module MDPs

include("objectives.jl")
export InfiniteH, FiniteH, Markov, Stationary, MarkovDet, StationaryDet
export TotalReward

include("models/mdp.jl")
export MDP
export getnext, transition, isterminal
export getnext, transition
export valuefunction


Expand Down Expand Up @@ -37,6 +38,10 @@ export policy_iteration, policy_iteration_sparse
include("algorithms/linprogsolve.jl")
export lp_solve

include("algorithms/transient.jl")
export lp_solve, anytransient, alltransient
export isterminal

include("simulation.jl")
export simulate, random_π
export Policy, PolicyStationary, PolicyMarkov
Expand Down
61 changes: 46 additions & 15 deletions src/algorithms/linprogsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,59 @@ using JuMP


"""
lp_solve(model, γ, lpm)
lp_solve(model, γ, lpmf, [silent = true])
Implements the linear program primal problem for an MDP `model` with a discount factor `γ`.
It uses the JuMP model `lpm` as the linear program solver and returns the state values
found by `lpm`.
found by `lpmf`. The `lpmf` is a factory that can be passed to `JuMP.Model`.
The function needs to be provided with a solver. See the example below.
# Example
```jldoctest
using MDPs, HiGHS
model = Domains.Gambler.Ruin(0.5, 10)
val = lp_solve(model, 0.9, HiGHS.Optimizer)
maximum(val.policy)
# output
6
```
"""

function lp_solve(model::TabMDP, γ::Number, lpm)
0 γ < 1 || error("γ must be between 0 and 1")
set_silent(lpm)
function lp_solve(model::TabMDP, obj::InfiniteH, lpmf; silent = true)
γ = discount(obj)
0 γ < 1 || error("γ must be between 0 and 1.")


lpm = Model(lpmf)
silent && set_silent(lpm)
n = state_count(model)

@variable(lpm, v[1:n])
@objective(lpm,Min, sum(v[1:n]))
π::Vector{Vector{ConstraintRef}} = []
for s in 1:n
m = action_count(model,s)
π_s::Vector{ConstraintRef} = []
for a in 1:m
push!(π_s, @constraint(lpm, v[s] sum(sp[2]*(sp[3]+γ*v[sp[1]]) for sp in transition(model,s,a))))
end
push!(π, π_s)
@objective(lpm, Min, sum(v[1:n]))

u = Vector{Vector{ConstraintRef}}(undef, n)
for s 1:n
u[s] = [@constraint(lpm, v[s] sum(sp[2]*(sp[3]+γ*v[sp[1]])
for sp in transition(model,s,a)))
for a 1:action_count(model,s)]
end

optimize!(lpm)
(value = value.(v), policy = map(x->argmax(dual.(x)), π))

is_solved_and_feasible(lpm; dual = true) ||
error("Failed to solve the MDP linear program")

(value = value.(v),
policy = map(x->argmax(dual.(x)), u))
end

lp_solve(model::TabMDP, γ::Number, lpm; args...) =
lp_solve(model, InfiniteH(γ), lpm; args...)



15 changes: 5 additions & 10 deletions src/algorithms/mrp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,11 @@ function mrp!(P_π::AbstractMatrix{<:Real}, r_π::AbstractVector{<:Real},
S = state_count(model)
fill!(P_π, 0.); fill!(r_π, 0.)
for s 1:S
#TODO: remove the definition of terminal states
if !isterminal(model, s)
for (sn, p, r) transition(model, s, π[s])
P_π[s,sn] 0. ||
error("duplicated transition entries (s1->s2, s1->s2) not allowed")
P_π[s,sn] += p
r_π[s] += p * r
end
else
r_π[s] = reward_T(model, s)
for (sn, p, r) transition(model, s, π[s])
P_π[s,sn] 0. ||
error("duplicated transition entries (s1->s2, s1->s2) not allowed")
P_π[s,sn] += p
r_π[s] += p * r
end
end
end
Expand Down
150 changes: 150 additions & 0 deletions src/algorithms/transient.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
using JuMP

# ----------------------------------------------------------------
# Linear Program Solver
# ----------------------------------------------------------------


"""
isterminal(model, state)
Checks that the `state` is terminal in `model`. A state is terminal if it
1) has a single action,
2) transitions to itself,
3) has a reward 0.
# Example
```jldoctest
using MDPs
model = Domains.Gambler.RuinTransient(0.5, 4, true)
isterminal.((model,), states(model))[1:2]
# output
2-element BitVector:
1
0
```
"""
function isterminal(model::MDP{S,A}, state::S) where {S,A}
as = actions(model, state)
length(as) == 1 || return false
trs = transition(model, state, first(actions(model, state)))
length(trs) == 1 || return false
t = first(trs)
(t[1] == state && t[2] 1.0 && t[3] 0.0) || return false
return true
end


# a helper function used to check for transience
# reward: a function that specifies whether the reward
# from the MDP is used or a custom reward
# the function treats terminal states as having value 0
function _transient_lp(model::TabMDP, reward::Union{Float64, Nothing},
lpmf; silent) :: Union{Nothing,NamedTuple}

@assert minimum(states(model)) == 1 # make sure that the index is 1-based

lpm = Model(lpmf)
silent && set_silent(lpm)

rew(r) = isnothing(reward) ? r :: Float64 : reward :: Float64

n = state_count(model)

@variable(lpm, v[1:n])
@objective(lpm, Min, sum(v))

u = Vector{Vector{ConstraintRef}}(undef, n)
for s 1:n
@assert minimum(actions(model,s)) == 1 # make sure that the index is 1-based
if isterminal(model, s) # set terminal state(s) to 0 value
u[s] = [@constraint(lpm, v[s] == 0)]
else
u[s] = [@constraint(lpm, v[s] sum(p*(rew(r) + v[sn])
for (sn,p,r) transition(model,s,a)))
for a in actions(model,s)]
end
end

optimize!(lpm)

if is_solved_and_feasible(lpm)
(value = value.(v), policy = map(x -> argmax(dual.(x)), u))
else
nothing
end
end


"""
lp_solve(model, lpmf, [silent = true])
Implements the linear program primal problem for an MDP `model` with a discount factor `γ`.
It uses the JuMP model `lpm` as the linear program solver and returns the state values
found found using the solver constructed by `JuMP.Model(lpmf)`.
## Examples
# Example
```jldoctest
using MDPs, HiGHS
model = Domains.Gambler.RuinTransient(0.5, 4, true)
lp_solve(model, TotalReward(), HiGHS.Optimizer).policy
# output
5-element Vector{Int64}:
1
4
2
2
1
```
"""
function lp_solve(model::TabMDP, obj::TotalReward, lpmf; silent = true)
# nothing => run with the true rewards
solution = _transient_lp(model, nothing, lpmf; silent = silent)
if isnothing(solution)
error("Failed to solve LP formulation. Is MDP transient?")
else
solution
end
end


"""
anytransient(model, lpmf, [silent = true])
Checks if the MDP `model` has some transient policy. A policy is transient if it
is guaranteed to terminate with positive probability after some finite number of steps.
Note that the function returns true even when there are some policies that are not transient.
The parameters match the use in `lp_solve`.
"""
function anytransient(model::TabMDP, lpmf; silent = true)
solution = _transient_lp(model, -1., lpmf; silent = silent)
!isnothing(solution)
end

"""
anytransient(model, lpmf, [silent = true])
Checks if the MDP `model` has all transient policies. A policy is transient if it
is guaranteed to terminate with positive probability after some finite number of steps.
Note that the function returns true only if all policies are transient.
The parameters match the use in `lp_solve`.
"""
function alltransient(model::TabMDP, lpmf; silent = true)
solution = _transient_lp(model, 1., lpmf; silent = silent)
!isnothing(solution)
end
Loading

0 comments on commit eb8b45a

Please sign in to comment.