-
Notifications
You must be signed in to change notification settings - Fork 1
/
policy_parser.py
183 lines (153 loc) · 7.21 KB
/
policy_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""Implements a parser for Policy class."""
import json
import os
from typing import Dict, Optional, Tuple
import symengine.lib.symengine_wrapper as core
from xaddpy.xadd.xadd import VAR_TYPE
from pyRDDLGym_symbolic.mdp.mdp import MDP
from pyRDDLGym_symbolic.mdp.policy import Policy
from pyRDDLGym_symbolic.utils.xadd_utils import ValueAssertion
class PolicyParser:
"""Parses a policy from a json file and its associated policy XADDs."""
def parse(
self,
mdp: MDP,
policy_fname: str,
assert_concurrency: bool = True,
concurrency: Optional[int] = None,
) -> Policy:
"""Parses the policy from a given json file.
Args:
mdp: The MDP object.
policy_fname: The policy file name.
assert_concurrency: Whether to assert concurrency. If set to True,
this method will raise an error when the concurrency condition does not
match with the given MDP configuration.
concurrency: If assert_concurrency is set to True, this argument specifies
the maximum concurrency number.
Returns:
The parsed policy object.
"""
self.mdp = mdp
assert policy_fname.endswith('.json'), 'Policy file must be a json file.'
assert os.path.exists(policy_fname), 'Policy file does not exist.'
try:
policy_dict = json.load(open(policy_fname, 'r'))
except Exception as e:
raise RuntimeError(f'Failed to load policy from {policy_fname}.') from e
assert not assert_concurrency or (assert_concurrency and concurrency is not None)
parsed_policy_dict = {}
try:
# Get the action fluents and validate them.
a_vars = policy_dict.pop('action-fluents')
self._validate_action_fluents(a_vars)
self._validate_policy_json(policy_dict)
for a_name in a_vars:
a_symbol, a_xadd = self._parse_policy_for_single_action(policy_dict, a_name)
parsed_policy_dict[a_symbol] = a_xadd
policy = Policy(parsed_policy_dict, policy_fname)
except Exception as e:
print(str(e))
raise RuntimeError(f'Failed to load policy from {policy_fname}.')
# Assert concurrency.
if assert_concurrency:
self._assert_concurrency(parsed_policy_dict, concurrency)
return policy
def _validate_action_fluents(self, a_vars):
"""Validates the action fluents."""
for a_var in a_vars:
if a_var not in self.model.action_fluents:
raise ValueError(f'Action fluent {a_var} is not defined in the model.')
def _validate_policy_json(self, policy_dict):
"""Validates the policy json."""
action_set_from_model = set(self.model.action_fluents)
action_set_from_policy = set(policy_dict.keys())
assert len(action_set_from_model.symmetric_difference(action_set_from_policy)) == 0, \
'Action set from the policy does not match with the model.'
for a, val in policy_dict.items():
assert isinstance(val, str), f'Value for action {a} must be a string file path.'
assert val.endswith('.xadd'), f'Value for action {a} must be a string file path ending with .xadd.'
assert os.path.exists(val), f'Value for action {a} must be a string file path that exists.'
def _parse_policy_for_single_action(self, policy_dict, a_name) -> Tuple[VAR_TYPE, int]:
"""Parses the policy for a single action."""
assert a_name in policy_dict, f'Action {a_name} not found in the policy.'
path = policy_dict[a_name]
a_type = self.model.variable_ranges[a_name]
# Parse the XADD from file.
a_dd = self.context.import_xadd(path)
a_var = self.model.ns[a_name]
# Boolean action should condition on itself being True or False.
if a_type == 'bool':
a_var_id, _ = self.context.get_dec_expr_index(a_var, create=False)
high = a_dd
low = self.context.apply(self.context.ONE, high, op='subtract')
a_dd = self.context.get_inode_canon(a_var_id, low, high)
# Validate the given action XADD.
self._validate_action(a_dd, a_type)
return a_var, a_dd
def _validate_action(self, a_dd: int, a_type: str):
"""Validates the action XADD."""
if a_type == 'bool':
self._validate_bool_action_dd(a_dd)
elif a_type == 'real':
self._validate_cont_action_dd(a_dd)
else:
raise ValueError(f'Action type {a_type} not supported.')
def _validate_bool_action_dd(self, a_dd: int):
"""Validates the boolean action XADD."""
var_set = self.context.collect_vars(a_dd)
# Boolean action should only depend on state fluents.
s_vars = self.mdp.cont_s_vars.union(self.mdp.bool_s_vars)
assert len(var_set.difference(s_vars)) == 1, (
'Boolean action should only depend on state fluents plus the action'
f'variable, but found {var_set.difference(s_vars)}'
)
# Check leaf value types.
leaf_op = ValueAssertion(
self.context,
fn=lambda x: int(x) >= 0 and int(x) <= 1,
msg='Boolean action leaf {leaf_val} is not within range [0, 1]',
)
self.context.reduce_process_xadd_leaf(a_dd, leaf_op, [], [])
def _validate_cont_action_dd(self, a_dd: int):
"""Validates the continuous action XADD."""
var_set = self.context.collect_vars(a_dd)
# Continuous action should depend on state fluents and boolean actions.
s_vars = self.mdp.cont_s_vars.union(self.mdp.bool_s_vars)
bool_a_vars = self.mdp.bool_a_vars
assert len(var_set.difference(s_vars.union(bool_a_vars))) == 0, \
'Continuous action should only depend on state fluents and boolean actions.'
# Check leaf value types.
leaf_op = ValueAssertion(
self.context,
fn=lambda x: not isinstance(x, core.BooleanAtom),
msg='Continuous action leaf {leaf_val} is a Boolean value.'
)
self.context.reduce_process_xadd_leaf(a_dd, leaf_op, [], [])
def _assert_concurrency(
self,
policy: Dict[VAR_TYPE, int],
concurrency: int,
):
"""Asserts the concurrency is satisfied for the given policy."""
bool_dd = self.context.ZERO
for a, dd in policy.items():
a_name = self.model._sym_var_name_to_var_name[str(a)]
a_type = self.model.variable_ranges[a_name]
if a_type == 'bool':
a_true_dd = self.context.unary_op(dd, op='ceil')
bool_dd = self.context.apply(bool_dd, a_true_dd, 'add')
# Assert concurrency.
if bool_dd != self.context.ZERO:
leaf_op = ValueAssertion(
self.context,
fn=lambda x: float(x) <= concurrency,
msg='Concurrency condition is not satisfied with {leaf_val}.'
)
self.context.reduce_process_xadd_leaf(bool_dd, leaf_op, [], [])
@property
def model(self):
return self.mdp.model
@property
def context(self):
return self.mdp.context