Skip to content

Commit

Permalink
added time dependency to qvalue and finite horizon VI
Browse files Browse the repository at this point in the history
  • Loading branch information
Marek Petrik committed Jan 4, 2024
1 parent c58c930 commit 2fc119e
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 38 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MDPs"
uuid = "faa839ec-fb2b-412e-b807-f8b3264e2d6a"
authors = ["Marek Petrik <[email protected]>"]
version = "0.1.4"
version = "0.1.5"

[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/valueiteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function value_iteration!(v::Vector{Vector{Float64}}, π::Vector{Vector{Int}},
for t horizon(objective):-1:1
# initialize vectors
Threads.@threads for s 1:n
bg = bellmangreedy(model, objective, s, v[t+1])
bg = bellmangreedy(model, objective, t, s, v[t+1])
v[t][s] = bg.qvalue
π[t][s] = bg.action
end
Expand Down
74 changes: 38 additions & 36 deletions src/valuefunction/bellman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# ----------------------------------------------------------------

"""
qvalues(model, objective, s, v)
qvalues(model, objective, [t=0,] s, v)
Compute the state-action-value for state `s`, and value function `v` for
`objective`. There is no set representation of the value function `v`.
Expand All @@ -14,16 +14,18 @@ and transitions.
The function is tractable only if there are a small number of actions and transitions.
"""
function qvalues(model::MDP{S,A}, objective::Union{FiniteH, InfiniteH}, s::S, v) where
{S,A}
function qvalues(model::MDP{S,A}, objective::Union{FiniteH, InfiniteH},
t::Integer, s::S, v) where {S,A}
acts = actions(model, s)
qvalues = Vector{Float64}(undef, length(acts))
qvalues!(qvalues, model, objective, s, v)
return qvalues
end

qvalues(model, objective, s, v) = qvalues(model, objective, 0, s, v)

"""
qvalues!(qvalues, model, objective, s, v)
qvalues!(qvalues, model, objective, [t=0,] s, v)
Compute the state-action-values for state `s`, and value function `v` for the
`objective`.
Expand All @@ -35,29 +37,31 @@ count are set to `-Inf`.
See `qvalues` for more information.
"""
function qvalues!(qvalues::AbstractVector{<:Real}, model::MDP{S,A},
obj::Objective, s::S, v) where {S,A}
obj::Objective, t::Integer, s::S, v) where {S,A}

if isterminal(model, s)
qvalues .= -Inf
qvalues[1] = 0
else
acts = actions(model, s)
for (ia,a) enumerate(acts)
qvalues[ia] = qvalue(model, obj, s, a, v)
qvalues[ia] = qvalue(model, obj, t, s, a, v)
end
end
end

qvalues!(qvalues, model, obj, s, v) = qvalues!(qvalues, model, obj, 0, s, v)

"""
qvalue(model, objective, s, a, v)
qvalue(model, objective, [t=0,] s, a, v)
Compute the state-action-values for state `s`, action `a`, and
value function `v` for an `objective`.
There is no set representation for the value function.
"""
function qvalue(model::MDP{S,A}, objective::Union{FiniteH, InfiniteH},
s::S, a::A, v) where {S,A}
t::Integer, s::S, a::A, v) where {S,A}
val :: Float64 = 0.0
# much much faster than sum( ... for)
for (sn, p, r) transition(model, s, a)
Expand All @@ -67,16 +71,9 @@ function qvalue(model::MDP{S,A}, objective::Union{FiniteH, InfiniteH},
end


"""
qvalue(model, γ, s, a, v)
Compute the state-action-values for state `s`, action `a`, and
value function `v` for a discount factor `γ`.
This function is just a more efficient version of the standard definition.
"""
# more efficient version for IntMDPs
function qvalue(model::IntMDP, objective::Union{FiniteH, InfiniteH},
s::Int, a::Int, v::AbstractVector{<:Real})
t::Integer, s::Integer, a::Integer, v::AbstractVector{<:Real})
x = model.states[s].actions[a]
val = 0.0
# much much faster than sum( ... for)
Expand All @@ -87,90 +84,95 @@ function qvalue(model::IntMDP, objective::Union{FiniteH, InfiniteH},
val :: Float64
end

qvalue(model, objective, s, a, v) = qvalue(model, objective, 0, s, a, v)

# ----------------------------------------------------------------
# Generalized Bellman operators
# ----------------------------------------------------------------

"""
bellmangreedy(model, obj, s, v)
bellmangreedy(model, obj, [t=0,] s, v)
Compute the Bellman operator and greedy action for state `s`, and value
function `v` assuming an objective `obj`.
function `v` assuming an objective `obj`. The optional time parameter `t` allows for
time-dependent updates.
The function uses `qvalue` to compute the Bellman operator and the greedy policy.
"""
function bellmangreedy(model::MDP{S,A}, obj::Objective, s::S, v) where {S,A}
function bellmangreedy(model::MDP{S,A}, obj::Objective, t::Integer, s::S, v) where {S,A}
if isterminal(model, s)
(qvalue = 0 :: Float64,
action = emptyaction(model) :: A)
else
acts = actions(model, s)
(qval, ia) = findmax(a->qvalue(model, obj, s, a, v), acts)
(qval, ia) = findmax(a->qvalue(model, obj, t, s, a, v), acts)
(qvalue = qval :: Float64,
action = acts[ia] :: A)
end
end

# default fallback when t is
bellmangreedy(model, obj, s, v) = bellmangreedy(model, obj, 0, s, v)

# ----------------------------------------------------------------
# Greedy policies and Bellman
# ----------------------------------------------------------------

"""
greedy(model, obj, [s,] v)
greedy(model, obj, [t=0,] s, v)
Compute the greedy action for state `s` and value
function `v` assuming an objective `obj`.
If `s` is not provided, then computes a value function for all states.
The model must support `states` function.
"""
greedy(model::MDP{S,A}, obj::Objective, s::S, v) where {S,A} =
bellmangreedy(model, obj, s, v).action :: A
greedy(model::MDP{S,A}, obj::Objective, t::Integer, s::S, v) where {S,A} =
bellmangreedy(model, obj, t, s, v).action :: A

greedy(model, obj, s, v) = greedy(model, obj, 0, s, v)

#greedy(model::TabMDP{S,A}, obj::Objective, v) where {S,A} =
# greedy.((model,), (obj,), states(model), (v,))

"""
greedy!(π, model, obj, v)
Update policy `π` with the greedy policy for value function `v` and MDP `model`
and an objective `obj`.
"""
function greedy!::Vector{Int}, model::TabMDP, obj::Objective,
function greedy!::Vector{Int}, model::TabMDP, obj::Objective, t::Integer,
v::AbstractVector{<:Real})

length(π) == state_count(model) ||
error("Policy π length must be the same as the state count")
length(v) == state_count(model) ||
error("Value function length must be the same as the state count")

π .= greedy.((model,), (obj,), states(model), (v,))
π .= greedy.((model,), (obj,), (t,), states(model), (v,))
end

greedy!(π, model, obj, v) = greedy!(π, model, obj, 0, v)


"""
greedy(model, obj, v)
Compute the greedy action for all states and value function `v` assuming
an objective `obj`.
an objective `obj` and time `t=0`.
"""
function greedy(model::TabMDP, obj::Stationary, v::AbstractVector{<:Real})
π = Vector{Int}(undef, state_count(model))
greedy!(π,model,obj,v)
π
end

greedy(model, γ::Real, v::AbstractVector{<:Real}) = greedy(model, InfiniteH(γ), v)


"""
bellman(model, γ, s, v)
bellman(model, obj, [t=0,] s, v)
Compute the Bellman operator for state `s`, and value function `v` assuming
an objective `obj`.
A real-valued objective `obj` is interpreted as a discount factor.
"""
bellman(model::MDP{S,A}, obj::Objective, s::S, v) where {S,A} =
bellmangreedy(model, obj, s, v).qvalue :: Float64
bellman(model::MDP{S,A}, obj::Objective, t::Integer, s::S, v) where {S,A} =
bellmangreedy(model, obj, t, s, v).qvalue :: Float64

bellman(model, obj, s, v) = bellman(model, obj, 0, s, v)

0 comments on commit 2fc119e

Please sign in to comment.