From 76cfc0e13d2c4cc3fc1522908f8618e673f29f4c Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 16 Jan 2025 06:23:35 +0000 Subject: [PATCH] UB: Add lowering to pairs --- tests/filecheck/lower-ub-to-pairs.mlir | 25 +++++ xdsl_smt/dialects/ub.py | 20 +++- xdsl_smt/passes/lower_ub_to_pairs.py | 130 +++++++++++++++++++++++++ xdsl_smt/traits/inhabitant.py | 20 ++++ 4 files changed, 191 insertions(+), 4 deletions(-) create mode 100644 tests/filecheck/lower-ub-to-pairs.mlir create mode 100644 xdsl_smt/passes/lower_ub_to_pairs.py create mode 100644 xdsl_smt/traits/inhabitant.py diff --git a/tests/filecheck/lower-ub-to-pairs.mlir b/tests/filecheck/lower-ub-to-pairs.mlir new file mode 100644 index 00000000..936b1407 --- /dev/null +++ b/tests/filecheck/lower-ub-to-pairs.mlir @@ -0,0 +1,25 @@ +// RUN: xdsl-smt "%s" -p=lower-ub-to-pairs | filecheck "%s" + +%value = "smt.declare_const"() : () -> i32 +%ub = ub.ub : !ub.ub_or +%non_ub = ub.from %value : !ub.ub_or +%res = ub.match %ub : !ub.ub_or -> i64 { +^bb0(%val: i32): + %x = "smt.declare_const"() : () -> i64 + ub.yield %x : i64 +} { + %y = "smt.declare_const"() : () -> i64 + ub.yield %y : i64 +} + +// CHECK: %value = "smt.declare_const"() : () -> i32 +// CHECK-NEXT: %0 = arith.constant 0 : i32 +// CHECK-NEXT: %ub = "smt.constant_bool"() {"value" = #smt.bool_attr} : () -> !smt.bool +// CHECK-NEXT: %ub_1 = "smt.utils.pair"(%0, %ub) : (i32, !smt.bool) -> !smt.utils.pair +// CHECK-NEXT: %non_ub = "smt.constant_bool"() {"value" = #smt.bool_attr} : () -> !smt.bool +// CHECK-NEXT: %non_ub_1 = "smt.utils.pair"(%value, %non_ub) : (i32, !smt.bool) -> !smt.utils.pair +// CHECK-NEXT: %val = "smt.utils.first"(%ub_1) : (!smt.utils.pair) -> i32 +// CHECK-NEXT: %1 = "smt.utils.second"(%ub_1) : (!smt.utils.pair) -> !smt.bool +// CHECK-NEXT: %x = "smt.declare_const"() : () -> i64 +// CHECK-NEXT: %y = "smt.declare_const"() : () -> i64 +// CHECK-NEXT: %res = "smt.ite"(%1, %x, %y) : (!smt.bool, i64, i64) -> i64 diff --git a/xdsl_smt/dialects/ub.py b/xdsl_smt/dialects/ub.py index 8f91b8e3..d7bb9e48 100644 --- a/xdsl_smt/dialects/ub.py +++ b/xdsl_smt/dialects/ub.py @@ -46,9 +46,9 @@ class UBOp(IRDLOperation): name = "ub.ub" - new_ub = result_def(UBOrType) + res = result_def(UBOrType) - assembly_format = "attr-dict `:` type($new_ub)" + assembly_format = "attr-dict `:` type($res)" def __init__(self, type: Attribute): """Create an UB value for the given type.""" @@ -67,9 +67,9 @@ class FromOp(IRDLOperation): T = Annotated[Attribute, ConstraintVar("T")] value = operand_def(T) - result = result_def(UBOrType[T]) + res = result_def(UBOrType[T]) - assembly_format = "$value attr-dict `:` type($result)" + assembly_format = "$value attr-dict `:` type($res)" def __init__(self, value: SSAValue): super().__init__( @@ -106,6 +106,18 @@ def __init__(self, value: SSAValue): regions=[value_region, ub_region], ) + @property + def value_terminator(self) -> YieldOp: + if not isinstance(self.value_region.block.last_op, YieldOp): + raise ValueError("Value case region must have a yield terminator") + return self.value_region.block.last_op + + @property + def ub_terminator(self) -> YieldOp: + if not isinstance(self.ub_region.block.last_op, YieldOp): + raise ValueError("UB case region must have a yield terminator") + return self.ub_region.block.last_op + def verify_(self): assert isattr(self.value.type, UBOrType[Attribute]) if self.value_region.blocks[0].arg_types != (self.value.type.type,): diff --git a/xdsl_smt/passes/lower_ub_to_pairs.py b/xdsl_smt/passes/lower_ub_to_pairs.py new file mode 100644 index 00000000..80ef68eb --- /dev/null +++ b/xdsl_smt/passes/lower_ub_to_pairs.py @@ -0,0 +1,130 @@ +from dataclasses import dataclass + +from xdsl.ir import Attribute, ParametrizedAttribute, Operation, SSAValue +from xdsl.utils.isattr import isattr +from xdsl.passes import ModulePass +from xdsl.context import MLContext +from xdsl.pattern_rewriter import ( + PatternRewriteWalker, + GreedyRewritePatternApplier, + RewritePattern, + PatternRewriter, + op_type_rewrite_pattern, +) +from xdsl.rewriter import InsertPoint + +from xdsl.dialects.builtin import ModuleOp, AnyArrayAttr +from xdsl_smt.dialects import ( + smt_utils_dialect as smt_utils, + smt_dialect as smt, + ub, +) +from xdsl_smt.traits.inhabitant import create_inhabitant + + +def recursively_convert_attr(attr: Attribute) -> Attribute: + """ + Recursively convert an attribute to replace all references to the effect state + into a pair between the ub flag and the memory. + """ + if isattr(attr, ub.UBOrType[Attribute]): + return smt_utils.PairType(attr.type, smt.BoolType()) + if isinstance(attr, ParametrizedAttribute): + return type(attr).new( + [recursively_convert_attr(param) for param in attr.parameters] + ) + if isattr(attr, AnyArrayAttr): + return AnyArrayAttr((recursively_convert_attr(value) for value in attr.data)) + return attr + + +class LowerGenericOp(RewritePattern): + """ + Recursively lower all result types, attributes, and properties. + """ + + def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter): + for result in op.results: + if (new_type := recursively_convert_attr(result.type)) != result.type: + rewriter.modify_value_type(result, new_type) + + for region in op.regions: + for block in region.blocks: + for arg in block.args: + if (new_type := recursively_convert_attr(arg.type)) != arg.type: + rewriter.modify_value_type(arg, new_type) + + has_done_action = False + for name, attr in op.attributes.items(): + if (new_attr := recursively_convert_attr(attr)) != attr: + op.attributes[name] = new_attr + has_done_action = True + for name, attr in op.properties.items(): + if (new_attr := recursively_convert_attr(attr)) != attr: + op.properties[name] = new_attr + has_done_action = True + if has_done_action: + rewriter.handle_operation_modification(op) + + +class LowerUBOp(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: ub.UBOp, rewriter: PatternRewriter): + assert isattr(op.res.type, ub.UBOrType[Attribute]) + inhabitant = create_inhabitant(op.res.type.type, rewriter) + if inhabitant is None: + raise ValueError(f"Type {op.res.type.type} does not have an inhabitant.") + ub_flag = smt.ConstantBoolOp(True) + pair = smt_utils.PairOp(inhabitant, ub_flag.res) + rewriter.replace_matched_op([ub_flag, pair]) + + +class LowerFromOp(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: ub.FromOp, rewriter: PatternRewriter): + assert isattr(op.res.type, ub.UBOrType[Attribute]) + ub_flag = smt.ConstantBoolOp(False) + pair = smt_utils.PairOp(op.value, ub_flag.res) + rewriter.replace_matched_op([ub_flag, pair]) + + +class LowerMatchOp(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: ub.MatchOp, rewriter: PatternRewriter): + value = smt_utils.FirstOp(op.value) + poison_flag = smt_utils.SecondOp(op.value) + rewriter.insert_op_before_matched_op((value, poison_flag)) + value_terminator = op.value_terminator + ub_terminator = op.ub_terminator + rewriter.inline_block( + op.value_region.block, InsertPoint.before(op), (value.res,) + ) + rewriter.inline_block(op.ub_region.block, InsertPoint.before(op), ()) + results = list[SSAValue]() + for val_val, val_ub in zip(value_terminator.rets, ub_terminator.rets): + val = smt.IteOp(poison_flag.res, val_val, val_ub) + results.append(val.res) + rewriter.insert_op_before_matched_op(val) + + rewriter.erase_op(value_terminator) + rewriter.erase_op(ub_terminator) + + rewriter.replace_matched_op([], results) + + +@dataclass(frozen=True) +class LowerUBToPairs(ModulePass): + name = "lower-ub-to-pairs" + + def apply(self, ctx: MLContext, op: ModuleOp) -> None: + walker = PatternRewriteWalker( + GreedyRewritePatternApplier( + [ + LowerUBOp(), + LowerFromOp(), + LowerMatchOp(), + LowerGenericOp(), + ] + ) + ) + walker.rewrite_module(op) diff --git a/xdsl_smt/traits/inhabitant.py b/xdsl_smt/traits/inhabitant.py new file mode 100644 index 00000000..5ce0e1fb --- /dev/null +++ b/xdsl_smt/traits/inhabitant.py @@ -0,0 +1,20 @@ +from abc import abstractmethod, ABC +from xdsl.pattern_rewriter import PatternRewriter +from xdsl.ir import SSAValue, Attribute +from xdsl.dialects import builtin, arith + + +class HasInhabitant(ABC): + """Return an inhabitant of the type.""" + + @classmethod + @abstractmethod + def create_inhabitant(cls, rewriter: PatternRewriter) -> SSAValue: + ... + + +def create_inhabitant(type: Attribute, rewriter: PatternRewriter) -> SSAValue | None: + if isinstance(type, builtin.IntegerType): + constant = arith.Constant(builtin.IntegerAttr(0, type), type) + rewriter.insert_op_before_matched_op(constant) + return constant.result