diff --git a/docs/getting_started/simulator.ipynb b/docs/getting_started/simulator.ipynb index 8d27f49..0dbb8d1 100644 --- a/docs/getting_started/simulator.ipynb +++ b/docs/getting_started/simulator.ipynb @@ -2,17 +2,18 @@ "cells": [ { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "a8ddc37c-66d2-43e4-8162-6be19a1d70a1", "metadata": {}, "outputs": [], "source": [ - "from stormvogel import visualization, show, simulator, model" + "from stormvogel import visualization, show, simulator\n", + "import stormvogel.model" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "cab40f99-3460-4497-8b9f-3d669eee1e11", "metadata": {}, "outputs": [], @@ -85,27 +86,10 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "eb0fadc0-7bb6-4c1d-ae3e-9e16527726ab", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ModelType.MDP with name None\n", - "\n", - "States:\n", - "State 0 with labels ['init'] and features {}\n", - "State 1 with labels ['carchosen'] and features {}\n", - "State 2 with labels ['open'] and features {}\n", - "State 3 with labels ['goatrevealed'] and features {}\n", - "State 4 with labels ['target', 'done'] and features {}\n", - "\n", - "Transitions:\n" - ] - } - ], + "outputs": [], "source": [ "#we want to simulate this model. That is, we start at the initial state and then\n", "#we walk through the model according to transition probabilities.\n", @@ -125,29 +109,12 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "id": "59ac1e34-866c-42c4-b19b-c2a15c830e2e", "metadata": { "scrolled": true }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ModelType.MDP with name None\n", - "\n", - "States:\n", - "State 0 with labels ['init'] and features {}\n", - "State 1 with labels ['carchosen'] and features {}\n", - "State 2 with labels ['open'] and features {}\n", - "State 3 with labels ['goatrevealed'] and features {}\n", - "State 4 with labels ['done'] and features {}\n", - "\n", - "Transitions:\n" - ] - } - ], + "outputs": [], "source": [ "#it still chooses random actions but we can prevent this by providing a scheduler:\n", "taken_actions = {}\n", @@ -161,74 +128,35 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "id": "22871288-755c-463f-9150-f207c2f5c211", "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f92e1a66e2034c70832b36e2bd7fe70f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b413c4c87cb34895b0a4551b4e49ed16", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(Output(), Output()))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "#we can also visualize the partial model that we get from the simulator:\n", - "show.show(partial_model)" + "show.show(partial_model)\n" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "id": "34d0c293-d090-4e3d-9e80-4351f5fcba62", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "initial state --(action: empty)--> state: 2 --(action: open0)--> state: 7 --(action: empty)--> state: 17 --(action: stay)--> state: 33\n" - ] - } - ], + "outputs": [], "source": [ "#we can also use another simulator function that returns a path instead of a partial model:\n", "path = stormvogel.simulator.simulate_path(mdp, steps=4, scheduler=scheduler, seed=123456)\n", "\n", "print(path)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99c763fa-82ea-42ff-8833-79c640f14518", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/stormvogel/model.py b/stormvogel/model.py index f0b9ac1..6b62afd 100644 --- a/stormvogel/model.py +++ b/stormvogel/model.py @@ -4,6 +4,7 @@ from enum import Enum from fractions import Fraction from typing import cast +import copy Parameter = str @@ -391,6 +392,13 @@ def supports_observations(self): """Returns whether this model supports observations.""" return self.type == ModelType.POMDP + def get_sub_model(self, states: list[State]) -> "Model": + sub_model = copy.deepcopy(self) + for state in states: + sub_model.delete_state(state, normalize=True, reassign_ids=True) + + return sub_model + def is_well_defined(self) -> bool: """Checks if all sums of outgoing transition probabilities for all states equal 1""" diff --git a/stormvogel/simulator.py b/stormvogel/simulator.py index 34d7400..2414c4a 100644 --- a/stormvogel/simulator.py +++ b/stormvogel/simulator.py @@ -277,4 +277,7 @@ def get_range_index(stateid: int): if simulator.is_done(): break - return partial_model + # TODO: refactor + states = list(partial_model.states.values()) + sub_model = model.get_sub_model(states) + return sub_model diff --git a/tests/test_simulator.py b/tests/test_simulator.py index b839d97..aef1f28 100644 --- a/tests/test_simulator.py +++ b/tests/test_simulator.py @@ -34,6 +34,8 @@ def test_simulate(): for stateid in other_dtmc.states.keys(): rewardmodel3.rewards[stateid] = float(1) + print(partial_model) + assert partial_model == other_dtmc # we make a monty hall mdp and run the simulator with it