From e3ae802983ce2cfb7f1a6577dd6f67462abb29ce Mon Sep 17 00:00:00 2001 From: jorendumoulin <47864363+jorendumoulin@users.noreply.github.com> Date: Wed, 15 Nov 2023 10:38:52 +0100 Subject: [PATCH] dialects: (memref) add layout and memory space support in memref.alloc(a) (#1784) This PR adds support for memref layout and memory space attributes in the memref.alloc and memref.alloca operations @JosseVanDelm --- tests/dialects/test_memref.py | 16 ++++++++++++++++ tests/filecheck/dialects/memref/memref_ops.mlir | 9 +++++++++ xdsl/dialects/memref.py | 16 ++++++++++++++-- 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/tests/dialects/test_memref.py b/tests/dialects/test_memref.py index b0e0fc7f63..2a2e4e7bd0 100644 --- a/tests/dialects/test_memref.py +++ b/tests/dialects/test_memref.py @@ -105,8 +105,11 @@ def test_memref_store_i32_with_dimensions(): def test_memref_alloc(): my_i32 = IntegerType(32) + my_layout = StridedLayoutAttr(strides=(2, 4, 6), offset=8) + my_memspace = builtin.IntegerAttr(0, i32) alloc0 = Alloc.get(my_i32, 64, [3, 1, 2]) alloc1 = Alloc.get(my_i32, 64) + alloc2 = Alloc.get(my_i32, 64, [3, 1, 2], my_layout, my_memspace) assert alloc0.dynamic_sizes == () assert type(alloc0.results[0]) is OpResult @@ -115,12 +118,20 @@ def test_memref_alloc(): assert type(alloc1.results[0]) is OpResult assert type(alloc1.results[0].type) is MemRefType assert alloc1.results[0].type.get_shape() == (1,) + assert type(alloc2.results[0]) is OpResult + assert type(alloc2.results[0].type) is MemRefType + assert alloc2.results[0].type.get_shape() == (3, 1, 2) + assert alloc2.results[0].type.layout == my_layout + assert alloc2.results[0].type.memory_space == my_memspace def test_memref_alloca(): my_i32 = IntegerType(32) + my_layout = StridedLayoutAttr(strides=(2, 4, 6), offset=8) + my_memspace = builtin.IntegerAttr(0, i32) alloc0 = Alloca.get(my_i32, 64, [3, 1, 2]) alloc1 = Alloca.get(my_i32, 64) + alloc2 = Alloc.get(my_i32, 64, [3, 1, 2], my_layout, my_memspace) assert type(alloc0.results[0]) is OpResult assert type(alloc0.results[0].type) is MemRefType @@ -128,6 +139,11 @@ def test_memref_alloca(): assert type(alloc1.results[0]) is OpResult assert type(alloc1.results[0].type) is MemRefType assert alloc1.results[0].type.get_shape() == (1,) + assert type(alloc2.results[0]) is OpResult + assert type(alloc2.results[0].type) is MemRefType + assert alloc2.results[0].type.get_shape() == (3, 1, 2) + assert alloc2.results[0].type.layout == my_layout + assert alloc2.results[0].type.memory_space == my_memspace def test_memref_dealloc(): diff --git a/tests/filecheck/dialects/memref/memref_ops.mlir b/tests/filecheck/dialects/memref/memref_ops.mlir index d686ce6f0b..7b443550af 100644 --- a/tests/filecheck/dialects/memref/memref_ops.mlir +++ b/tests/filecheck/dialects/memref/memref_ops.mlir @@ -21,9 +21,14 @@ builtin.module { %7 = "memref.cast"(%5) : (memref<10x2xindex>) -> memref %8 = "memref.alloca"() {"operandSegmentSizes" = array} : () -> memref<1xindex> %9 = "memref.memory_space_cast"(%5) : (memref<10x2xindex>) -> memref<10x2xindex, 1: i32> + %10 = "memref.alloc"() {"operandSegmentSizes" = array} : () -> memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32> + %11 = "memref.alloca"() {"operandSegmentSizes" = array} : () -> memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32> "memref.dealloc"(%2) : (memref<1xindex>) -> () "memref.dealloc"(%5) : (memref<10x2xindex>) -> () "memref.dealloc"(%8) : (memref<1xindex>) -> () + "memref.dealloc"(%10) : (memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>) -> () + "memref.dealloc"(%11) : (memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>) -> () + func.return } } @@ -49,9 +54,13 @@ builtin.module { // CHECK-NEXT: %{{.*}} = "memref.cast"(%{{.*}}) : (memref<10x2xindex>) -> memref // CHECK-NEXT: %{{.*}} = "memref.alloca"() <{"operandSegmentSizes" = array}> : () -> memref<1xindex> // CHECK-NEXT: %{{.*}} = "memref.memory_space_cast"(%{{.*}}) : (memref<10x2xindex>) -> memref<10x2xindex, 1 : i32> +// CHECK-NEXT: %{{.*}} = "memref.alloc"() <{"operandSegmentSizes" = array}> : () -> memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32> +// CHECK-NEXT: %{{.*}} = "memref.alloca"() <{"operandSegmentSizes" = array}> : () -> memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32> // CHECK-NEXT: "memref.dealloc"(%{{.*}}) : (memref<1xindex>) -> () // CHECK-NEXT: "memref.dealloc"(%{{.*}}) : (memref<10x2xindex>) -> () // CHECK-NEXT: "memref.dealloc"(%{{.*}}) : (memref<1xindex>) -> () +// CHECK-NEXT: "memref.dealloc"(%{{.*}}) : (memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>) -> () +// CHECK-NEXT: "memref.dealloc"(%{{.*}}) : (memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>) -> () // CHECK-NEXT: func.return // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py index 6af16844b3..6dd633e7d9 100644 --- a/xdsl/dialects/memref.py +++ b/xdsl/dialects/memref.py @@ -310,12 +310,18 @@ def get( return_type: Attribute, alignment: int | None = None, shape: Iterable[int | AnyIntegerAttr] | None = None, + layout: Attribute = NoneAttr(), + memory_space: Attribute = NoneAttr(), ) -> Alloc: if shape is None: shape = [1] return Alloc.build( operands=[[], []], - result_types=[MemRefType.from_element_type_and_shape(return_type, shape)], + result_types=[ + MemRefType.from_element_type_and_shape( + return_type, shape, layout, memory_space + ) + ], attributes={ "alignment": IntegerAttr.from_int_and_width(alignment, 64) if alignment is not None @@ -369,6 +375,8 @@ def get( alignment: int | AnyIntegerAttr | None = None, shape: Iterable[int | AnyIntegerAttr] | None = None, dynamic_sizes: Sequence[SSAValue | Operation] | None = None, + layout: Attribute = NoneAttr(), + memory_space: Attribute = NoneAttr(), ) -> Alloca: if shape is None: shape = [1] @@ -381,7 +389,11 @@ def get( return Alloca.build( operands=[dynamic_sizes, []], - result_types=[MemRefType.from_element_type_and_shape(return_type, shape)], + result_types=[ + MemRefType.from_element_type_and_shape( + return_type, shape, layout, memory_space + ) + ], properties={ "alignment": alignment, },