-
Notifications
You must be signed in to change notification settings - Fork 1
/
z3_utils.py
157 lines (131 loc) · 4.54 KB
/
z3_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from z3 import Int, Optimize, sat, unsat, Solver
from pattern_ast import get_accesses, Op, Access, Literal, Node
from loguru import logger
from enum import Enum
class Error(Enum):
Z3_BUG = 0
def expr_to_cexpr(expr, cvars):
# logger.info(expr)
if type(expr) == Op:
args = expr.args
op = expr.op
if len(args) == 1:
cexpr = expr_to_cexpr(args[0], cvars)
if expr.op == '+':
return cexpr
elif expr.op == '-':
return -cexpr
elif len(args) == 2:
left = expr_to_cexpr(args[0], cvars)
right = expr_to_cexpr(args[1], cvars)
if left is not None and right is not None:
if op == '+':
return left + right
elif op == '*':
return left * right
elif op == '-':
return left - right
elif op == '/':
return left / right
elif type(expr) == Literal:
if expr.ty == int:
return expr.val
elif type(expr) == Access:
name = expr.pprint()
if name in cvars:
return cvars[name]
elif type(expr) == int:
return expr
return None
def get_scalar_cvars(pattern):
cvars = {}
def maybe_add(access):
if access.is_scalar() and access.var not in cvars:
cvars[access.var] = Int(access.var)
for access in get_accesses(pattern):
maybe_add(access)
for decl in pattern.decls:
for size in decl.sizes:
if size is not None:
for access in get_accesses(size):
maybe_add(access)
return cvars
def get_int_cvars(pattern, types):
cvars = {}
def maybe_add(access):
name = access.pprint()
if types.can_be(access.var, 'int') and name not in cvars:
cvars[name] = Int(name)
for access in get_accesses(pattern):
maybe_add(access)
for decl in pattern.decls:
for size in decl.sizes:
if size is not None:
for access in get_accesses(size):
maybe_add(access)
return cvars
def affine_to_cexpr(affine, cvars):
if not affine.var:
return affine.offset
return affine.coeff * cvars[affine.var] + affine.offset
def find_max(constraints, expr, l = None):
if l is None:
l = logger
if type(expr) == int:
return expr
constraint_strs = [f'{c}' for c in constraints]
max_optimize = Optimize()
max_optimize.set('timeout', 10000)
max_optimize.assert_exprs(*constraints)
max_optimize.maximize(expr)
status = max_optimize.check()
if status != sat:
l.warning(f'Unable to find max ({status}) for:\n' + '\n'.join(constraint_strs))
return None
max_val = max_optimize.model().eval(expr).as_long()
# Make sure it's actually the max, since z3 has a bug
# https://github.com/Z3Prover/z3/issues/4670
solver = Solver()
solver.set('timeout', 10000)
solver.add(constraints + [expr > max_val])
status = solver.check()
if status != unsat:
l.error(f'Z3 bug\nFind max ({expr}) => {max_val} with status ({status}):\n' + '\n'.join(constraint_strs))
return None
return max_val
def find_min(constraints, expr, l = None):
if l is None:
l = logger
if type(expr) == int:
return expr
constraint_strs = [f'{c}' for c in constraints]
min_optimize = Optimize()
min_optimize.set('timeout', 10000)
min_optimize.assert_exprs(*constraints)
min_optimize.minimize(expr)
status = min_optimize.check()
if status != sat:
l.warning(f'Unable to find min ({status}) for:\n' + '\n'.join(constraint_strs))
return None
min_val = min_optimize.model().eval(expr).as_long()
# Make sure it's actually the min, since z3 has a bug
# https://github.com/Z3Prover/z3/issues/4670
solver = Solver()
solver.set('timeout', 10000)
solver.add(constraints + [expr < min_val])
status = solver.check()
if status != unsat:
l.error(f'Z3 bug\nFind min ({expr}) => {min_val} with status ({status}):\n' + '\n'.join(constraint_strs))
return None
return min_val
def find_min_max(constraints, i):
return [f(constraints, i) for f in [find_min, find_max]]
def is_sat(constraints, print_model=False):
solver = Solver()
solver.set('timeout', 10000)
solver.add(constraints)
status = solver.check()
if print_model:
if status == sat:
logger.debug(f'Model:\n{solver.model()}')
return status == sat