Skip to content

Commit

Permalink
all simple tests now pass for mdps
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Jul 25, 2024
1 parent 54ab304 commit 595372e
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 38 deletions.
4 changes: 2 additions & 2 deletions examples/monty_hall.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import stormvogel.model


def create_monty_hall_dtmc():
def create_monty_hall_mdp():
mdp = stormvogel.model.new_mdp("Monty Hall")

init = mdp.get_initial_state()
Expand Down Expand Up @@ -68,4 +68,4 @@ def create_monty_hall_dtmc():

if __name__ == "__main__":
# Print the resulting model in dot format.
print(create_monty_hall_dtmc().to_dot())
print(create_monty_hall_mdp().to_dot())
128 changes: 128 additions & 0 deletions examples/stormpy_mdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import stormpy


# Knuth's model of a fair die using only fair coins
def example_building_mdps_01():
nr_states = 13
nr_choices = 14

# Transition matrix with custom row grouping: nondeterministic choice over the actions available in states
builder = stormpy.SparseMatrixBuilder(
rows=0,
columns=0,
entries=0,
force_dimensions=False,
has_custom_row_grouping=True,
row_groups=0,
)

# New row group, for actions of state 0
builder.new_row_group(0)
builder.add_next_value(0, 1, 0.5)
builder.add_next_value(0, 2, 0.5)
builder.add_next_value(1, 1, 0.2)
builder.add_next_value(1, 2, 0.8)
# State 1
builder.new_row_group(2)
builder.add_next_value(2, 3, 0.5)
builder.add_next_value(2, 4, 0.5)
# State 2
builder.new_row_group(3)
builder.add_next_value(3, 5, 0.5)
builder.add_next_value(3, 6, 0.5)
# State 3
builder.new_row_group(4)
builder.add_next_value(4, 7, 0.5)
builder.add_next_value(4, 1, 0.5)
# State 4
builder.new_row_group(5)
builder.add_next_value(5, 8, 0.5)
builder.add_next_value(5, 9, 0.5)
# State 5
builder.new_row_group(6)
builder.add_next_value(6, 10, 0.5)
builder.add_next_value(6, 11, 0.5)
# State 6
builder.new_row_group(7)
builder.add_next_value(7, 2, 0.5)
builder.add_next_value(7, 12, 0.5)

# Add transitions for the final states
for s in range(8, 14):
builder.new_row_group(s)
builder.add_next_value(s, s - 1, 1)

transition_matrix = builder.build()

# State labeling
state_labeling = stormpy.storage.StateLabeling(nr_states)
# Add labels
labels = {"init", "one", "two", "three", "four", "five", "six", "done", "deadlock"}
for label in labels:
state_labeling.add_label(label)

# Set labeling of states
state_labeling.add_label_to_state("init", 0)
state_labeling.add_label_to_state("one", 7)
state_labeling.add_label_to_state("two", 8)
state_labeling.add_label_to_state("three", 9)
state_labeling.add_label_to_state("four", 10)
state_labeling.add_label_to_state("five", 11)
state_labeling.add_label_to_state("six", 12)

# Set label 'done' for multiple states
state_labeling.set_states(
"done", stormpy.BitVector(nr_states, [7, 8, 9, 10, 11, 12])
)

# Choice labeling
choice_labeling = stormpy.storage.ChoiceLabeling(nr_choices)
choice_labels = {"a", "b"}
# Add labels
for label in choice_labels:
choice_labeling.add_label(label)

# Set labels
choice_labeling.add_label_to_choice("a", 0)
choice_labeling.add_label_to_choice("b", 1)
print(choice_labeling)

# Reward models
reward_models = {}
# Create a vector representing the state-action rewards
action_reward = [
0.0,
0.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
]
reward_models["coin_flips"] = stormpy.SparseRewardModel(
optional_state_action_reward_vector=action_reward
)

# Collect components
components = stormpy.SparseModelComponents(
transition_matrix=transition_matrix,
state_labeling=state_labeling,
reward_models=reward_models,
rate_transitions=False,
)
components.choice_labeling = choice_labeling

# Build the model
mdp = stormpy.storage.SparseMdp(components)
return mdp


if __name__ == "__main__":
example_building_mdps_01()
56 changes: 31 additions & 25 deletions stormvogel/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def stormvogel_to_stormpy(
) -> stormpy.storage.SparseDtmc | stormpy.storage.SparseMdp | None:
def map_dtmc(model: stormvogel.model.Model) -> stormpy.storage.SparseDtmc:
"""
Takes a simple representation as input and outputs a dtmc how it is represented in stormpy
Takes a simple representation of a dtmc as input and outputs a dtmc how it is represented in stormpy
"""

# we first build the SparseMatrix
Expand All @@ -62,7 +62,7 @@ def map_dtmc(model: stormvogel.model.Model) -> stormpy.storage.SparseDtmc:

def map_mdp(model: stormvogel.model.Model) -> stormpy.storage.SparseMdp:
"""
Takes a simple representation as input and outputs an mdp how it is represented in stormpy
Takes a simple representation of an mdp as input and outputs an mdp how it is represented in stormpy
"""

# we determine the number of choices and the choice labels
Expand Down Expand Up @@ -128,7 +128,7 @@ def map_mdp(model: stormvogel.model.Model) -> stormpy.storage.SparseMdp:


def stormpy_to_stormvogel(
sparsedtmc: stormpy.storage.SparseDtmc | stormpy.storage.SparseMdp,
sparsemodel: stormpy.storage.SparseDtmc | stormpy.storage.SparseMdp,
) -> stormvogel.model.Model:
def map_dtmc(sparsedtmc: stormpy.storage.SparseDtmc) -> stormvogel.model.Model:
"""
Expand All @@ -155,55 +155,61 @@ def map_dtmc(sparsedtmc: stormpy.storage.SparseDtmc) -> stormvogel.model.Model:
transitions = stormvogel.model.transition_from_shorthand(
transitionshorthand
)
# print("state:", state.id)
# print("type", type(state))
model.set_transitions(model.get_state_by_id(state.id), transitions)

# TODO rewards

return model

def map_mdp(sparsedtmc: stormpy.storage.SparseDtmc) -> stormvogel.model.Model:
def map_mdp(sparsemdp: stormpy.storage.SparseDtmc) -> stormvogel.model.Model:
"""
Takes a dtmc stormpy representation as input and outputs a simple stormvogel representation
Takes a mdp stormpy representation as input and outputs a simple stormvogel representation
"""

# we create the model
model = stormvogel.model.new_dtmc(name=None)
model = stormvogel.model.new_mdp(name=None)

# we add the states
# print("states:", len(sparsedtmc.states))
for state in sparsedtmc.states:
for state in sparsemdp.states:
# the initial state is automatically added so we don't add it
if state.id > 0:
model.new_state(labels=list(state.labels))

# we add the transitions
matrix = sparsedtmc.transition_matrix

for index, state in enumerate(sparsedtmc.states):
# we add the transitions
matrix = sparsemdp.transition_matrix
for index, state in enumerate(sparsemdp.states):
row_group_start = matrix.get_row_group_start(index)
row_group_end = matrix.get_row_group_end(index)
# print("rowgroupstart:", row_group_start)
# print("rowgroupend:", row_group_end)
# within a row group we add for each action the transitions
transition = dict()

for i in range(row_group_start, row_group_end):
row = matrix.get_row(i)
transitionshorthand = [
(x.value(), model.get_state_by_id(x.column)) for x in row
]
transitions = stormvogel.model.transition_from_shorthand(
transitionshorthand
)

# actionlabels = sparsemdp.choice_labeling.get_labels_of_choice(i)
# actionlabelslist = [str(x) for x in actionlabels]
# print(actionlabelslist)

# for now assign a name based on index
# TODO assign the correct labels and name
action = stormvogel.model.Action(str(i))

branch = [(x.value(), model.get_state_by_id(x.column)) for x in row]

transition[action] = stormvogel.model.Branch(branch)

transitions = stormvogel.model.Transition(transition)

model.set_transitions(model.get_state_by_id(state.id), transitions)

# TODO rewards

return model

if sparsedtmc.transition_matrix.has_trivial_row_grouping:
return map_dtmc(sparsedtmc)
if sparsemodel.transition_matrix.has_trivial_row_grouping:
return map_dtmc(sparsemodel)
else:
return map_mdp(sparsedtmc)
return map_mdp(sparsemodel)


if __name__ == "__main__":
Expand Down
34 changes: 31 additions & 3 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

Parameter = str

Number = float | Fraction | Parameter
Number = float | Parameter | Fraction


class ModelType(Enum):
Expand All @@ -32,11 +32,19 @@ class State:
model: The model this state belongs to.
"""

# name: str | None
labels: list[str]
features: dict[str, int]
id: int
model: "Model"

def __init__(self, labels: list[str], features: dict[str, int], id: int, model):
self.labels = labels
self.features = features
self.id = id
self.model = model
# TODO how to handle state names?

def set_transitions(self, transitions: "Transition | TransitionShorthand"):
"""Set transitions from this state."""
self.model.set_transitions(self, transitions)
Expand All @@ -51,6 +59,8 @@ def __str__(self):
def __eq__(self, other):
if isinstance(other, State):
if self.id == other.id:
self.labels.sort()
other.labels.sort()
return self.labels == other.labels
return True
return False
Expand All @@ -66,13 +76,22 @@ class Action:
"""

name: str
# labels: list[str]

# TODO action labels

# def __init__(self, name: str):
# self.name = name

# def set_labels(labels: list[str]):
# self.labels = labels

def __str__(self):
return f"Action {self.name}"

def __eq__(self, other):
if isinstance(other, Action):
return self.name == other.name
return True
return False


Expand Down Expand Up @@ -103,6 +122,11 @@ def __eq__(self, other):
return self.branch == other.branch
return False

def __lt__(self, other):
if not isinstance(other, Branch):
return NotImplemented
return str(self.branch) < str(other.branch)


@dataclass
class Transition:
Expand All @@ -125,7 +149,11 @@ def __str__(self):

def __eq__(self, other):
if isinstance(other, Transition):
return self.transition == other.transition
self_values = list(self.transition.values())
other_values = list(other.transition.values())
self_values.sort()
other_values.sort()
return self_values == other_values
return False


Expand Down
29 changes: 21 additions & 8 deletions tests/test_map.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import examples.monty_hall
import examples.stormpy_mdp
import stormvogel.map
import stormvogel.model
import examples.stormpy_to_stormvogel
Expand All @@ -7,14 +8,13 @@


def matrix_equals(
dtmc0: stormpy.storage.SparseDtmc, dtmc1: stormpy.storage.SparseDtmc
model0: stormpy.storage.SparseDtmc | stormpy.storage.SparseMdp,
model1: stormpy.storage.SparseDtmc | stormpy.storage.SparseMdp,
) -> bool:
"""
outputs true if the sparsematrices of the two sparsedtmcs are the same and false otherwise
"""
# outputs true if the sparsematrices of the two sparsedtmcs are the same and false otherwise

# TODO is there a better check for equality for matrices in storm(py)? otherwise one should perhaps be implemented
return str(dtmc0.transition_matrix) == str(dtmc1.transition_matrix)
return str(model0.transition_matrix) == str(model1.transition_matrix)


def test_stormpy_to_stormvogel_and_back_dtmc():
Expand Down Expand Up @@ -44,10 +44,23 @@ def test_stormvogel_to_stormpy_and_back_dtmc():
def test_stormvogel_to_stormpy_and_back_mdp():
# we test it for monty hall mdp
stormvogel_mdp = examples.monty_hall.create_monty_hall_mdp()
print(stormvogel_mdp)
# print(stormvogel_mdp)
stormpy_mdp = stormvogel.map.stormvogel_to_stormpy(stormvogel_mdp)
print(stormpy_mdp)
# print(stormpy_mdp)
new_stormvogel_mdp = stormvogel.map.stormpy_to_stormvogel(stormpy_mdp)
print(new_stormvogel_mdp)
# print(new_stormvogel_mdp)

assert new_stormvogel_mdp == stormvogel_mdp


def test_stormpy_to_stormvogel_and_back_mdp():
# we test it for monty hall mdp
stormpy_mdp = examples.stormpy_mdp.example_building_mdps_01()
# print(stormpy_mdp)
stormvogel_mdp = stormvogel.map.stormpy_to_stormvogel(stormpy_mdp)
# print(stormvogel_mdp)
new_stormpy_mdp = stormvogel.map.stormvogel_to_stormpy(stormvogel_mdp)
# print(new_stormpy_mdp)

# TODO also compare other parts than the matrix
assert matrix_equals(stormpy_mdp, new_stormpy_mdp)

0 comments on commit 595372e

Please sign in to comment.