Skip to content

Commit

Permalink
model builder now also uses reward function
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Dec 8, 2024
1 parent 9a5d267 commit 072c6cb
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
2 changes: 2 additions & 0 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,8 @@ def get_action(self, name: str) -> Action:
)
assert self.actions is not None
if name not in self.actions:
print(name)
print(self.actions)
raise RuntimeError(
f"Tried to get action {name} but that action does not exist"
)
Expand Down
31 changes: 28 additions & 3 deletions stormvogel/pgc.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __eq__(self, other):
def build_pgc(
delta, # Callable[[State, Action], list[tuple[float, State]]],
initial_state_pgc: State, # TODO rewards function, label function
rewards=None,
labels=None,
available_actions: Callable[[State], list[Action]] | None = None,
modeltype: stormvogel.model.ModelType = stormvogel.model.ModelType.MDP,
) -> stormvogel.model.Model:
Expand Down Expand Up @@ -67,13 +69,13 @@ def build_pgc(
while len(states_to_be_visited) > 0:
state = states_to_be_visited[0]
states_to_be_visited.remove(state)
# we loop over all available actions and call the delta function for each action
transition = {}

if state not in states_seen:
states_seen.append(state)

if model.supports_actions():
# we loop over all available actions and call the delta function for each action
assert available_actions is not None
for action in available_actions(state):
try:
Expand All @@ -98,8 +100,6 @@ def build_pgc(
branch.append((tuple[0], new_state))
states_to_be_visited.append(tuple[1])
else:
# print(tuple[1].__dict__)
# print(model.states)
branch.append(
(tuple[0], model.get_state_by_name(str(tuple[1].__dict__)))
)
Expand Down Expand Up @@ -133,4 +133,29 @@ def build_pgc(
stormvogel.model.Transition(transition),
)

# we add the rewards
# TODO support multiple reward models
if rewards is not None:
rewardmodel = model.add_rewards("rewards")
if model.supports_actions():
for state in states_seen:
assert available_actions is not None
for action in available_actions(state):
reward = rewards(state, action)
s = model.get_state_by_name(str(state.__dict__))
assert s is not None
rewardmodel.set_state_action_reward(
s,
model.get_action(str(action.labels)),
reward,
)
else:
for state in states_seen:
reward = rewards(state)
s = model.get_state_by_name(str(state.__dict__))
assert s is not None
rewardmodel.set_state_reward(s, reward)

# we add the labels

return model
16 changes: 16 additions & 0 deletions tests/test_pgc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def test_pgc_mdp():
def available_actions(s: pgc.State):
return [left, right]

def rewards(s: pgc.State, a: pgc.Action):
return 1

def delta(s: pgc.State, action: pgc.Action):
if action == left:
return (
Expand All @@ -40,6 +43,7 @@ def delta(s: pgc.State, action: pgc.Action):
delta=delta,
available_actions=available_actions,
initial_state_pgc=initial_state,
rewards=rewards,
)

# we build the model in the regular way:
Expand All @@ -60,6 +64,10 @@ def delta(s: pgc.State, action: pgc.Action):
model.add_transitions(state2, stormvogel.model.Transition({right: branch2}))
model.add_transitions(state0, stormvogel.model.Transition({left: branch0}))

rewardmodel = model.add_rewards("rewards")
for i in range(2 * N):
rewardmodel.set_state_action_reward_at_id(i, 1)

assert model == pgc_model


Expand All @@ -68,6 +76,9 @@ def test_pgc_dtmc():
p = 0.5
initial_state = pgc.State(s=0)

def rewards(s: pgc.State):
return 1

def delta(s: pgc.State):
match s.s:
case 0:
Expand Down Expand Up @@ -96,6 +107,7 @@ def delta(s: pgc.State):
pgc_model = stormvogel.pgc.build_pgc(
delta=delta,
initial_state_pgc=initial_state,
rewards=rewards,
modeltype=stormvogel.model.ModelType.DTMC,
)

Expand Down Expand Up @@ -161,4 +173,8 @@ def delta(s: pgc.State):
model.set_transitions(model.get_state_by_id(12), [(1, model.get_state_by_id(13))])
model.set_transitions(model.get_state_by_id(13), [(1, model.get_state_by_id(13))])

rewardmodel = model.add_rewards("rewards")
for state in model.states.values():
rewardmodel.set_state_reward(state, 1)

assert pgc_model == model

0 comments on commit 072c6cb

Please sign in to comment.