Skip to content

Commit

Permalink
dependencies and bugs fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Marek Petrik committed Jul 3, 2024
1 parent 057a847 commit 3eed2bc
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 30 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ version = "0.1.5"
[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
julia = "1.9"
DataFrames = "1.6.1"
DataFramesMeta = "0.14.1"
julia = "1.9"
StatsBase = "0.34.2"
Distributions = "0.25.107"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
65 changes: 36 additions & 29 deletions src/domains/garnet.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,52 @@
module Garnet

import ...TabMDP, ...transition, ...state_count, ...action_count#, StatsBase, Distributions
import ...TabMDP, ...transition, ...state_count, ...action_count
import ...actions, ...states

# TODO: are these reasonable or can we replace them?
import StatsBase, Distributions
# ----------------------------------------------------------------
# A Garnet MDP
# ----------------------------------------------------------------

struct GarnetMDP <: TabMDP
reward::Vector{Vector{Float64}}
transition::Vector{Vector{Vector{Float64}}}
S::Int64
A::Vector{Int64}

"""
function GarnetMDP(numStates::Int, numActions::Vector{Int}, nBranch::Float64, minReward::Int, maxReward::Int)
S = numStates
A = numActions
reward = Vector{Vector{Float64}}([])
transition = Vector{Vector{Vector{Float64}}}([])
dist = Distributions.Exponential(1)
sout = Int(round(nBranch*S))
for i in 1:numStates
r = Vector{Float64}([])
p = Vector{Vector{Float64}}([])
for j in 1:numActions[i]
push!(r,rand(minReward:maxReward))
inds = StatsBase.sample(1:S,sout,replace=false)
z = rand(dist,sout)
z /= sum(z)
pp = zeros(S)
for (k,l) in enumerate(inds) pp[l] = z[k] end
push!(p,pp)
end
push!(reward,r)
push!(transition, p)
S::Int
A::Vector{Int}

# TODO: add a constructor that checks for consistency
end

function make_garnet(S::Integer, A::AbstractVector{Int}, nbranch::Number, min_reward::Integer, max_reward::Integer)

0.0 nbranch 1.0 || error("nbranch must be in [0,1]")

reward = Vector{Vector{Float64}}()
transition = Vector{Vector{Vector{Float64}}}()
dist = Distributions.Exponential(1)
sout = Int(round(nbranch*S))

for i in 1:S
r = Vector{Float64}()
p = Vector{Vector{Float64}}()
for j in 1:A[i]
push!(r, rand(min_reward:max_reward))
inds = StatsBase.sample(1:S, sout, replace=false)
z = rand(dist,sout)
z /= sum(z)
pp = zeros(S)
for (k,l) in enumerate(inds) pp[l] = z[k] end
push!(p,pp)
end
new(reward,transition,S,A)
push!(reward,r)
push!(transition, p)
end
"""

GarnetMDP(reward,transition,S,A)
end

make_garnet(S::Integer, A::Integer, nbranch, min_reward, max_reward) = make_garnet(S, fill(Int(A),S), nbranch, min_reward, max_reward)

function transition(model::GarnetMDP, state::Int, action::Int)
@assert state in 1:model.S
Expand All @@ -58,4 +65,4 @@ state_count(model::GarnetMDP) = model.S
action_count(model::GarnetMDP, s::Int) = model.A[s]

end
# Module: Garnet
# Module: Garnet

0 comments on commit 3eed2bc

Please sign in to comment.