Skip to content

Commit

Permalink
transformations: Enable dmp.swap stencil bufferization. (#3066)
Browse files Browse the repository at this point in the history
- Implement side effect interface of dmp.swap
- Slightly extend the side-effect analysis of stencil bufferization.
- Add a `dmp.swap` bufferization pattern in `stencil-bufferize`. To just
work for now as dmp is only used in stencil so far - decoupling this as
shape inference and all is left to future work!
  • Loading branch information
PapyChacal authored Aug 19, 2024
1 parent 6291e67 commit f791473
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 10 deletions.
20 changes: 20 additions & 0 deletions tests/filecheck/transforms/distribute-stencil.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: xdsl-opt %s -p "distribute-stencil{strategy=3d-grid slices=2,2,2}" | filecheck %s
// RUN: xdsl-opt %s -p "distribute-stencil{strategy=3d-grid slices=2,2,2},shape-inference" | filecheck %s --check-prefix SHAPE
// RUN: xdsl-opt %s -p "distribute-stencil{strategy=3d-grid slices=2,2,2},shape-inference,stencil-bufferize" | filecheck %s --check-prefix BUFF

func.func @offsets(%27 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %28 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %29 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) {
%33 = stencil.load %27 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -> !stencil.temp<?x?x?xf64>
Expand Down Expand Up @@ -66,6 +67,25 @@
// SHAPE-NEXT: func.return
// SHAPE-NEXT: }

// BUFF: func.func @offsets(%0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %1 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %2 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) {
// BUFF-NEXT: "dmp.swap"(%0) {"strategy" = #dmp.grid_slice_3d<#dmp.topo<2x2x2>, false>, "swaps" = [#dmp.exchange<at [32, 0, 0] size [1, 32, 32] source offset [-1, 0, 0] to [1, 0, 0]>, #dmp.exchange<at [-1, 0, 0] size [1, 32, 32] source offset [1, 0, 0] to [-1, 0, 0]>, #dmp.exchange<at [0, 32, 0] size [32, 1, 32] source offset [0, -1, 0] to [0, 1, 0]>, #dmp.exchange<at [0, -1, 0] size [32, 1, 32] source offset [0, 1, 0] to [0, -1, 0]>]} : (!stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) -> ()
// BUFF-NEXT: stencil.apply(%3 = %0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) outs (%1 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %2 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) {
// BUFF-NEXT: %4 = stencil.access %3[-1, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// BUFF-NEXT: %5 = stencil.access %3[1, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// BUFF-NEXT: %6 = stencil.access %3[0, 1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// BUFF-NEXT: %7 = stencil.access %3[0, -1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// BUFF-NEXT: %8 = stencil.access %3[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// BUFF-NEXT: %9 = arith.addf %4, %5 : f64
// BUFF-NEXT: %10 = arith.addf %6, %7 : f64
// BUFF-NEXT: %11 = arith.addf %9, %10 : f64
// BUFF-NEXT: %cst = arith.constant -4.000000e+00 : f64
// BUFF-NEXT: %12 = arith.mulf %8, %cst : f64
// BUFF-NEXT: %13 = arith.addf %12, %11 : f64
// BUFF-NEXT: stencil.return %13, %12 : f64, f64
// BUFF-NEXT: } to <[0, 0, 0], [32, 32, 32]>
// BUFF-NEXT: func.return
// BUFF-NEXT: }

func.func @trivial_externals(%dyn_mem : memref<?x?x?xf64>, %sta_mem : memref<64x64x64xf64>, %dyn_field : !stencil.field<?x?x?xf64>, %sta_field : !stencil.field<[-2,62]x[0,64]x[2,66]xf64>) {
stencil.external_store %dyn_field to %dyn_mem : !stencil.field<?x?x?xf64> to memref<?x?x?xf64>
stencil.external_store %sta_field to %sta_mem : !stencil.field<[-2,62]x[0,64]x[2,66]xf64> to memref<64x64x64xf64>
Expand Down
32 changes: 29 additions & 3 deletions xdsl/dialects/experimental/dmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from abc import ABC
from collections.abc import Iterable, Sequence
from math import prod
from typing import Literal
from typing import Literal, cast

from xdsl.dialects import builtin, stencil
from xdsl.ir import Attribute, Dialect, Operation, ParametrizedAttribute, SSAValue
Expand All @@ -29,7 +29,12 @@
)
from xdsl.parser import AttrParser
from xdsl.printer import Printer
from xdsl.traits import HasShapeInferencePatternsTrait
from xdsl.traits import (
EffectInstance,
HasShapeInferencePatternsTrait,
MemoryEffect,
MemoryEffectKind,
)
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa

Expand Down Expand Up @@ -594,6 +599,27 @@ def get_shape_inference_patterns(cls):
return (DmpSwapShapeInference(), DmpSwapSwapsInference())


class SwapOpMemoryEffect(MemoryEffect):
"""
Side effect implementation of dmp.swap.
"""

@classmethod
def get_effects(cls, op: Operation) -> set[EffectInstance]:
op = cast(SwapOp, op)
# If it's operating in value-semantic mode, it has no side effects.
if op.swapped_values:
return set()
# If it's operating in reference-semantic mode, it reads and writes to its field.
# TODO: consider the empty swaps case at some point.
# Right now, it relies on it before inferring them, so not very safe.
# But it could be an elegant way to generically simplify those.
return {
EffectInstance(MemoryEffectKind.WRITE, op.input_stencil),
EffectInstance(MemoryEffectKind.READ, op.input_stencil),
}


@irdl_op_definition
class SwapOp(IRDLOperation):
"""
Expand All @@ -609,7 +635,7 @@ class SwapOp(IRDLOperation):

strategy = attr_def(DomainDecompositionStrategy)

traits = frozenset([SwapOpHasShapeInferencePatterns()])
traits = frozenset([SwapOpHasShapeInferencePatterns(), SwapOpMemoryEffect()])

def verify_(self) -> None:
if self.swapped_values:
Expand Down
50 changes: 43 additions & 7 deletions xdsl/transforms/stencil_bufferize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from xdsl.context import MLContext
from xdsl.dialects import builtin
from xdsl.dialects.experimental.dmp import SwapOp
from xdsl.dialects.stencil import (
AllocOp,
ApplyOp,
Expand Down Expand Up @@ -166,17 +167,19 @@ def match_and_rewrite(self, op: BufferOp, rewriter: PatternRewriter):

underlying = load.field

# TODO: propery analysis of effects in between
# For illustration, only fold a single use of the handle
# (Requires more boilerplate to analyse the whole live range otherwise)
# TODO: further analysis
# For now, only handle usages in the same block
uses = op.res.uses.copy()
if len(uses) > 1:
block = op.parent
if not block or any(use.operation.parent is not block for use in uses):
return
user = uses.pop().operation
last_user = max(
uses, key=lambda u: block.get_operation_index(u.operation)
).operation

effecting = [
o
for o in walk_from_to(load, user)
for o in walk_from_to(load, last_user)
if might_effect(o, {MemoryEffectKind.WRITE}, underlying)
]
if effecting:
Expand Down Expand Up @@ -498,6 +501,37 @@ def match_and_rewrite(self, op: CombineOp, rewriter: PatternRewriter):
return


class SwapBufferize(RewritePattern):
"""
Bufferize a dmp.swap operation.
NB: This should most likely consider a shared pass following canonicalize and
shape-inference.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: SwapOp, rewriter: PatternRewriter):
temp = op.input_stencil

if not isa(temp_t := temp.type, TempType[Attribute]):
return

load = temp.owner
if not isinstance(load, LoadOp):
return

buffer = BufferOp.create(
operands=[temp], result_types=[field_from_temp(temp_t)]
)
new_swap = SwapOp.get(buffer.res, op.strategy)
new_swap.swaps = op.swaps
load = LoadOp(operands=[buffer.res], result_types=[temp_t])

rewriter.replace_matched_op(
new_ops=[buffer, new_swap, load],
)


@dataclass(frozen=True)
class StencilBufferize(ModulePass):
"""
Expand All @@ -520,7 +554,9 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
ApplyStoreFoldPattern(),
RemoveUnusedOperations(),
ApplyUnusedResults(),
SwapBufferize(),
]
)
),
apply_recursively=True,
)
walker.rewrite_module(op)

0 comments on commit f791473

Please sign in to comment.