Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gersi gridworld #6

Merged
merged 28 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/MDPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ include("domains/machine.jl")
export Machine
include("domains/gambler.jl")
export Gambler
include("domains/gridworld.jl")
export GridWorld
end
export Domains
# --------------------
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/policyiteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function policy_iteration(model::TabMDP, γ::Real; iterations::Int = 1000)
for it ∈ 1:iterations
policyold .= policy
greedy!(policy, model, InfiniteH(γ), v_π)
mrp!(IP_π, r_π, model, policy);
mrp!(IP_π, r_π, model, policy)
# Solve: v_π .= (I - γ * P_π) \ r_π
lmul!(-γ, IP_π)
_add_identity!(IP_π)
Expand Down
101 changes: 101 additions & 0 deletions src/domains/gridworld.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
module GridWorld

import ...TabMDP, ...transition, ...state_count, ...action_count
import ...actions, ...states

# TODO: Add docs, with method signatures
"""
Models values of demand in `values` and probabilities in `probabilities`.
"""

@enum Action begin
UP = 1
DOWN = 2
LEFT = 3
RIGHT = 4
end

"""
Parameters that define a GridWorld problem

- `rewards_s`: A vector of rewards for each state
- `max_side_length`: An integer that represents the maximum side length of the grid
- `wind`: A float that represents the wind
"""
struct Parameters
rewards_s::Vector{Float64}
max_side_length::Int
wind::Float64

function Parameters(rewards_s, max_side_length, wind)
length(rewards_s) == max_side_length * max_side_length ||
error("Rewards must have the same length as the number of states.")
wind ≥ 0.0 || error("Wind must be non-negative.")
wind ≤ 1.0 || error("Wind must be less than or equal to 1.")

new(rewards_s, max_side_length, wind)
end
end


# ----------------------------------------------------------------
# Definition of MDP models and functions
# ----------------------------------------------------------------

"""
A GridWorld MDP problem simulator

The states and actions are 1-based integers.
"""
struct Model <: TabMDP
params::Parameters
end

function transition(model::Model, state::Int, action::Int)
n = model.params.max_side_length
n_states = state_count(model.params)
compl_wind = (1.0 - model.params.wind)
remaining_wind = model.params.wind / 3
ret = []
# Wrap the state around the grid 1-based indexing
# NOTE: Julia for the love of God please implement a proper modulo function
upstate = state - n <= 0 ? state + n_states - n : state - n
downstate = (state + n) > n_states ? state - n_states + n : state + n
leftstate = state % n == 1 ? state + (n - 1) : state - 1
rightstate = state % n == 0 ? state - (n - 1) : state + 1
if action == Int(UP)
push!(ret, (upstate, compl_wind, model.params.rewards_s[upstate]))
push!(ret, (downstate, remaining_wind, model.params.rewards_s[downstate]))
push!(ret, (leftstate, remaining_wind, model.params.rewards_s[leftstate]))
push!(ret, (rightstate, remaining_wind, model.params.rewards_s[rightstate]))
elseif action == Int(DOWN)
push!(ret, (downstate, compl_wind, model.params.rewards_s[downstate]))
push!(ret, (upstate, remaining_wind, model.params.rewards_s[upstate]))
push!(ret, (leftstate, remaining_wind, model.params.rewards_s[leftstate]))
push!(ret, (rightstate, remaining_wind, model.params.rewards_s[rightstate]))
elseif action == Int(LEFT)
push!(ret, (leftstate, compl_wind, model.params.rewards_s[leftstate]))
push!(ret, (upstate, remaining_wind, model.params.rewards_s[upstate]))
push!(ret, (downstate, remaining_wind, model.params.rewards_s[downstate]))
push!(ret, (rightstate, remaining_wind, model.params.rewards_s[rightstate]))
elseif action == Int(RIGHT)
push!(ret, (rightstate, compl_wind, model.params.rewards_s[rightstate]))
push!(ret, (upstate, remaining_wind, model.params.rewards_s[upstate]))
push!(ret, (downstate, remaining_wind, model.params.rewards_s[downstate]))
push!(ret, (leftstate, remaining_wind, model.params.rewards_s[leftstate]))
else
throw(ArgumentError("Invalid action " * string(action) * " for GridWorld."))
end
return ret
end

state_count(params::Parameters) = params.max_side_length * params.max_side_length
action_count(params::Parameters, state::Int) = 4

state_count(model::Model) = model.params.max_side_length * model.params.max_side_length
action_count(model::Model, state::Int) = 4

states(model::Model) = 1:state_count(model.params)
actions(model::Model, state::Int) = 1:action_count(model.params, state)

end # Module: GridWorld
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ include("src/integral.jl")
include("src/domains/inventory.jl")
include("src/domains/make_domains.jl")
include("src/domains/solvers.jl")
include("src/domains/gridworld.jl")
31 changes: 31 additions & 0 deletions test/src/domains/gridworld.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using MDPs.Domains

@testset "Solve Gridworld" begin
reward = [0.1, 0.1, 0.2, -10, -15, 100, 1, 0.5, 0.1]
max_side_length = 3
wind = 0.2
params = GridWorld.Parameters(reward, max_side_length, wind)

model = GridWorld.Model(params)
simulate(model, random_π(model), 1, 10000, 500)
model_g = make_int_mdp(model; docompress=false)
model_gc = make_int_mdp(model; docompress=true)

v1 = value_iteration(model, InfiniteH(0.95); ϵ=1e-10)
v2 = value_iteration(model_g, InfiniteH(0.95); ϵ=1e-10)
v3 = value_iteration(model_gc, InfiniteH(0.95); ϵ=1e-10)
v4 = policy_iteration(model_gc, 0.95)

# Ensure value functions are close
V = hcat(v1.value, v2.value[1:end-1], v3.value[1:end-1], v4.value[1:end-1])
@test map(x -> x[2] - x[1], mapslices(extrema, V; dims=2)) |> maximum ≤ 1e-6

# Ensure policies are identical
p1 = greedy(model, InfiniteH(0.95), v1.value)
p2 = greedy(model_g, InfiniteH(0.95), v2.value)
p3 = greedy(model_gc, InfiniteH(0.95), v3.value)
p4 = v4.policy

P = hcat(p1, p2[1:end-1], p3[1:end-1], p4[1:end-1])
@test all(mapslices(allequal, P; dims=2))
end
Loading