Skip to content

Commit

Permalink
Merge branch 'Gersi_Gridworld' of github.com:RiskAverseRL/MDPs.jl int…
Browse files Browse the repository at this point in the history
…o Gersi_Gridworld
  • Loading branch information
Marek Petrik committed Jul 3, 2024
2 parents b08199b + cc2e999 commit bc85469
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 22 deletions.
12 changes: 10 additions & 2 deletions src/domains/gridworld.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import ...TabMDP, ...transition, ...state_count, ...action_count
import ...actions, ...states

# TODO: Add docs
# TODO: Add tests
"""
Models values of demand in `values` and probabilities in `probabilities`.
"""
Expand All @@ -27,6 +26,15 @@ struct Parameters
rewards_s::Vector{Float64}
max_side_length::Int
wind::Float64

function Parameters(rewards_s, max_side_length, wind)
length(rewards_s) == max_side_length * max_side_length ||
error("Rewards must have the same length as the number of states.")
wind 0.0 || error("Wind must be non-negative.")
wind 1.0 || error("Wind must be less than or equal to 1.")

new(rewards_s, max_side_length, wind)
end
end


Expand All @@ -50,7 +58,7 @@ function transition(model::Model, state::Int, action::Int)
remaining_wind = model.params.wind / 3
ret = []
# Wrap the state around the grid 1-based indexing
# Julia for the love of God please implement a proper modulo function
# NOTE: Julia for the love of God please implement a proper modulo function
upstate = state - n <= 0 ? state + n_states - n : state - n
downstate = (state + n) > n_states ? state - n_states + n : state + n
leftstate = state % n == 1 ? state + (n - 1) : state - 1
Expand Down
21 changes: 1 addition & 20 deletions test/src/domains/gridworld.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,11 @@
using MDPs.Domains

@testset "Solve Gridworld" begin
reward = [0.1, 0.1, 0.2, -10, -15, 100, 1, 0.5]
reward = [0.1, 0.1, 0.2, -10, -15, 100, 1, 0.5, 0.1]
max_side_length = 3
wind = 0.2
params = GridWorld.Parameters(reward, max_side_length, wind)

# Initialize flags for tests
stateok = true
actionok = true
transitionok = true

for s in 1:GridWorld.state_count(params)
state = GridWorld.state2state(params, s) # Assuming state2state function exists
stateok &= (GridWorld.state2state(params, state) == s)
for a in 1:GridWorld.action_count(params, s)
action = GridWorld.action2action(params, a) # Assuming action2action function exists
actionok &= (GridWorld.action2action(params, action) == a)
transitionok &= (GridWorld.transition(params, state, action, 0).state == state + action) # Adjust transition logic as per the actual implementation
end
end

@test stateok
@test actionok
@test transitionok

model = GridWorld.Model(params)
simulate(model, random_π(model), 1, 10000, 500)
model_g = make_int_mdp(model; docompress=false)
Expand Down

0 comments on commit bc85469

Please sign in to comment.