Skip to content

Commit

Permalink
transformations: (memref-streamify) don't streamify 0D memrefs (#3677)
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh authored Dec 26, 2024
1 parent 9e1f976 commit d7f6068
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
32 changes: 32 additions & 0 deletions tests/filecheck/transforms/memref_streamify.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,38 @@

// CHECK: builtin.module {

func.func @fill_empty_shape(%scalar: memref<f64>) {
%zero_float = arith.constant 0.000000e+00 : f64
memref_stream.generic {
bounds = [],
indexing_maps = [
affine_map<() -> ()>,
affine_map<() -> ()>
],
iterator_types = []
} ins(%zero_float : f64) outs(%scalar : memref<f64>) {
^bb0(%in: f64, %out: f64):
linalg.yield %in : f64
}
return
}

// CHECK-NEXT: func.func @fill_empty_shape(%scalar : memref<f64>) {
// CHECK-NEXT: %zero_float = arith.constant 0.000000e+00 : f64
// CHECK-NEXT: memref_stream.generic {
// CHECK-NEXT: bounds = [],
// CHECK-NEXT: indexing_maps = [
// CHECK-NEXT: affine_map<() -> ()>,
// CHECK-NEXT: affine_map<() -> ()>
// CHECK-NEXT: ],
// CHECK-NEXT: iterator_types = []
// CHECK-NEXT: } ins(%zero_float : f64) outs(%scalar : memref<f64>) {
// CHECK-NEXT: ^0(%in : f64, %out : f64):
// CHECK-NEXT: linalg.yield %in : f64
// CHECK-NEXT: }
// CHECK-NEXT: func.return
// CHECK-NEXT: }

func.func public @dsum(%arg0 : memref<8x16xf64>, %arg1 : memref<8x16xf64>, %arg2 : memref<8x16xf64>) -> memref<8x16xf64> {
memref_stream.generic {
bounds = [8, 16],
Expand Down
6 changes: 4 additions & 2 deletions xdsl/transforms/memref_streamify.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,17 @@ def match_and_rewrite(
for index, (i, arg) in enumerate(
zip(op.inputs, op.body.block.args[:input_count])
)
if isinstance(i.type, memref.MemRefType) and arg.uses
if isinstance(i_type := i.type, memref.MemRefType) and arg.uses
if i_type.get_shape()
)
streamable_output_indices = tuple(
(index, arg.type)
for index, (o, arg) in enumerate(
zip(op.outputs, op.body.block.args[input_count:])
)
if isinstance(o.type, memref.MemRefType)
if isinstance(o_type := o.type, memref.MemRefType)
if index in init_indices or not arg.uses
if o_type.get_shape()
)
if not streamable_input_indices and not streamable_output_indices:
# No memrefs to convert to streams
Expand Down

0 comments on commit d7f6068

Please sign in to comment.