Skip to content

Commit

Permalink
transform: (stencil-tensorize-z-dimension) Tensorize arith.constant d…
Browse files Browse the repository at this point in the history
…irectly (#2970)

Add support for directly tensorising `arith.constant`, e.g.
```
%0 = arith.constant dense<1.666600e-01> : tensor<510xf32>
```

The more generic way of tensorising any scalar operand (independent of
its source) would be `linalg.fill`ing the value a newly created empty
tensor. This still works as a generic way. However, in the special case
that the scalar operand is a constant, this is much shorter.

---------

Co-authored-by: n-io <[email protected]>
  • Loading branch information
n-io and n-io authored Aug 2, 2024
1 parent ffe6132 commit 1fd1d39
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 75 deletions.
86 changes: 42 additions & 44 deletions tests/filecheck/transforms/stencil-tensorize-z-dimension.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,27 @@ builtin.module {
// CHECK-NEXT: %1 = stencil.load %0 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %2 = stencil.external_load %b : memref<1024x512x512xf32> -> !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %3 = stencil.apply(%4 = %1 : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>) -> (!stencil.temp<[0,1022]x[0,510]xtensor<510xf32>>) {
// CHECK-NEXT: %5 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: %6 = stencil.access %4[1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %7 = "tensor.extract_slice"(%6) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %8 = stencil.access %4[-1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %9 = "tensor.extract_slice"(%8) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %10 = stencil.access %4[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %11 = "tensor.extract_slice"(%10) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %12 = stencil.access %4[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %13 = "tensor.extract_slice"(%12) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %14 = stencil.access %4[0, 1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %15 = "tensor.extract_slice"(%14) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %16 = stencil.access %4[0, -1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %17 = "tensor.extract_slice"(%16) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %18 = arith.addf %17, %15 : tensor<510xf32>
// CHECK-NEXT: %19 = arith.addf %18, %13 : tensor<510xf32>
// CHECK-NEXT: %20 = arith.addf %19, %11 : tensor<510xf32>
// CHECK-NEXT: %21 = arith.addf %20, %9 : tensor<510xf32>
// CHECK-NEXT: %22 = arith.addf %21, %7 : tensor<510xf32>
// CHECK-NEXT: %23 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %24 = linalg.fill ins(%5 : f32) outs(%23 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %25 = arith.mulf %22, %24 : tensor<510xf32>
// CHECK-NEXT: stencil.return %25 : tensor<510xf32>
// CHECK-NEXT: %5 = arith.constant dense<1.666600e-01> : tensor<510xf32>
// CHECK-NEXT: %6 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: %7 = stencil.access %4[1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %8 = "tensor.extract_slice"(%7) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %9 = stencil.access %4[-1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %10 = "tensor.extract_slice"(%9) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %11 = stencil.access %4[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %12 = "tensor.extract_slice"(%11) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %13 = stencil.access %4[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %14 = "tensor.extract_slice"(%13) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %15 = stencil.access %4[0, 1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %16 = "tensor.extract_slice"(%15) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %17 = stencil.access %4[0, -1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %18 = "tensor.extract_slice"(%17) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %19 = arith.addf %18, %16 : tensor<510xf32>
// CHECK-NEXT: %20 = arith.addf %19, %14 : tensor<510xf32>
// CHECK-NEXT: %21 = arith.addf %20, %12 : tensor<510xf32>
// CHECK-NEXT: %22 = arith.addf %21, %10 : tensor<510xf32>
// CHECK-NEXT: %23 = arith.addf %22, %8 : tensor<510xf32>
// CHECK-NEXT: %24 = arith.mulf %23, %5 : tensor<510xf32>
// CHECK-NEXT: stencil.return %24 : tensor<510xf32>
// CHECK-NEXT: }
// CHECK-NEXT: stencil.store %3 to %2 (<[0, 0], [1022, 510]>) : !stencil.temp<[0,1022]x[0,510]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: func.return
Expand Down Expand Up @@ -86,28 +85,27 @@ builtin.module {
// CHECK: func.func @gauss_seidel_func(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
// CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %1 = stencil.apply(%2 = %0 : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>) -> (!stencil.temp<[0,1022]x[0,510]xtensor<510xf32>>) {
// CHECK-NEXT: %3 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: %4 = stencil.access %2[1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %5 = "tensor.extract_slice"(%4) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %6 = stencil.access %2[-1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %7 = "tensor.extract_slice"(%6) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %8 = stencil.access %2[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %9 = "tensor.extract_slice"(%8) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %10 = stencil.access %2[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %11 = "tensor.extract_slice"(%10) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %12 = stencil.access %2[0, 1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %13 = "tensor.extract_slice"(%12) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %14 = stencil.access %2[0, -1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %15 = "tensor.extract_slice"(%14) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %16 = arith.addf %15, %13 : tensor<510xf32>
// CHECK-NEXT: %17 = arith.addf %16, %11 : tensor<510xf32>
// CHECK-NEXT: %18 = arith.addf %17, %9 : tensor<510xf32>
// CHECK-NEXT: %19 = arith.addf %18, %7 : tensor<510xf32>
// CHECK-NEXT: %20 = arith.addf %19, %5 : tensor<510xf32>
// CHECK-NEXT: %21 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %22 = linalg.fill ins(%3 : f32) outs(%21 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %23 = arith.mulf %20, %22 : tensor<510xf32>
// CHECK-NEXT: stencil.return %23 : tensor<510xf32>
// CHECK-NEXT: %3 = arith.constant dense<1.666600e-01> : tensor<510xf32>
// CHECK-NEXT: %4 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: %5 = stencil.access %2[1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %6 = "tensor.extract_slice"(%5) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %7 = stencil.access %2[-1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %8 = "tensor.extract_slice"(%7) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %9 = stencil.access %2[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %10 = "tensor.extract_slice"(%9) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %11 = stencil.access %2[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %12 = "tensor.extract_slice"(%11) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %13 = stencil.access %2[0, 1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %14 = "tensor.extract_slice"(%13) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %15 = stencil.access %2[0, -1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %16 = "tensor.extract_slice"(%15) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %17 = arith.addf %16, %14 : tensor<510xf32>
// CHECK-NEXT: %18 = arith.addf %17, %12 : tensor<510xf32>
// CHECK-NEXT: %19 = arith.addf %18, %10 : tensor<510xf32>
// CHECK-NEXT: %20 = arith.addf %19, %8 : tensor<510xf32>
// CHECK-NEXT: %21 = arith.addf %20, %6 : tensor<510xf32>
// CHECK-NEXT: %22 = arith.mulf %21, %3 : tensor<510xf32>
// CHECK-NEXT: stencil.return %22 : tensor<510xf32>
// CHECK-NEXT: }
// CHECK-NEXT: stencil.store %1 to %b (<[0, 0], [1022, 510]>) : !stencil.temp<[0,1022]x[0,510]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: func.return
Expand Down
100 changes: 69 additions & 31 deletions xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from xdsl.dialects.arith import (
Addf,
BinaryOperation,
Constant,
Divf,
FloatingPointLikeBinaryOp,
Mulf,
Expand All @@ -16,6 +17,7 @@
AnyFloat,
ArrayAttr,
ContainerType,
DenseIntOrFPElementsAttr,
IntAttr,
ModuleOp,
ShapedType,
Expand All @@ -40,6 +42,8 @@
from xdsl.ir import (
Attribute,
Operation,
OpResult,
SSAValue,
)
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
Expand Down Expand Up @@ -135,42 +139,63 @@ def match_and_rewrite(self, op: AccessOp, rewriter: PatternRewriter, /):
rewriter.replace_matched_op(extract)


def arithBinaryOpTensorize(
op: FloatingPointLikeBinaryOp,
rewriter: PatternRewriter,
/,
):
type_constructor = type(op)
if is_tensor(op.result.type):
return
if is_tensor(op.lhs.type) and is_tensor(op.rhs.type):
rewriter.replace_matched_op(
type_constructor(op.lhs, op.rhs, flags=None, result_type=op.lhs.type)
)
elif is_tensor(op.lhs.type) and is_scalar(op.rhs.type):
emptyop = EmptyOp((), op.lhs.type)
fillop = FillOp((op.rhs,), (emptyop.results[0],), (op.lhs.type,))
rewriter.insert_op(emptyop, InsertPoint.before(op))
rewriter.insert_op(fillop, InsertPoint.before(op))
rewriter.replace_matched_op(
type_constructor(op.lhs, fillop, flags=None, result_type=op.lhs.type)
)
elif is_scalar(op.lhs.type) and is_tensor(op.rhs.type):
emptyop = EmptyOp((), op.rhs.type)
fillop = FillOp((op.lhs,), (emptyop.results[0],), (op.rhs.type,))
rewriter.insert_op(emptyop, InsertPoint.before(op))
rewriter.insert_op(fillop, InsertPoint.before(op))
rewriter.replace_matched_op(
type_constructor(fillop, op.rhs, flags=None, result_type=op.rhs.type)
)


class ArithOpTensorize(RewritePattern):
"""
Tensorises arith binary ops.
If both operands are tensor types, rebuilds the op with matching result type.
If one operand is scalar and an `arith.constant`, create a tensor constant directly.
If one operand is scalar and not an `arith.constant`, create an empty tensor and fill it with the scalar value.
"""

@op_type_rewrite_pattern
def match_and_rewrite(
self, op: Addf | Subf | Mulf | Divf, rewriter: PatternRewriter, /
):
arithBinaryOpTensorize(op, rewriter)
type_constructor = type(op)
if is_tensor(op.result.type):
return
if is_tensor(op.lhs.type) and is_tensor(op.rhs.type):
rewriter.replace_matched_op(
type_constructor(op.lhs, op.rhs, flags=None, result_type=op.lhs.type)
)
elif is_tensor(op.lhs.type) and is_scalar(op.rhs.type):
new_rhs = ArithOpTensorize._rewrite_scalar_operand(
op.rhs, op.lhs.type, op, rewriter
)
rewriter.replace_matched_op(
type_constructor(op.lhs, new_rhs, flags=None, result_type=op.lhs.type)
)
elif is_scalar(op.lhs.type) and is_tensor(op.rhs.type):
new_lhs = ArithOpTensorize._rewrite_scalar_operand(
op.lhs, op.rhs.type, op, rewriter
)
rewriter.replace_matched_op(
type_constructor(new_lhs, op.rhs, flags=None, result_type=op.rhs.type)
)

@staticmethod
def _rewrite_scalar_operand(
scalar_op: SSAValue,
dest_typ: TensorType[Attribute],
op: FloatingPointLikeBinaryOp,
rewriter: PatternRewriter,
) -> SSAValue:
"""
Rewrites a scalar operand into a tensor.
If it is a constant, create a corresponding tensor constant.
If it is not a constant, create an empty tensor and `linalg.fill` it with the scalar value.
"""
if isinstance(scalar_op, OpResult) and isinstance(scalar_op.op, Constant):
tens_const = Constant(
DenseIntOrFPElementsAttr([dest_typ, ArrayAttr([scalar_op.op.value])])
)
rewriter.insert_op(tens_const, InsertPoint.before(scalar_op.op))
return tens_const.result
emptyop = EmptyOp((), dest_typ)
fillop = FillOp((scalar_op,), (emptyop.tensor,), (dest_typ,))
rewriter.insert_op(emptyop, InsertPoint.before(op))
rewriter.insert_op(fillop, InsertPoint.before(op))
return fillop.res[0]


@dataclass(frozen=True)
Expand Down Expand Up @@ -361,6 +386,18 @@ def match_and_rewrite(self, op: FillOp, rewriter: PatternRewriter, /):
)


class ConstOpUpdateShape(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: Constant, rewriter: PatternRewriter, /):
if is_tensor(op.result.type):
if typ := get_required_result_type(op):
if needs_update_shape(op.result.type, typ):
assert isinstance(op.value, DenseIntOrFPElementsAttr)
rewriter.replace_matched_op(
Constant(DenseIntOrFPElementsAttr([typ, op.value.data]))
)


@dataclass(frozen=True)
class BackpropagateStencilShapes(ModulePass):
"""
Expand All @@ -379,6 +416,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
EmptyOpUpdateShape(),
FillOpUpdateShape(),
ArithOpUpdateShape(),
ConstOpUpdateShape(),
]
),
walk_reverse=True,
Expand Down

0 comments on commit 1fd1d39

Please sign in to comment.