Skip to content

Commit

Permalink
UB: Add lowering to pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed Jan 16, 2025
1 parent 459f55f commit 76cfc0e
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 4 deletions.
25 changes: 25 additions & 0 deletions tests/filecheck/lower-ub-to-pairs.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: xdsl-smt "%s" -p=lower-ub-to-pairs | filecheck "%s"

%value = "smt.declare_const"() : () -> i32
%ub = ub.ub : !ub.ub_or<i32>
%non_ub = ub.from %value : !ub.ub_or<i32>
%res = ub.match %ub : !ub.ub_or<i32> -> i64 {
^bb0(%val: i32):
%x = "smt.declare_const"() : () -> i64
ub.yield %x : i64
} {
%y = "smt.declare_const"() : () -> i64
ub.yield %y : i64
}

// CHECK: %value = "smt.declare_const"() : () -> i32
// CHECK-NEXT: %0 = arith.constant 0 : i32
// CHECK-NEXT: %ub = "smt.constant_bool"() {"value" = #smt.bool_attr<true>} : () -> !smt.bool
// CHECK-NEXT: %ub_1 = "smt.utils.pair"(%0, %ub) : (i32, !smt.bool) -> !smt.utils.pair<i32, !smt.bool>
// CHECK-NEXT: %non_ub = "smt.constant_bool"() {"value" = #smt.bool_attr<false>} : () -> !smt.bool
// CHECK-NEXT: %non_ub_1 = "smt.utils.pair"(%value, %non_ub) : (i32, !smt.bool) -> !smt.utils.pair<i32, !smt.bool>
// CHECK-NEXT: %val = "smt.utils.first"(%ub_1) : (!smt.utils.pair<i32, !smt.bool>) -> i32
// CHECK-NEXT: %1 = "smt.utils.second"(%ub_1) : (!smt.utils.pair<i32, !smt.bool>) -> !smt.bool
// CHECK-NEXT: %x = "smt.declare_const"() : () -> i64
// CHECK-NEXT: %y = "smt.declare_const"() : () -> i64
// CHECK-NEXT: %res = "smt.ite"(%1, %x, %y) : (!smt.bool, i64, i64) -> i64
20 changes: 16 additions & 4 deletions xdsl_smt/dialects/ub.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ class UBOp(IRDLOperation):

name = "ub.ub"

new_ub = result_def(UBOrType)
res = result_def(UBOrType)

assembly_format = "attr-dict `:` type($new_ub)"
assembly_format = "attr-dict `:` type($res)"

def __init__(self, type: Attribute):
"""Create an UB value for the given type."""
Expand All @@ -67,9 +67,9 @@ class FromOp(IRDLOperation):
T = Annotated[Attribute, ConstraintVar("T")]

value = operand_def(T)
result = result_def(UBOrType[T])
res = result_def(UBOrType[T])

assembly_format = "$value attr-dict `:` type($result)"
assembly_format = "$value attr-dict `:` type($res)"

def __init__(self, value: SSAValue):
super().__init__(
Expand Down Expand Up @@ -106,6 +106,18 @@ def __init__(self, value: SSAValue):
regions=[value_region, ub_region],
)

@property
def value_terminator(self) -> YieldOp:
if not isinstance(self.value_region.block.last_op, YieldOp):
raise ValueError("Value case region must have a yield terminator")
return self.value_region.block.last_op

@property
def ub_terminator(self) -> YieldOp:
if not isinstance(self.ub_region.block.last_op, YieldOp):
raise ValueError("UB case region must have a yield terminator")
return self.ub_region.block.last_op

def verify_(self):
assert isattr(self.value.type, UBOrType[Attribute])
if self.value_region.blocks[0].arg_types != (self.value.type.type,):
Expand Down
130 changes: 130 additions & 0 deletions xdsl_smt/passes/lower_ub_to_pairs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from dataclasses import dataclass

from xdsl.ir import Attribute, ParametrizedAttribute, Operation, SSAValue
from xdsl.utils.isattr import isattr
from xdsl.passes import ModulePass
from xdsl.context import MLContext
from xdsl.pattern_rewriter import (
PatternRewriteWalker,
GreedyRewritePatternApplier,
RewritePattern,
PatternRewriter,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint

from xdsl.dialects.builtin import ModuleOp, AnyArrayAttr
from xdsl_smt.dialects import (
smt_utils_dialect as smt_utils,
smt_dialect as smt,
ub,
)
from xdsl_smt.traits.inhabitant import create_inhabitant


def recursively_convert_attr(attr: Attribute) -> Attribute:
"""
Recursively convert an attribute to replace all references to the effect state
into a pair between the ub flag and the memory.
"""
if isattr(attr, ub.UBOrType[Attribute]):
return smt_utils.PairType(attr.type, smt.BoolType())
if isinstance(attr, ParametrizedAttribute):
return type(attr).new(
[recursively_convert_attr(param) for param in attr.parameters]
)
if isattr(attr, AnyArrayAttr):
return AnyArrayAttr((recursively_convert_attr(value) for value in attr.data))
return attr


class LowerGenericOp(RewritePattern):
"""
Recursively lower all result types, attributes, and properties.
"""

def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter):
for result in op.results:
if (new_type := recursively_convert_attr(result.type)) != result.type:
rewriter.modify_value_type(result, new_type)

for region in op.regions:
for block in region.blocks:
for arg in block.args:
if (new_type := recursively_convert_attr(arg.type)) != arg.type:
rewriter.modify_value_type(arg, new_type)

has_done_action = False
for name, attr in op.attributes.items():
if (new_attr := recursively_convert_attr(attr)) != attr:
op.attributes[name] = new_attr
has_done_action = True
for name, attr in op.properties.items():
if (new_attr := recursively_convert_attr(attr)) != attr:
op.properties[name] = new_attr
has_done_action = True
if has_done_action:
rewriter.handle_operation_modification(op)


class LowerUBOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ub.UBOp, rewriter: PatternRewriter):
assert isattr(op.res.type, ub.UBOrType[Attribute])
inhabitant = create_inhabitant(op.res.type.type, rewriter)
if inhabitant is None:
raise ValueError(f"Type {op.res.type.type} does not have an inhabitant.")
ub_flag = smt.ConstantBoolOp(True)
pair = smt_utils.PairOp(inhabitant, ub_flag.res)
rewriter.replace_matched_op([ub_flag, pair])


class LowerFromOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ub.FromOp, rewriter: PatternRewriter):
assert isattr(op.res.type, ub.UBOrType[Attribute])
ub_flag = smt.ConstantBoolOp(False)
pair = smt_utils.PairOp(op.value, ub_flag.res)
rewriter.replace_matched_op([ub_flag, pair])


class LowerMatchOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ub.MatchOp, rewriter: PatternRewriter):
value = smt_utils.FirstOp(op.value)
poison_flag = smt_utils.SecondOp(op.value)
rewriter.insert_op_before_matched_op((value, poison_flag))
value_terminator = op.value_terminator
ub_terminator = op.ub_terminator
rewriter.inline_block(
op.value_region.block, InsertPoint.before(op), (value.res,)
)
rewriter.inline_block(op.ub_region.block, InsertPoint.before(op), ())
results = list[SSAValue]()
for val_val, val_ub in zip(value_terminator.rets, ub_terminator.rets):
val = smt.IteOp(poison_flag.res, val_val, val_ub)
results.append(val.res)
rewriter.insert_op_before_matched_op(val)

rewriter.erase_op(value_terminator)
rewriter.erase_op(ub_terminator)

rewriter.replace_matched_op([], results)


@dataclass(frozen=True)
class LowerUBToPairs(ModulePass):
name = "lower-ub-to-pairs"

def apply(self, ctx: MLContext, op: ModuleOp) -> None:
walker = PatternRewriteWalker(
GreedyRewritePatternApplier(
[
LowerUBOp(),
LowerFromOp(),
LowerMatchOp(),
LowerGenericOp(),
]
)
)
walker.rewrite_module(op)
20 changes: 20 additions & 0 deletions xdsl_smt/traits/inhabitant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from abc import abstractmethod, ABC
from xdsl.pattern_rewriter import PatternRewriter
from xdsl.ir import SSAValue, Attribute
from xdsl.dialects import builtin, arith


class HasInhabitant(ABC):
"""Return an inhabitant of the type."""

@classmethod
@abstractmethod
def create_inhabitant(cls, rewriter: PatternRewriter) -> SSAValue:
...


def create_inhabitant(type: Attribute, rewriter: PatternRewriter) -> SSAValue | None:
if isinstance(type, builtin.IntegerType):
constant = arith.Constant(builtin.IntegerAttr(0, type), type)
rewriter.insert_op_before_matched_op(constant)
return constant.result

0 comments on commit 76cfc0e

Please sign in to comment.