Skip to content

Commit

Permalink
bug: (csl-lowering) Make multi-apply lowering work (#3614)
Browse files Browse the repository at this point in the history
This PR includes a few small fixes, described below.

---------

Co-authored-by: n-io <[email protected]>
  • Loading branch information
n-io and n-io authored Dec 19, 2024
1 parent 4b15917 commit d5dd188
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,12 @@ builtin.module {
// CHECK-NEXT: %0 = tensor.empty() : tensor<1x64xf32>
// CHECK-NEXT: csl_stencil.apply(%arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %0 : tensor<1x64xf32>) -> () <{"swaps" = [#csl_stencil.exchange<to [-1, 0]>], "topo" = #dmp.topo<64x64>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 0>}> ({
// CHECK-NEXT: ^0(%1 : tensor<1x32xf32>, %2 : index, %3 : tensor<1x64xf32>):
// CHECK-NEXT: %4 = csl_stencil.access %3[-1, 0] : tensor<1x64xf32>
// CHECK-NEXT: %5 = "tensor.insert_slice"(%4, %3, %2) <{"static_offsets" = array<i64: 0, -9223372036854775808>, "static_sizes" = array<i64: 32>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<32xf32>, tensor<1x64xf32>, index) -> tensor<1x64xf32>
// CHECK-NEXT: %4 = csl_stencil.access %1[-1, 0] : tensor<1x32xf32>
// CHECK-NEXT: %5 = "tensor.insert_slice"(%4, %3, %2) <{"static_offsets" = array<i64: 0, -9223372036854775808>, "static_sizes" = array<i64: 1, 32>, "static_strides" = array<i64: 1, 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<32xf32>, tensor<1x64xf32>, index) -> tensor<1x64xf32>
// CHECK-NEXT: csl_stencil.yield %5 : tensor<1x64xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%6 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %7 : tensor<1x64xf32>):
// CHECK-NEXT: csl_stencil.yield %7 : tensor<1x64xf32>
// CHECK-NEXT: csl_stencil.yield
// CHECK-NEXT: })
// CHECK-NEXT: %1 = tensor.empty() : tensor<64xf32>
// CHECK-NEXT: csl_stencil.apply(%arg0 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %1 : tensor<64xf32>, %arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %0 : tensor<1x64xf32>) outs (%arg4 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) <{"swaps" = [#csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<64x64>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 2, 1>}> ({
Expand Down
10 changes: 6 additions & 4 deletions xdsl/transforms/convert_stencil_to_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,19 +606,21 @@ def match_and_rewrite(

block = Block(arg_types=[chunk_buf_t, builtin.IndexType(), op.result.type])
block2 = Block(arg_types=[op.input_stencil.type, op.result.type])
block2.add_op(csl_stencil.YieldOp(block2.args[1]))
block2.add_op(csl_stencil.YieldOp())

with ImplicitBuilder(block) as (_, offset, acc):
with ImplicitBuilder(block) as (buf, offset, acc):
dest = acc
for i, acc_offset in enumerate(offsets):
ac_op = csl_stencil.AccessOp(
dest, stencil.IndexAttr.get(*acc_offset), chunk_t
buf, stencil.IndexAttr.get(*acc_offset), chunk_t
)
assert isa(ac_op.result.type, AnyTensorType)
# inserts 1 (see static_sizes) 1d slice into a 2d tensor at offset (i, `offset`) (see static_offsets)
# where the latter offset is provided dynamically (see offsets)
dest = tensor.InsertSliceOp.get(
source=ac_op.result,
dest=dest,
static_sizes=ac_op.result.type.get_shape(),
static_sizes=[1, *ac_op.result.type.get_shape()],
static_offsets=[i, memref.SubviewOp.DYNAMIC_INDEX],
offsets=[offset],
).result
Expand Down
9 changes: 9 additions & 0 deletions xdsl/transforms/csl_stencil_bufferize.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
# convert args
buf_args: list[SSAValue] = []
to_memrefs: list[Operation] = [buf_iter_arg := to_memref_op(op.accumulator)]
# in case of subsequent apply ops accessing this accumulator, replace uses with `bufferization.to_memref`
op.accumulator.replace_by_if(
buf_iter_arg.memref, lambda use: use.operation != buf_iter_arg
)
for arg in [*op.args_rchunk, *op.args_dexchng]:
if isa(arg.type, TensorType[Attribute]):
to_memrefs.append(new_arg := to_memref_op(arg))
Expand Down Expand Up @@ -385,6 +389,11 @@ def match_and_rewrite(self, op: arith.ConstantOp, rewriter: PatternRewriter, /):
class InjectApplyOutsIntoLinalgOuts(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, /):
# require bufferized apply (with op.dest specified)
# zero-output apply ops may be used for communicate-only, to which this pattern does not apply
if not op.dest:
return

yld = op.done_exchange.block.last_op
assert isinstance(yld, csl_stencil.YieldOp)
new_dest: list[SSAValue] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,13 @@ def get_required_result_type(op: Operation) -> TensorType[Attribute] | None:
tuple[int, ...],
)
):
assert is_tensor(use.operation.source.type)
# inserting an (n-1)d tensor into an (n)d tensor should not require the input tensor to also be (n)d
# instead, drop the first `dimdiff` dimensions
dimdiff = len(static_sizes) - len(use.operation.source.type.shape)
return TensorType(
use.operation.result.type.get_element_type(),
static_sizes,
static_sizes[dimdiff:],
)
for ret in use.operation.results:
if isa(r_type := ret.type, TensorType[Attribute]):
Expand Down
7 changes: 6 additions & 1 deletion xdsl/transforms/memref_to_dsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,12 @@ def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /):
last_op = stride_ops[-1] if len(stride_ops) > 0 else last_op
offset_ops = self._update_offsets(op, last_op)

rewriter.replace_matched_op([*size_ops, *stride_ops, *offset_ops])
new_ops = [*size_ops, *stride_ops, *offset_ops]
if new_ops:
rewriter.replace_matched_op([*size_ops, *stride_ops, *offset_ops])
else:
# subview has no effect (todo: this could be canonicalized away)
rewriter.replace_matched_op([], new_results=[op.source])

@staticmethod
def _update_sizes(
Expand Down

0 comments on commit d5dd188

Please sign in to comment.