Skip to content

Commit

Permalink
domain improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Marek Petrik committed Sep 10, 2024
1 parent 9920336 commit d67bbf9
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 11 deletions.
33 changes: 32 additions & 1 deletion src/algorithms/valueiteration.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
value_iteration(model, objective; [v_terminal, iterations = 1000, ϵ = 1e-3] )
value_iteration(model, objective[, π]; [v_terminal, iterations = 1000, ϵ = 1e-3] )
Compute value function and policy for a tabular MDP `model` with an objective
Expand All @@ -19,6 +19,8 @@ The argument `v_terminal` represents the terminal value function. It should be p
a function that maps the state id to its terminal value (at time T+1). If this value is provided,
then it is used in place of 0.
If a policy `π` is provided, then the algorithm evaluates it.
Infinite Horizon
----------------
For a Bellman error `ϵ`, the computed value function is quaranteed to be within
Expand All @@ -32,6 +34,11 @@ is more efficient, but the goal of this function is also to compute the value
function.
The time steps go from 1 to T+1.
## See also
`value_iteration!`
"""
function value_iteration end

Expand Down Expand Up @@ -69,6 +76,30 @@ function value_iteration!(v::Vector{Vector{Float64}}, π::Vector{Vector{Int}},
return (policy = π, value = v)
end


function value_iteration(model::TabMDP, objective::Markov, π::Vector{Vector{Int}};
v_terminal = nothing)
length(π) == horizon(objective) ||

error("Policy π length must match the horizon $(horizon(model))")
vp = make_value(model, objective)
v = vp.value

n = state_count(model)
# final value function
v[horizon(objective)+1] .= isnothing(v_terminal) ? 0 : map(v_terminal, 1:n)

for t horizon(objective):-1:1
# initialize vectors
Threads.@threads for s 1:n
length(π[t]) == n || error("Policy π[$t] length must match state count $n")
v[t][s] = qvalue(model, objective, t, s, π[t][s], v[t+1])
end
end
return (policy = π, value = v)
end


function value_iteration(model::TabMDP, objective::Stationary;
iterations::Integer = 10000, ϵ::Number = 1e-3)
nstates = state_count(model)
Expand Down
33 changes: 23 additions & 10 deletions src/domains/gambler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ end


"""
RuinTransient(win, max_capital, noop)
RuinTransient(win, max_capital, noop[, win_reward = 1.0, lose_reward = 0.0])
Gambler's ruin; the transient version. Can decide how much to bet at any point in time. With some
probability `win`, the bet is doubled, and with `1-win` it is lost. The reward is `1` if it achieves
Expand All @@ -83,18 +83,23 @@ transient.
Special states: `state=1` is broke and `state=max_capital+1` is maximal capital. Both of the
states are absorbing/terminal.
The reward is `0` when the gambler goes broke and `+1` when it achieves the target capital. The
difference from `Ruin` is that the reward is not received in the terminal state.
By default, the reward is `0` when the gambler goes broke and `+1` when it achieves the
target capital. The difference from `Ruin` is that no reward received in the terminal state.
The rewards for overall win and loss can be adjusted by providing `win_reward` and
`lose_reward` optional parameters.
"""
struct RuinTransient <: TabMDP
win :: Float64
max_capital :: Int
noop :: Bool
win_reward :: Float64
lose_reward :: Float64

function RuinTransient(win::Number, max_capital::Integer, noop::Bool)
function RuinTransient(win::Number, max_capital::Integer, noop::Bool;
win_reward = 1.0, lose_reward = 0.0)
zero(win) win one(win) || error("Win probability must be in [0,1]")
max_capital one(max_capital) || error("Max capital must be positive")
new(win, max_capital, noop)
new(win, max_capital, noop, win_reward, lose_reward)
end
end

Expand All @@ -107,15 +112,15 @@ function action_count(model::RuinTransient, state::Int)
1
else
capital = state - 1
model.noop ? model.max_capital + 1 : model.max_capital
model.noop ? capital + 1 : capital
end
end

function transition(model::RuinTransient, state::Int, action::Int)
absorbing = state_count(model) # the "last" state

1 state absorbing || error("invalid state")
1 action action_count(model, state) || error("invalid action")
1 state absorbing || error("invalid state: $state")
1 action action_count(model, state) || error("invalid action $action in state $state")

if state == 1 # broke
(mt(state, 1.0, 0.0),)
Expand All @@ -127,11 +132,19 @@ function transition(model::RuinTransient, state::Int, action::Int)
win_state = min(model.max_capital + 1, (state - 1) + bet + 1)
lose_state = max(1, (state - 1) - bet + 1)

zero_rew = 1e-8 * rand()

# reward 1.0 if an donly if we achieve the target capital
win_reward = win_state == absorbing ? 1.0 : 0.0
win_reward = win_state == absorbing ? model.win_reward : 0.0
lose_reward = lose_state == 1 ? model.lose_reward : 0.0

# transition to the absorbing last state
if lose_state == 1
lose_state = absorbing
end

# the reward is 0 when we lose
(mt(win_state, model.win, win_reward), mt(lose_state, 1.0 - model.win, 0.))
(mt(win_state, model.win, win_reward), mt(lose_state, 1.0 - model.win, lose_reward))
end
end

Expand Down

0 comments on commit d67bbf9

Please sign in to comment.