Skip to content

Commit

Permalink
UB: Allow to match multiple values at once
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed Jan 16, 2025
1 parent 76cfc0e commit fa9b0e6
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 35 deletions.
8 changes: 4 additions & 4 deletions tests/filecheck/dialects/ub.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
%value = "smt.declare_const"() : () -> i32
%ub = ub.ub : !ub.ub_or<i32>
%non_ub = ub.from %value : !ub.ub_or<i32>
%res = ub.match %ub : !ub.ub_or<i32> -> i64 {
^bb0(%val: i32):
%res = ub.match %ub, %non_ub : (!ub.ub_or<i32>, !ub.ub_or<i32>) -> i64 {
^bb0(%val1: i32, %val2: i32):
%x = "smt.declare_const"() : () -> i64
ub.yield %x : i64
} {
Expand All @@ -17,8 +17,8 @@
// CHECK-NEXT: %value = "smt.declare_const"() : () -> i32
// CHECK-NEXT: %ub = ub.ub : !ub.ub_or<i32>
// CHECK-NEXT: %non_ub = ub.from %value : !ub.ub_or<i32>
// CHECK-NEXT: %res = ub.match %ub : !ub.ub_or<i32> -> i64 {
// CHECK-NEXT: ^0(%val : i32):
// CHECK-NEXT: %res = ub.match %ub, %non_ub : (!ub.ub_or<i32>, !ub.ub_or<i32>) -> i64 {
// CHECK-NEXT: ^0(%val1 : i32, %val2 : i32):
// CHECK-NEXT: %x = "smt.declare_const"() : () -> i64
// CHECK-NEXT: ub.yield %x : i64
// CHECK-NEXT: } {
Expand Down
31 changes: 18 additions & 13 deletions tests/filecheck/lower-ub-to-pairs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,28 @@
%value = "smt.declare_const"() : () -> i32
%ub = ub.ub : !ub.ub_or<i32>
%non_ub = ub.from %value : !ub.ub_or<i32>
%res = ub.match %ub : !ub.ub_or<i32> -> i64 {
^bb0(%val: i32):
%res = ub.match %ub, %non_ub : (!ub.ub_or<i32>, !ub.ub_or<i32>) -> i64 {
^bb0(%val1: i32, %val2: i32):
%x = "smt.declare_const"() : () -> i64
ub.yield %x : i64
} {
%y = "smt.declare_const"() : () -> i64
ub.yield %y : i64
}

// CHECK: %value = "smt.declare_const"() : () -> i32
// CHECK-NEXT: %0 = arith.constant 0 : i32
// CHECK-NEXT: %ub = "smt.constant_bool"() {"value" = #smt.bool_attr<true>} : () -> !smt.bool
// CHECK-NEXT: %ub_1 = "smt.utils.pair"(%0, %ub) : (i32, !smt.bool) -> !smt.utils.pair<i32, !smt.bool>
// CHECK-NEXT: %non_ub = "smt.constant_bool"() {"value" = #smt.bool_attr<false>} : () -> !smt.bool
// CHECK-NEXT: %non_ub_1 = "smt.utils.pair"(%value, %non_ub) : (i32, !smt.bool) -> !smt.utils.pair<i32, !smt.bool>
// CHECK-NEXT: %val = "smt.utils.first"(%ub_1) : (!smt.utils.pair<i32, !smt.bool>) -> i32
// CHECK-NEXT: %1 = "smt.utils.second"(%ub_1) : (!smt.utils.pair<i32, !smt.bool>) -> !smt.bool
// CHECK-NEXT: %x = "smt.declare_const"() : () -> i64
// CHECK-NEXT: %y = "smt.declare_const"() : () -> i64
// CHECK-NEXT: %res = "smt.ite"(%1, %x, %y) : (!smt.bool, i64, i64) -> i64
// CHECK: builtin.module {
// CHECK-NEXT: %value = "smt.declare_const"() : () -> i32
// CHECK-NEXT: %0 = arith.constant 0 : i32
// CHECK-NEXT: %ub = "smt.constant_bool"() {"value" = #smt.bool_attr<true>} : () -> !smt.bool
// CHECK-NEXT: %ub_1 = "smt.utils.pair"(%0, %ub) : (i32, !smt.bool) -> !smt.utils.pair<i32, !smt.bool>
// CHECK-NEXT: %non_ub = "smt.constant_bool"() {"value" = #smt.bool_attr<false>} : () -> !smt.bool
// CHECK-NEXT: %non_ub_1 = "smt.utils.pair"(%value, %non_ub) : (i32, !smt.bool) -> !smt.utils.pair<i32, !smt.bool>
// CHECK-NEXT: %val1 = "smt.utils.first"(%ub_1) : (!smt.utils.pair<i32, !smt.bool>) -> i32
// CHECK-NEXT: %1 = "smt.utils.second"(%ub_1) : (!smt.utils.pair<i32, !smt.bool>) -> !smt.bool
// CHECK-NEXT: %val2 = "smt.utils.first"(%non_ub_1) : (!smt.utils.pair<i32, !smt.bool>) -> i32
// CHECK-NEXT: %2 = "smt.utils.second"(%non_ub_1) : (!smt.utils.pair<i32, !smt.bool>) -> !smt.bool
// CHECK-NEXT: %3 = "smt.or"(%2, %1) : (!smt.bool, !smt.bool) -> !smt.bool
// CHECK-NEXT: %x = "smt.declare_const"() : () -> i64
// CHECK-NEXT: %y = "smt.declare_const"() : () -> i64
// CHECK-NEXT: %res = "smt.ite"(%3, %y, %x) : (!smt.bool, i64, i64) -> i64
// CHECK-NEXT: }
33 changes: 22 additions & 11 deletions xdsl_smt/dialects/ub.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Annotated, Generic, TypeVar
from typing import Annotated, Generic, Iterable, Sequence, TypeVar

Check failure on line 2 in xdsl_smt/dialects/ub.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Import "Iterable" is not accessed (reportUnusedImport)

from xdsl.irdl import (
irdl_attr_definition,
Expand Down Expand Up @@ -86,22 +86,25 @@ class MatchOp(IRDLOperation):

T = Annotated[Attribute, ConstraintVar("T")]

value = operand_def(UBOrType[T])
values = var_operand_def(UBOrType[T])

value_region = region_def(single_block="single_block")
ub_region = region_def(single_block="single_block")

res = var_result_def()

assembly_format = "$value attr-dict-with-keyword `:` type($value) `->` type($res) $value_region $ub_region"
assembly_format = "$values attr-dict-with-keyword `:` `(` type($values) `)` `->` type($res) $value_region $ub_region"

def __init__(self, value: SSAValue):
if not isattr(value.type, UBOrType[Attribute]):
raise ValueError(f"Expected a '{UBOrType.name}' type, got {value.type}")
value_region = Region(Block((), arg_types=[value.type.type]))
def __init__(self, values: Sequence[SSAValue]):
value_types = list[UBOrType[Attribute]]()
for value in values:
if not isattr(value.type, UBOrType[Attribute]):
raise ValueError(f"Expected a '{UBOrType.name}' type, got {value.type}")
value_types.append(value.type)
value_region = Region(Block((), arg_types=value_types))
ub_region = Region(Block((), arg_types=[]))
super().__init__(
operands=[value],
operands=[values],
result_types=[],
regions=[value_region, ub_region],
)
Expand All @@ -118,12 +121,20 @@ def ub_terminator(self) -> YieldOp:
raise ValueError("UB case region must have a yield terminator")
return self.ub_region.block.last_op

@property
def value_types(self) -> Sequence[UBOrType[Attribute]]:
types = list[UBOrType[Attribute]]()
for value in self.values:
assert isattr(value.type, UBOrType[Attribute])
types.append(value.type)
return types

def verify_(self):
assert isattr(self.value.type, UBOrType[Attribute])
if self.value_region.blocks[0].arg_types != (self.value.type.type,):
value_type_type = [type.type for type in self.value_types]
if list(self.value_region.blocks[0].arg_types) != value_type_type:
raise ValueError(
"Value region must have exactly one argument of type "
f"{self.value.type.type}, got {self.value_region.blocks[0].args}"
f"{tuple(value_type_type)}, got {tuple(self.value_region.blocks[0].arg_types)}"
)
if len(self.ub_region.blocks[0].args) != 0:
raise ValueError("UB region must have no arguments")
Expand Down
27 changes: 20 additions & 7 deletions xdsl_smt/passes/lower_ub_to_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,31 @@ def match_and_rewrite(self, op: ub.FromOp, rewriter: PatternRewriter):
class LowerMatchOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ub.MatchOp, rewriter: PatternRewriter):
value = smt_utils.FirstOp(op.value)
poison_flag = smt_utils.SecondOp(op.value)
rewriter.insert_op_before_matched_op((value, poison_flag))
# Unwrap all value pairs to their value and poison flag
values = list[SSAValue]()
poison_flags = list[SSAValue]()
for value in op.values:
value_value = smt_utils.FirstOp(value)
poison_flag = smt_utils.SecondOp(value)
values.append(value_value.res)
poison_flags.append(poison_flag.res)
rewriter.insert_op_before_matched_op((value_value, poison_flag))

# Check if all values are not poison
one_is_poison = poison_flags[0]
for poison_flag in poison_flags[1:]:
or_poison = smt.OrOp(poison_flag, one_is_poison)
one_is_poison = or_poison.res
rewriter.insert_op_before_matched_op(or_poison)

# Inline both case regions
value_terminator = op.value_terminator
ub_terminator = op.ub_terminator
rewriter.inline_block(
op.value_region.block, InsertPoint.before(op), (value.res,)
)
rewriter.inline_block(op.value_region.block, InsertPoint.before(op), values)
rewriter.inline_block(op.ub_region.block, InsertPoint.before(op), ())
results = list[SSAValue]()
for val_val, val_ub in zip(value_terminator.rets, ub_terminator.rets):
val = smt.IteOp(poison_flag.res, val_val, val_ub)
val = smt.IteOp(one_is_poison, val_ub, val_val)
results.append(val.res)
rewriter.insert_op_before_matched_op(val)

Expand Down

0 comments on commit fa9b0e6

Please sign in to comment.