Skip to content

Commit

Permalink
dialects: Implement affine.min. (#1847)
Browse files Browse the repository at this point in the history
  • Loading branch information
PapyChacal authored Dec 12, 2023
1 parent 82e4af2 commit 5ac911a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/filecheck/dialects/affine/affine_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@

%zero = "test.op"() : () -> index
%2 = "affine.apply"(%zero, %zero) <{"map" = affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))>}> : (index, index) -> index
%min = "affine.min"(%zero) <{"map" = affine_map<(d0) -> ((d0 + 41), d0)>}> : (index) -> index
%same_value = "affine.load"(%memref, %zero, %zero) <{"map" = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<2x3xf64>, index, index) -> f64

// CHECK: %zero = "test.op"() : () -> index
// CHECK-NEXT: %{{.*}} = "affine.apply"(%{{.*}}, %{{.*}}) <{"map" = affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))>}> : (index, index) -> index
// CHECK-NEXT: %{{.*}} = "affine.min"(%{{.*}}) <{"map" = affine_map<(d0) -> ((d0 + 41), d0)>}> : (index) -> index
// CHECK-NEXT: %same_value = "affine.load"(%memref, %zero, %zero) <{"map" = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<2x3xf64>, index, index) -> f64

func.func @empty() {
Expand Down
16 changes: 16 additions & 0 deletions xdsl/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,21 @@ def __init__(
)


@irdl_op_definition
class MinOp(IRDLOperation):
name = "affine.min"
arguments = var_operand_def(IndexType())
result = result_def(IndexType())

map = prop_def(AffineMapAttr)

def verify_(self) -> None:
if len(self.operands) != self.map.data.num_dims + self.map.data.num_symbols:
raise VerifyException(
f"{self.name} expects {self.map.data.num_dims + self.map.data.num_symbols} operands, but got {len(self.operands)}. The number of map operands must match the sum of the dimensions and symbols of its map."
)


@irdl_op_definition
class Yield(IRDLOperation):
name = "affine.yield"
Expand All @@ -242,6 +257,7 @@ def get(*operands: SSAValue | Operation) -> Yield:
If,
Store,
Load,
MinOp,
Yield,
],
[],
Expand Down

0 comments on commit 5ac911a

Please sign in to comment.