From 2fc119e5ccfd695cc49207e212626a66358cf893 Mon Sep 17 00:00:00 2001 From: Marek Petrik Date: Wed, 3 Jan 2024 19:52:25 -0500 Subject: [PATCH] added time dependency to qvalue and finite horizon VI --- Project.toml | 2 +- src/algorithms/valueiteration.jl | 2 +- src/valuefunction/bellman.jl | 74 ++++++++++++++++---------------- 3 files changed, 40 insertions(+), 38 deletions(-) diff --git a/Project.toml b/Project.toml index a184a19..c54c70e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MDPs" uuid = "faa839ec-fb2b-412e-b807-f8b3264e2d6a" authors = ["Marek Petrik "] -version = "0.1.4" +version = "0.1.5" [deps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" diff --git a/src/algorithms/valueiteration.jl b/src/algorithms/valueiteration.jl index be7fe5c..0ac329a 100644 --- a/src/algorithms/valueiteration.jl +++ b/src/algorithms/valueiteration.jl @@ -61,7 +61,7 @@ function value_iteration!(v::Vector{Vector{Float64}}, π::Vector{Vector{Int}}, for t ∈ horizon(objective):-1:1 # initialize vectors Threads.@threads for s ∈ 1:n - bg = bellmangreedy(model, objective, s, v[t+1]) + bg = bellmangreedy(model, objective, t, s, v[t+1]) v[t][s] = bg.qvalue π[t][s] = bg.action end diff --git a/src/valuefunction/bellman.jl b/src/valuefunction/bellman.jl index 7dd00e0..6ba69bb 100644 --- a/src/valuefunction/bellman.jl +++ b/src/valuefunction/bellman.jl @@ -4,7 +4,7 @@ # ---------------------------------------------------------------- """ - qvalues(model, objective, s, v) + qvalues(model, objective, [t=0,] s, v) Compute the state-action-value for state `s`, and value function `v` for `objective`. There is no set representation of the value function `v`. @@ -14,16 +14,18 @@ and transitions. The function is tractable only if there are a small number of actions and transitions. """ -function qvalues(model::MDP{S,A}, objective::Union{FiniteH, InfiniteH}, s::S, v) where - {S,A} +function qvalues(model::MDP{S,A}, objective::Union{FiniteH, InfiniteH}, + t::Integer, s::S, v) where {S,A} acts = actions(model, s) qvalues = Vector{Float64}(undef, length(acts)) qvalues!(qvalues, model, objective, s, v) return qvalues end +qvalues(model, objective, s, v) = qvalues(model, objective, 0, s, v) + """ - qvalues!(qvalues, model, objective, s, v) + qvalues!(qvalues, model, objective, [t=0,] s, v) Compute the state-action-values for state `s`, and value function `v` for the `objective`. @@ -35,7 +37,7 @@ count are set to `-Inf`. See `qvalues` for more information. """ function qvalues!(qvalues::AbstractVector{<:Real}, model::MDP{S,A}, - obj::Objective, s::S, v) where {S,A} + obj::Objective, t::Integer, s::S, v) where {S,A} if isterminal(model, s) qvalues .= -Inf @@ -43,13 +45,15 @@ function qvalues!(qvalues::AbstractVector{<:Real}, model::MDP{S,A}, else acts = actions(model, s) for (ia,a) ∈ enumerate(acts) - qvalues[ia] = qvalue(model, obj, s, a, v) + qvalues[ia] = qvalue(model, obj, t, s, a, v) end end end +qvalues!(qvalues, model, obj, s, v) = qvalues!(qvalues, model, obj, 0, s, v) + """ - qvalue(model, objective, s, a, v) + qvalue(model, objective, [t=0,] s, a, v) Compute the state-action-values for state `s`, action `a`, and value function `v` for an `objective`. @@ -57,7 +61,7 @@ value function `v` for an `objective`. There is no set representation for the value function. """ function qvalue(model::MDP{S,A}, objective::Union{FiniteH, InfiniteH}, - s::S, a::A, v) where {S,A} + t::Integer, s::S, a::A, v) where {S,A} val :: Float64 = 0.0 # much much faster than sum( ... for) for (sn, p, r) ∈ transition(model, s, a) @@ -67,16 +71,9 @@ function qvalue(model::MDP{S,A}, objective::Union{FiniteH, InfiniteH}, end -""" - qvalue(model, γ, s, a, v) - -Compute the state-action-values for state `s`, action `a`, and -value function `v` for a discount factor `γ`. - -This function is just a more efficient version of the standard definition. -""" +# more efficient version for IntMDPs function qvalue(model::IntMDP, objective::Union{FiniteH, InfiniteH}, - s::Int, a::Int, v::AbstractVector{<:Real}) + t::Integer, s::Integer, a::Integer, v::AbstractVector{<:Real}) x = model.states[s].actions[a] val = 0.0 # much much faster than sum( ... for) @@ -87,37 +84,42 @@ function qvalue(model::IntMDP, objective::Union{FiniteH, InfiniteH}, val :: Float64 end +qvalue(model, objective, s, a, v) = qvalue(model, objective, 0, s, a, v) + # ---------------------------------------------------------------- # Generalized Bellman operators # ---------------------------------------------------------------- """ - bellmangreedy(model, obj, s, v) + bellmangreedy(model, obj, [t=0,] s, v) Compute the Bellman operator and greedy action for state `s`, and value -function `v` assuming an objective `obj`. +function `v` assuming an objective `obj`. The optional time parameter `t` allows for +time-dependent updates. The function uses `qvalue` to compute the Bellman operator and the greedy policy. """ -function bellmangreedy(model::MDP{S,A}, obj::Objective, s::S, v) where {S,A} +function bellmangreedy(model::MDP{S,A}, obj::Objective, t::Integer, s::S, v) where {S,A} if isterminal(model, s) (qvalue = 0 :: Float64, action = emptyaction(model) :: A) else acts = actions(model, s) - (qval, ia) = findmax(a->qvalue(model, obj, s, a, v), acts) + (qval, ia) = findmax(a->qvalue(model, obj, t, s, a, v), acts) (qvalue = qval :: Float64, action = acts[ia] :: A) end end +# default fallback when t is +bellmangreedy(model, obj, s, v) = bellmangreedy(model, obj, 0, s, v) # ---------------------------------------------------------------- # Greedy policies and Bellman # ---------------------------------------------------------------- """ - greedy(model, obj, [s,] v) + greedy(model, obj, [t=0,] s, v) Compute the greedy action for state `s` and value function `v` assuming an objective `obj`. @@ -125,11 +127,11 @@ function `v` assuming an objective `obj`. If `s` is not provided, then computes a value function for all states. The model must support `states` function. """ -greedy(model::MDP{S,A}, obj::Objective, s::S, v) where {S,A} = - bellmangreedy(model, obj, s, v).action :: A +greedy(model::MDP{S,A}, obj::Objective, t::Integer, s::S, v) where {S,A} = + bellmangreedy(model, obj, t, s, v).action :: A + +greedy(model, obj, s, v) = greedy(model, obj, 0, s, v) -#greedy(model::TabMDP{S,A}, obj::Objective, v) where {S,A} = -# greedy.((model,), (obj,), states(model), (v,)) """ greedy!(π, model, obj, v) @@ -137,7 +139,7 @@ greedy(model::MDP{S,A}, obj::Objective, s::S, v) where {S,A} = Update policy `π` with the greedy policy for value function `v` and MDP `model` and an objective `obj`. """ -function greedy!(π::Vector{Int}, model::TabMDP, obj::Objective, +function greedy!(π::Vector{Int}, model::TabMDP, obj::Objective, t::Integer, v::AbstractVector{<:Real}) length(π) == state_count(model) || @@ -145,15 +147,17 @@ function greedy!(π::Vector{Int}, model::TabMDP, obj::Objective, length(v) == state_count(model) || error("Value function length must be the same as the state count") - π .= greedy.((model,), (obj,), states(model), (v,)) + π .= greedy.((model,), (obj,), (t,), states(model), (v,)) end +greedy!(π, model, obj, v) = greedy!(π, model, obj, 0, v) + """ greedy(model, obj, v) Compute the greedy action for all states and value function `v` assuming -an objective `obj`. +an objective `obj` and time `t=0`. """ function greedy(model::TabMDP, obj::Stationary, v::AbstractVector{<:Real}) π = Vector{Int}(undef, state_count(model)) @@ -161,16 +165,14 @@ function greedy(model::TabMDP, obj::Stationary, v::AbstractVector{<:Real}) π end -greedy(model, γ::Real, v::AbstractVector{<:Real}) = greedy(model, InfiniteH(γ), v) - """ - bellman(model, γ, s, v) + bellman(model, obj, [t=0,] s, v) Compute the Bellman operator for state `s`, and value function `v` assuming an objective `obj`. - -A real-valued objective `obj` is interpreted as a discount factor. """ -bellman(model::MDP{S,A}, obj::Objective, s::S, v) where {S,A} = - bellmangreedy(model, obj, s, v).qvalue :: Float64 +bellman(model::MDP{S,A}, obj::Objective, t::Integer, s::S, v) where {S,A} = + bellmangreedy(model, obj, t, s, v).qvalue :: Float64 + +bellman(model, obj, s, v) = bellman(model, obj, 0, s, v)