From 53d70a67185aa641ae89a06efeabb426b0fd6bde Mon Sep 17 00:00:00 2001 From: spica Date: Fri, 20 Dec 2024 10:31:15 -0800 Subject: [PATCH] transfer function utils added --- .../utils/transfer_function_check_util.py | 424 ++++++++++++++++++ xdsl_smt/utils/transfer_function_util.py | 347 ++++++++++++++ 2 files changed, 771 insertions(+) create mode 100644 xdsl_smt/utils/transfer_function_check_util.py create mode 100644 xdsl_smt/utils/transfer_function_util.py diff --git a/xdsl_smt/utils/transfer_function_check_util.py b/xdsl_smt/utils/transfer_function_check_util.py new file mode 100644 index 00000000..44823341 --- /dev/null +++ b/xdsl_smt/utils/transfer_function_check_util.py @@ -0,0 +1,424 @@ +from .transfer_function_util import ( + replace_abstract_value_width, + get_argument_widths_with_effect, + get_argument_instances_with_effect, + call_function_and_assert_result_with_effect, + call_function_with_effect, +) +from ..dialects.smt_dialect import ( + DefineFunOp, + AssertOp, + CheckSatOp, + ConstantBoolOp, +) +from ..dialects.smt_bitvector_dialect import ( + ConstantOp, +) +from ..dialects.smt_utils_dialect import FirstOp +from xdsl.dialects.func import FuncOp +from ..dialects.transfer import AbstractValueType +from xdsl.ir import Operation, SSAValue, Attribute +from ..utils.transfer_function_util import ( + call_function_and_assert_result, + get_result_width, + SMTTransferFunction, + FunctionCollection, +) + +""" +This file contains property checkers used in the client including: +1. valid_abstract_domain_check +2. int_attr_check +3. forward_soundness_check +4. backward_soundness_check +5. forward_precision_check (not implemented) +6. backward_precision_check (not implemented) +""" + + +""" +Given a transfer function, check if its result construct a valid abstract domain +""" + + +def valid_abstract_domain_check( + transfer_function: SMTTransferFunction, + domain_constraint: FunctionCollection, + int_attr: dict[int, int], +): + effect = ConstantBoolOp(False) + abstract_func = transfer_function.transfer_function + abs_op_constraint = transfer_function.abstract_constraint + assert abstract_func is not None + abs_arg_ops = get_argument_instances_with_effect(abstract_func, int_attr) + abs_args: list[SSAValue] = [arg.res for arg in abs_arg_ops] + is_abstract_arg = transfer_function.is_abstract_arg + + constant_bv_0 = ConstantOp(0, 1) + constant_bv_1 = ConstantOp(1, 1) + + arg_widths = get_argument_widths_with_effect(abstract_func) + result_width = get_result_width(abstract_func) + + abs_domain_constraints_ops: list[Operation] = [] + for i, abs_arg in enumerate(abs_args): + if is_abstract_arg[i]: + abs_domain_constraints_ops += call_function_and_assert_result_with_effect( + domain_constraint.getFunctionByWidth(arg_widths[i]), + [abs_arg], + constant_bv_1, + effect.res, + ) + + abs_arg_constraints_ops: list[Operation] = [] + if abs_op_constraint is not None: + abs_arg_constraints_ops = call_function_and_assert_result_with_effect( + abs_op_constraint, abs_args, constant_bv_1, effect.res + ) + + call_abs_func_op, call_abs_func_first_op = call_function_with_effect( + abstract_func, abs_args, effect.res + ) + abs_result_domain_invalid_ops = call_function_and_assert_result_with_effect( + domain_constraint.getFunctionByWidth(result_width), + [call_abs_func_first_op.res], + constant_bv_0, + effect.res, + ) + return ( + [effect] + + abs_arg_ops + + [constant_bv_0, constant_bv_1] + + abs_domain_constraints_ops + + abs_arg_constraints_ops + + [call_abs_func_op, call_abs_func_first_op] + + abs_result_domain_invalid_ops + ) + + +""" +Given the transfer function and a set of integer attributes associated with the function, +returns if the attr set makes the operation valid. +For example: +trunc %a from i32 to i64 (invalid) +trunc %a from i5 to i3 (valid) +""" + + +def int_attr_check( + transfer_function: SMTTransferFunction, + domain_constraint: FunctionCollection, + instance_constraint: FunctionCollection, + int_attr: dict[int, int], +) -> list[Operation]: + if transfer_function.int_attr_constraint is not None: + effect = ConstantBoolOp(False) + int_attr_constraint = transfer_function.int_attr_constraint + int_attr_constraint_arg_ops = get_argument_instances_with_effect( + int_attr_constraint, int_attr + ) + int_attr_constraint_arg: list[SSAValue] = [ + arg.res for arg in int_attr_constraint_arg_ops + ] + + constant_bv_1 = ConstantOp(1, 1) + + call_constraint_ops = call_function_and_assert_result_with_effect( + int_attr_constraint, int_attr_constraint_arg, constant_bv_1, effect.res + ) + return ( + [effect] + + int_attr_constraint_arg_ops + + [constant_bv_1] + + call_constraint_ops + + [CheckSatOp()] + ) + else: + true_op = ConstantBoolOp(True) + assert_op = AssertOp(true_op.res) + return [true_op, assert_op, CheckSatOp()] + + +""" +Check the soundness for a forward transfer function +""" + + +def forward_soundness_check( + transfer_function: SMTTransferFunction, + domain_constraint: FunctionCollection, + instance_constraint: FunctionCollection, + int_attr: dict[int, int], +) -> list[Operation]: + assert transfer_function.is_forward + abstract_func = transfer_function.transfer_function + concrete_func = transfer_function.concrete_function + abs_op_constraint = transfer_function.abstract_constraint + op_constraint = transfer_function.op_constraint + is_abstract_arg = transfer_function.is_abstract_arg + + assert abstract_func is not None + assert concrete_func is not None + + abs_arg_ops = get_argument_instances_with_effect(abstract_func, int_attr) + abs_args: list[SSAValue] = [arg.res for arg in abs_arg_ops] + crt_arg_ops = get_argument_instances_with_effect(concrete_func, int_attr) + crt_args_with_poison: list[SSAValue] = [arg.res for arg in crt_arg_ops] + crt_arg_first_ops: list[FirstOp] = [FirstOp(arg) for arg in crt_args_with_poison] + crt_args: list[SSAValue] = [arg.res for arg in crt_arg_first_ops] + + assert len(abs_args) == len(crt_args) + arg_widths = get_argument_widths_with_effect(concrete_func) + result_width = get_result_width(concrete_func) + + effect = ConstantBoolOp(False) + constant_bv_0 = ConstantOp(0, 1) + constant_bv_1 = ConstantOp(1, 1) + + abs_arg_include_crt_arg_constraints_ops: list[Operation] = [] + abs_domain_constraints_ops: list[Operation] = [] + for i, (abs_arg, crt_arg) in enumerate(zip(abs_args, crt_args)): + if is_abstract_arg[i]: + abs_arg_include_crt_arg_constraints_ops += ( + call_function_and_assert_result_with_effect( + instance_constraint.getFunctionByWidth(arg_widths[i]), + [abs_arg, crt_arg], + constant_bv_1, + effect.res, + ) + ) + abs_domain_constraints_ops += call_function_and_assert_result_with_effect( + domain_constraint.getFunctionByWidth(arg_widths[i]), + [abs_arg], + constant_bv_1, + effect.res, + ) + + abs_arg_constraints_ops: list[Operation] = [] + if abs_op_constraint is not None: + abs_arg_constraints_ops = call_function_and_assert_result_with_effect( + abs_op_constraint, abs_args, constant_bv_1, effect.res + ) + crt_args_constraints_ops: list[Operation] = [] + if op_constraint is not None: + crt_args_constraints_ops = call_function_and_assert_result_with_effect( + op_constraint, crt_args, constant_bv_1, effect.res + ) + + call_abs_func_op, call_abs_func_first_op = call_function_with_effect( + abstract_func, abs_args, effect.res + ) + call_crt_func_op, call_crt_func_first_op = call_function_with_effect( + concrete_func, crt_args_with_poison, effect.res + ) + call_crt_first_op = FirstOp(call_crt_func_first_op.res) + + abs_result_not_include_crt_result_ops = call_function_and_assert_result_with_effect( + instance_constraint.getFunctionByWidth(result_width), + [call_abs_func_first_op.res, call_crt_first_op.res], + constant_bv_0, + effect.res, + ) + + return ( + [effect] + + abs_arg_ops + + crt_arg_ops + + crt_arg_first_ops + + [constant_bv_0, constant_bv_1] + + abs_domain_constraints_ops + + abs_arg_include_crt_arg_constraints_ops + + abs_arg_constraints_ops + + crt_args_constraints_ops + + [ + call_abs_func_op, + call_abs_func_first_op, + call_crt_func_op, + call_crt_func_first_op, + call_crt_first_op, + ] + + abs_result_not_include_crt_result_ops + + [CheckSatOp()] + ) + + +""" +Check the soundness for a backward transfer function +""" + + +def backward_soundness_check( + transfer_function: SMTTransferFunction, + domain_constraint: FunctionCollection, + instance_constraint: FunctionCollection, + int_attr: dict[int, int], +) -> list[Operation]: + assert not transfer_function.is_forward + operationNo = transfer_function.operationNo + abstract_func = transfer_function.transfer_function + concrete_func = transfer_function.concrete_function + abs_op_constraint = transfer_function.abstract_constraint + op_constraint = transfer_function.op_constraint + is_abstract_arg = transfer_function.is_abstract_arg + + effect = ConstantBoolOp(False) + assert abstract_func is not None + assert concrete_func is not None + arg_widths = get_argument_widths_with_effect(concrete_func) + result_width = get_result_width(concrete_func) + + # replace the only abstract arg in transfer_function with bv with result_width + assert sum(is_abstract_arg) == 1 + abs_arg_idx = is_abstract_arg.index(True) + old_abs_arg = abstract_func.body.block.args[abs_arg_idx] + assert isinstance(old_abs_arg.type, Attribute) + new_abs_arg_type = replace_abstract_value_width(old_abs_arg.type, result_width) + new_abs_arg = abstract_func.body.block.insert_arg(new_abs_arg_type, abs_arg_idx) + abstract_func.body.block.args[abs_arg_idx + 1].replace_by(new_abs_arg) + abstract_func.body.block.erase_arg(old_abs_arg) + + abs_arg_ops = get_argument_instances_with_effect(abstract_func, int_attr) + abs_args: list[SSAValue] = [arg.res for arg in abs_arg_ops] + + crt_arg_ops = get_argument_instances_with_effect(concrete_func, int_attr) + crt_args_with_poison: list[SSAValue] = [arg.res for arg in crt_arg_ops] + crt_arg_first_ops = [FirstOp(arg) for arg in crt_args_with_poison] + crt_args: list[SSAValue] = [arg.res for arg in crt_arg_first_ops] + + constant_bv_0 = ConstantOp(0, 1) + constant_bv_1 = ConstantOp(1, 1) + + call_abs_func_op, call_abs_func_first_op = call_function_with_effect( + abstract_func, abs_args, effect.res + ) + call_crt_func_op, call_crt_func_first_op = call_function_with_effect( + concrete_func, crt_args_with_poison, effect.res + ) + call_crt_func_res_op = FirstOp(call_crt_func_first_op.res) + + abs_domain_constraints_ops = call_function_and_assert_result_with_effect( + domain_constraint.getFunctionByWidth(result_width), + [abs_args[0]], + constant_bv_1, + effect.res, + ) + + abs_arg_include_crt_res_constraint_ops = ( + call_function_and_assert_result_with_effect( + instance_constraint.getFunctionByWidth(result_width), + [abs_args[0], call_crt_func_res_op.res], + constant_bv_1, + effect.res, + ) + ) + + abs_arg_constraints_ops: list[Operation] = [] + if abs_op_constraint is not None: + abs_arg_constraints_ops = call_function_and_assert_result( + abs_op_constraint, abs_args, constant_bv_1 + ) + crt_args_constraints_ops: list[Operation] = [] + if op_constraint is not None: + crt_args_constraints_ops = call_function_and_assert_result_with_effect( + op_constraint, crt_args, constant_bv_1, effect.res + ) + + abs_result_not_include_crt_arg_constraint_ops = ( + call_function_and_assert_result_with_effect( + instance_constraint.getFunctionByWidth(arg_widths[operationNo]), + [call_abs_func_first_op.res, crt_args[operationNo]], + constant_bv_0, + effect.res, + ) + ) + + return ( + [effect] + + abs_arg_ops + + crt_arg_ops + + [constant_bv_0, constant_bv_1] + + [ + call_abs_func_op, + call_abs_func_first_op, + call_crt_func_op, + call_crt_func_first_op, + call_crt_func_res_op, + ] + + abs_domain_constraints_ops + + abs_arg_include_crt_res_constraint_ops + + abs_arg_constraints_ops + + crt_args_constraints_ops + + abs_result_not_include_crt_arg_constraint_ops + + [CheckSatOp()] + ) + + +""" +Check the precision for a forward transfer function +""" + + +def forward_precision_check( + transfer_function: SMTTransferFunction, + domain_constraint: FunctionCollection, + instance_constraint: FunctionCollection, +): + assert transfer_function.is_forward + + +""" +Check the precision for a backward transfer function +""" + + +def backward_precision_check( + transfer_function: SMTTransferFunction, + domain_constraint: FunctionCollection, + instance_constraint: FunctionCollection, +): + assert not transfer_function.is_forward + + +""" +Check if the transfer function breaks any other constraints. Such constraints +are described by the parameter counter_func +""" + + +def counterexample_check( + counter_func: FuncOp, + smt_counter_func: DefineFunOp, + domain_constraint: FunctionCollection, + int_attr: dict[int, int], +): + is_abstract_arg: list[bool] = [ + isinstance(arg, AbstractValueType) for arg in counter_func.args + ] + effect = ConstantBoolOp(False) + arg_ops = get_argument_instances_with_effect(smt_counter_func, int_attr) + args: list[SSAValue] = [arg.res for arg in arg_ops] + arg_widths = get_argument_widths_with_effect(smt_counter_func) + + constant_bv_1 = ConstantOp(1, 1) + + abs_domain_constraints_ops: list[Operation] = [] + for i, arg in enumerate(args): + if is_abstract_arg[i]: + abs_domain_constraints_ops += call_function_and_assert_result_with_effect( + domain_constraint.getFunctionByWidth(arg_widths[i]), + [arg], + constant_bv_1, + effect.res, + ) + call_counterexample_func_ops = call_function_and_assert_result_with_effect( + smt_counter_func, args, constant_bv_1, effect.res + ) + + return ( + [effect] + + arg_ops + + [constant_bv_1] + + abs_domain_constraints_ops + + call_counterexample_func_ops + + [CheckSatOp()] + ) diff --git a/xdsl_smt/utils/transfer_function_util.py b/xdsl_smt/utils/transfer_function_util.py new file mode 100644 index 00000000..c8677230 --- /dev/null +++ b/xdsl_smt/utils/transfer_function_util.py @@ -0,0 +1,347 @@ +from typing import Callable + +from xdsl.context import MLContext +from xdsl.utils.hints import isa +from ..dialects.smt_dialect import ( + DefineFunOp, + DeclareConstOp, + CallOp, + AssertOp, + EqOp, + AndOp, + BoolType, +) +from ..dialects.smt_bitvector_dialect import ( + ConstantOp, + BitVectorType, +) +from ..dialects.smt_utils_dialect import FirstOp, PairType, SecondOp, AnyPairType +from xdsl.dialects.func import FuncOp +from ..dialects.transfer import AbstractValueType +from xdsl.ir import Operation, SSAValue, Attribute +from xdsl.dialects.builtin import ( + FunctionType, +) + + +# Given a function in smt dialect and its args, return CallOp(func, args) with type checking +def call_function(func: DefineFunOp, args: list[SSAValue]) -> CallOp: + func_args = func.body.block.args + assert len(func_args) == len(args) + for f_arg, arg in zip(func_args, args): + if f_arg.type != arg.type: + print(func.fun_name) + print(func_args) + print(args) + assert f_arg.type == arg.type + callOp = CallOp.get(func.results[0], args) + return callOp + + +# In current design, a FuncOp is lowered to DefineFunOp with receiving and returning the global effect. +# However, transfer functions don't use that field. +# This function is a shortcut for calling a DefineFunOp with adding the effect to arguments, +# and removing the effect from the returned value. +def call_function_with_effect( + func: DefineFunOp, args: list[SSAValue], effect: SSAValue +) -> tuple[CallOp, FirstOp]: + new_args = args + [effect] + callOp = call_function(func, new_args) + assert len(callOp.res) == 1 + callOpFirst = FirstOp(callOp.res[0]) + callOpSecond = SecondOp(callOp.res[0]) + # Assume the global effect has a bool type + assert isinstance(callOpSecond.res.type, BoolType) + return callOp, callOpFirst + + +# Given a SSAValue and bv constant, asserts the SSAValue equals to the given constant +def assert_result(result: SSAValue, bv: ConstantOp) -> list[Operation]: + eqOp = EqOp.get(result, bv.res) + assertOp = AssertOp.get(eqOp.res) + return [eqOp, assertOp] + + +# Given a function, its argument and a constant bv, assert the return value by CallOp(func, args) +# equals to the bv +def call_function_and_assert_result( + func: DefineFunOp, args: list[SSAValue], bv: ConstantOp +) -> list[Operation]: + callOp = call_function(func, args) + assert len(callOp.results) == 1 + firstOp = FirstOp(callOp.results[0]) + assertOps = assert_result(firstOp.res, bv) + return [callOp, firstOp] + assertOps + + +# Given a function with global effect, its argument and a constant bv, assert the return value of function calling +# equals to the bv +def call_function_and_assert_result_with_effect( + func: DefineFunOp, args: list[SSAValue], bv: ConstantOp, effect: SSAValue +) -> list[Operation]: + callOp, callFirstOp = call_function_with_effect(func, args, effect) + firstOp = FirstOp(callFirstOp.res) + assertOps = assert_result(firstOp.res, bv) + return [callOp, callFirstOp, firstOp] + assertOps + + +# Given a function, construct a list of argument instances +# by DeclareConstOp or ConstantOp except for the last effect argument +# Some operations require certain arguments must be constant (e.g. the length of truncated integer), +# and this information is maintained in int_attr +def get_argument_instances_with_effect( + func: DefineFunOp, int_attr: dict[int, int] +) -> list[DeclareConstOp | ConstantOp]: + result: list[DeclareConstOp | ConstantOp] = [] + # ignore last effect arg + assert isinstance(func.body.block.args[-1].type, BoolType) + for i, arg in enumerate(func.body.block.args[:-1]): + argType = arg.type + if i in int_attr: + result.append(ConstantOp(int_attr[i], get_width_from_type(argType))) + else: + result.append(DeclareConstOp(argType)) + return result + + +# Given a function, construct a list of argument instances by DeclareConstOp or ConstantOp +# Some operations require certain arguments must be constant (e.g. the length of truncated integer), +# and this information is maintained in int_attr +def get_argument_instances( + func: DefineFunOp, int_attr: dict[int, int] +) -> list[DeclareConstOp | ConstantOp]: + result: list[DeclareConstOp | ConstantOp] = [] + for i, arg in enumerate(func.body.block.args): + argType = arg.type + if i in int_attr: + result.append(ConstantOp(int_attr[i], get_width_from_type(argType))) + else: + result.append(DeclareConstOp(argType)) + return result + + +# Given a function, construct a list of its returned value by DeclareConstOp +# We assume only the first returned value is useful +def get_result_instance(func: DefineFunOp) -> list[DeclareConstOp]: + return_type = func.func_type.outputs.data[0] + return [DeclareConstOp(return_type)] + + +# Given a bit vector type or a pair type including a bit vector, +# returns the bit width of that bit vector +def get_width_from_type(ty: Attribute) -> int: + while isa(ty, AnyPairType): + assert isinstance(ty.first, Attribute) + ty = ty.first + if isinstance(ty, BitVectorType): + return ty.width.data + assert False + + +# Given a pair type and a bit width, this function replaces all bit vector type +# in the input pair type with a new bit vector type with the new bit width +def replace_abstract_value_width( + abs_val_ty: AnyPairType | Attribute, new_width: int +) -> AnyPairType: + types: list[Attribute] = [] + while isa(abs_val_ty, AnyPairType): + assert isinstance(abs_val_ty.first, Attribute) + types.append(abs_val_ty.first) + assert isinstance(abs_val_ty.second, Attribute) + abs_val_ty = abs_val_ty.second + types.append(abs_val_ty) + for i in range(len(types)): + if isinstance(types[i], BitVectorType): + types[i] = BitVectorType.from_int(new_width) + resultType = types.pop() + while len(types) > 0: + resultType = PairType(types.pop(), resultType) + assert isa(resultType, AnyPairType) + return resultType + + +# Given a smt function, returns a list of bit width for every argument of the function except for the last effect +def get_argument_widths_with_effect(func: DefineFunOp) -> list[int]: + # ignore last effect + return [get_width_from_type(arg.type) for arg in func.body.block.args[:-1]] + + +# Given a smt function, returns a list of bit width for every argument of the function +def get_argument_widths(func: DefineFunOp) -> list[int]: + return [get_width_from_type(arg.type) for arg in func.body.block.args] + + +# Given a smt function, returns the bit width of the returned value +def get_result_width(func: DefineFunOp) -> int: + return get_width_from_type(func.func_type.outputs.data[0]) + + +# Given a list of operations returning bool type, this function performs +# bool and operation on all operations, and returns a tuple of the final combined +# result and a list of constructed and operations +def compress_and_op(lst: list[Operation]) -> tuple[SSAValue, list[Operation]]: + if len(lst) == 0: + assert False and "cannot compress lst with size 0 to an AndOp" + elif len(lst) == 1: + empty_result: list[Operation] = [] + return (lst[0].results[0], empty_result) + else: + new_ops: list[Operation] = [AndOp(lst[0].results[0], lst[1].results[0])] + for i in range(2, len(lst)): + new_ops.append(AndOp(new_ops[-1].results[0], lst[i].results[0])) + return (new_ops[-1].results[0], new_ops) + + +# Given two smt functions, returns true if both are None or have the same function type +def compare_defining_op(func: DefineFunOp | None, func1: DefineFunOp | None) -> bool: + func_none: bool = func is None + func1_none: bool = func1 is None + if func_none ^ func1_none: + return False + if func_none or func1_none: + return True + for arg, arg1 in zip(func.body.block.args, func1.body.block.args): + if arg.type != arg1.type: + return False + return func.func_type.outputs.data[0] == func1.func_type.outputs.data[0] + + +# Given a smt function, if the type of returned value doesn't match +# the type of function signature, it replaces the output type of function signature +# with the actual returned type +def fix_defining_op_return_type(func: DefineFunOp) -> DefineFunOp: + smt_func_type = func.func_type + ret_val_type = [ret.type for ret in func.return_values] + if smt_func_type != ret_val_type: + new_smt_func_type = FunctionType.from_lists( + smt_func_type.inputs.data, ret_val_type + ) + func.ret.type = new_smt_func_type + return func + + +# This class maintains a map from width(int) -> smt function function +# When the desired function with given width doesn't exist, it generates one +# and returns it as the result +# This class is used when we need several instances of one same function but with different +# possible bit widths. +class FunctionCollection: + main_func: FuncOp + smt_funcs: dict[int, DefineFunOp] = {} + create_smt: Callable[[FuncOp, int, MLContext], DefineFunOp] + ctx: MLContext + + def __init__( + self, + func: FuncOp, + create_smt: Callable[[FuncOp, int, MLContext], DefineFunOp], + ctx: MLContext, + ): + self.main_func = func + self.create_smt = create_smt + self.smt_funcs = {} + self.ctx = ctx + + def getFunctionByWidth(self, width: int) -> DefineFunOp: + if width not in self.smt_funcs: + self.smt_funcs[width] = self.create_smt(self.main_func, width, self.ctx) + return self.smt_funcs[width] + + +# This class maintains information about a transfer function before lowering to smt +class TransferFunction: + # is_abstract_arg[ith] == True -> ith argument of the transfer function is an abstract value + # is_abstract_arg[ith] == False -> ith argument of the transfer function is not an abstract value, + # which maybe a constant value or extra parameters + is_abstract_arg: list[bool] = [] + name: str = "" + # indicates if this transfer function applies forwards or backwards + is_forward: bool = True + # This field indicates if some arguments should be replaced by a constant such as + # the length of truncated integer + replace_int_attr: bool = False + # When the transfer function applies backwards, this field indicates which argument it applies to + operationNo: int = -1 + transfer_function: FuncOp + + def __init__( + self, + transfer_function: FuncOp, + is_forward: bool = True, + operationNo: int = -1, + replace_int_attr: bool = False, + ): + self.name = transfer_function.sym_name.data + self.is_forward = is_forward + self.operationNo = operationNo + is_abstract_arg: list[bool] = [] + self.transfer_function = transfer_function + func_type = transfer_function.function_type + for func_type_arg, arg in zip(func_type.inputs, transfer_function.args): + assert func_type_arg == arg.type + is_abstract_arg.append(isinstance(arg.type, AbstractValueType)) + self.is_abstract_arg = is_abstract_arg + self.replace_int_attr = replace_int_attr + + +# This class maintains information about a transfer function after lowering to SMT +class SMTTransferFunction: + is_abstract_arg: list[bool] = [] + is_forward: bool = True + operationNo: int = -1 + transfer_function_name: str + transfer_function: DefineFunOp | None = None + concrete_function_name: str + concrete_function: DefineFunOp | None = None + + # This function describes constraints applied on arguments of the transfer function. + # For example, transfer functions in demanded bits use known bits information as extra parameters, + # we have to make sure all known bits are in valid domain. + abstract_constraint: DefineFunOp | None + # This function describes constraints applied on arguments of the concrete function. + # For example, SHL requires the shifting amount must be in a valid range + op_constraint: DefineFunOp | None + # Except for the basic soundness property checker, if there are other scenarios making the transfer function unsound + soundness_counterexample: DefineFunOp | None + + int_attr_arg: list[int] | None + # This function maintains the constraint of integer attributes + # For example, the truncated length should be larger than 0 and less than the total bit width + int_attr_constraint: DefineFunOp | None + + def __init__( + self, + transfer_function_name: str, + transfer_function: DefineFunOp | None, + tfRecord: dict[str, TransferFunction], + concrete_function_name: str, + concrete_function: DefineFunOp | None, + abstract_constraint: DefineFunOp | None, + op_constraint: DefineFunOp | None, + soundness_counterexample: DefineFunOp | None, + int_attr_arg: list[int] | None, + int_attr_constraint: DefineFunOp | None, + ): + self.transfer_function_name = transfer_function_name + self.concrete_function_name = concrete_function_name + assert self.transfer_function_name in tfRecord + tf = tfRecord[self.transfer_function_name] + self.transfer_function = transfer_function + self.is_forward = tf.is_forward + self.is_abstract_arg = tf.is_abstract_arg + self.concrete_function = concrete_function + self.abstract_constraint = abstract_constraint + self.op_constraint = op_constraint + self.operationNo = tf.operationNo + self.soundness_counterexample = soundness_counterexample + self.int_attr_arg = int_attr_arg + self.int_attr_constraint = int_attr_constraint + + def verify(self): + assert compare_defining_op(self.transfer_function, self.abstract_constraint) + assert compare_defining_op(self.concrete_function, self.op_constraint) + if self.transfer_function is not None: + assert len(self.is_abstract_arg) == len( + self.transfer_function.body.block.args + ) + assert self.is_forward ^ (self.operationNo != -1)