diff --git a/src/domains/gridworld.jl b/src/domains/gridworld.jl index 7b0b3be..000ce1d 100644 --- a/src/domains/gridworld.jl +++ b/src/domains/gridworld.jl @@ -4,7 +4,6 @@ import ...TabMDP, ...transition, ...state_count, ...action_count import ...actions, ...states # TODO: Add docs -# TODO: Add tests """ Models values of demand in `values` and probabilities in `probabilities`. """ @@ -27,6 +26,15 @@ struct Parameters rewards_s::Vector{Float64} max_side_length::Int wind::Float64 + + function Parameters(rewards_s, max_side_length, wind) + length(rewards_s) == max_side_length * max_side_length || + error("Rewards must have the same length as the number of states.") + wind ≥ 0.0 || error("Wind must be non-negative.") + wind ≤ 1.0 || error("Wind must be less than or equal to 1.") + + new(rewards_s, max_side_length, wind) + end end @@ -50,7 +58,7 @@ function transition(model::Model, state::Int, action::Int) remaining_wind = model.params.wind / 3 ret = [] # Wrap the state around the grid 1-based indexing - # Julia for the love of God please implement a proper modulo function + # NOTE: Julia for the love of God please implement a proper modulo function upstate = state - n <= 0 ? state + n_states - n : state - n downstate = (state + n) > n_states ? state - n_states + n : state + n leftstate = state % n == 1 ? state + (n - 1) : state - 1 diff --git a/test/src/domains/gridworld.jl b/test/src/domains/gridworld.jl index c84f1cb..797e5bc 100644 --- a/test/src/domains/gridworld.jl +++ b/test/src/domains/gridworld.jl @@ -1,30 +1,11 @@ using MDPs.Domains @testset "Solve Gridworld" begin - reward = [0.1, 0.1, 0.2, -10, -15, 100, 1, 0.5] + reward = [0.1, 0.1, 0.2, -10, -15, 100, 1, 0.5, 0.1] max_side_length = 3 wind = 0.2 params = GridWorld.Parameters(reward, max_side_length, wind) - # Initialize flags for tests - stateok = true - actionok = true - transitionok = true - - for s in 1:GridWorld.state_count(params) - state = GridWorld.state2state(params, s) # Assuming state2state function exists - stateok &= (GridWorld.state2state(params, state) == s) - for a in 1:GridWorld.action_count(params, s) - action = GridWorld.action2action(params, a) # Assuming action2action function exists - actionok &= (GridWorld.action2action(params, action) == a) - transitionok &= (GridWorld.transition(params, state, action, 0).state == state + action) # Adjust transition logic as per the actual implementation - end - end - - @test stateok - @test actionok - @test transitionok - model = GridWorld.Model(params) simulate(model, random_π(model), 1, 10000, 500) model_g = make_int_mdp(model; docompress=false)