Skip to content

Commit

Permalink
added submodel function
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Oct 31, 2024
1 parent fb7ae15 commit 4cc9464
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 94 deletions.
114 changes: 21 additions & 93 deletions docs/getting_started/simulator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
"cells": [
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"id": "a8ddc37c-66d2-43e4-8162-6be19a1d70a1",
"metadata": {},
"outputs": [],
"source": [
"from stormvogel import visualization, show, simulator, model"
"from stormvogel import visualization, show, simulator\n",
"import stormvogel.model"
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"id": "cab40f99-3460-4497-8b9f-3d669eee1e11",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -85,27 +86,10 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"id": "eb0fadc0-7bb6-4c1d-ae3e-9e16527726ab",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ModelType.MDP with name None\n",
"\n",
"States:\n",
"State 0 with labels ['init'] and features {}\n",
"State 1 with labels ['carchosen'] and features {}\n",
"State 2 with labels ['open'] and features {}\n",
"State 3 with labels ['goatrevealed'] and features {}\n",
"State 4 with labels ['target', 'done'] and features {}\n",
"\n",
"Transitions:\n"
]
}
],
"outputs": [],
"source": [
"#we want to simulate this model. That is, we start at the initial state and then\n",
"#we walk through the model according to transition probabilities.\n",
Expand All @@ -125,29 +109,12 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"id": "59ac1e34-866c-42c4-b19b-c2a15c830e2e",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ModelType.MDP with name None\n",
"\n",
"States:\n",
"State 0 with labels ['init'] and features {}\n",
"State 1 with labels ['carchosen'] and features {}\n",
"State 2 with labels ['open'] and features {}\n",
"State 3 with labels ['goatrevealed'] and features {}\n",
"State 4 with labels ['done'] and features {}\n",
"\n",
"Transitions:\n"
]
}
],
"outputs": [],
"source": [
"#it still chooses random actions but we can prevent this by providing a scheduler:\n",
"taken_actions = {}\n",
Expand All @@ -161,74 +128,35 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": null,
"id": "22871288-755c-463f-9150-f207c2f5c211",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f92e1a66e2034c70832b36e2bd7fe70f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b413c4c87cb34895b0a4551b4e49ed16",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(Output(), Output()))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<stormvogel.visualization.Visualization at 0x7f6e782fe4b0>"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"#we can also visualize the partial model that we get from the simulator:\n",
"show.show(partial_model)"
"show.show(partial_model)\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": null,
"id": "34d0c293-d090-4e3d-9e80-4351f5fcba62",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"initial state --(action: empty)--> state: 2 --(action: open0)--> state: 7 --(action: empty)--> state: 17 --(action: stay)--> state: 33\n"
]
}
],
"outputs": [],
"source": [
"#we can also use another simulator function that returns a path instead of a partial model:\n",
"path = stormvogel.simulator.simulate_path(mdp, steps=4, scheduler=scheduler, seed=123456)\n",
"\n",
"print(path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99c763fa-82ea-42ff-8833-79c640f14518",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
8 changes: 8 additions & 0 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import Enum
from fractions import Fraction
from typing import cast
import copy

Parameter = str

Expand Down Expand Up @@ -391,6 +392,13 @@ def supports_observations(self):
"""Returns whether this model supports observations."""
return self.type == ModelType.POMDP

def get_sub_model(self, states: list[State]) -> "Model":
sub_model = copy.deepcopy(self)
for state in states:
sub_model.delete_state(state, normalize=True, reassign_ids=True)

return sub_model

def is_well_defined(self) -> bool:
"""Checks if all sums of outgoing transition probabilities for all states equal 1"""

Expand Down
5 changes: 4 additions & 1 deletion stormvogel/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,4 +277,7 @@ def get_range_index(stateid: int):
if simulator.is_done():
break

return partial_model
# TODO: refactor
states = list(partial_model.states.values())
sub_model = model.get_sub_model(states)
return sub_model
2 changes: 2 additions & 0 deletions tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def test_simulate():
for stateid in other_dtmc.states.keys():
rewardmodel3.rewards[stateid] = float(1)

print(partial_model)

assert partial_model == other_dtmc

# we make a monty hall mdp and run the simulator with it
Expand Down

0 comments on commit 4cc9464

Please sign in to comment.