Skip to content

Commit

Permalink
dialects: (memref) add layout and memory space support in memref.allo…
Browse files Browse the repository at this point in the history
…c(a) (#1784)

This PR adds support for memref layout and memory space attributes in
the memref.alloc and memref.alloca operations

@JosseVanDelm
  • Loading branch information
jorendumoulin authored Nov 15, 2023
1 parent c5f3ce9 commit e3ae802
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 2 deletions.
16 changes: 16 additions & 0 deletions tests/dialects/test_memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,11 @@ def test_memref_store_i32_with_dimensions():

def test_memref_alloc():
my_i32 = IntegerType(32)
my_layout = StridedLayoutAttr(strides=(2, 4, 6), offset=8)
my_memspace = builtin.IntegerAttr(0, i32)
alloc0 = Alloc.get(my_i32, 64, [3, 1, 2])
alloc1 = Alloc.get(my_i32, 64)
alloc2 = Alloc.get(my_i32, 64, [3, 1, 2], my_layout, my_memspace)

assert alloc0.dynamic_sizes == ()
assert type(alloc0.results[0]) is OpResult
Expand All @@ -115,19 +118,32 @@ def test_memref_alloc():
assert type(alloc1.results[0]) is OpResult
assert type(alloc1.results[0].type) is MemRefType
assert alloc1.results[0].type.get_shape() == (1,)
assert type(alloc2.results[0]) is OpResult
assert type(alloc2.results[0].type) is MemRefType
assert alloc2.results[0].type.get_shape() == (3, 1, 2)
assert alloc2.results[0].type.layout == my_layout
assert alloc2.results[0].type.memory_space == my_memspace


def test_memref_alloca():
my_i32 = IntegerType(32)
my_layout = StridedLayoutAttr(strides=(2, 4, 6), offset=8)
my_memspace = builtin.IntegerAttr(0, i32)
alloc0 = Alloca.get(my_i32, 64, [3, 1, 2])
alloc1 = Alloca.get(my_i32, 64)
alloc2 = Alloc.get(my_i32, 64, [3, 1, 2], my_layout, my_memspace)

assert type(alloc0.results[0]) is OpResult
assert type(alloc0.results[0].type) is MemRefType
assert alloc0.results[0].type.get_shape() == (3, 1, 2)
assert type(alloc1.results[0]) is OpResult
assert type(alloc1.results[0].type) is MemRefType
assert alloc1.results[0].type.get_shape() == (1,)
assert type(alloc2.results[0]) is OpResult
assert type(alloc2.results[0].type) is MemRefType
assert alloc2.results[0].type.get_shape() == (3, 1, 2)
assert alloc2.results[0].type.layout == my_layout
assert alloc2.results[0].type.memory_space == my_memspace


def test_memref_dealloc():
Expand Down
9 changes: 9 additions & 0 deletions tests/filecheck/dialects/memref/memref_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@ builtin.module {
%7 = "memref.cast"(%5) : (memref<10x2xindex>) -> memref<?x?xindex>
%8 = "memref.alloca"() {"operandSegmentSizes" = array<i32: 0, 0>} : () -> memref<1xindex>
%9 = "memref.memory_space_cast"(%5) : (memref<10x2xindex>) -> memref<10x2xindex, 1: i32>
%10 = "memref.alloc"() {"operandSegmentSizes" = array<i32: 0, 0>} : () -> memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>
%11 = "memref.alloca"() {"operandSegmentSizes" = array<i32: 0, 0>} : () -> memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>
"memref.dealloc"(%2) : (memref<1xindex>) -> ()
"memref.dealloc"(%5) : (memref<10x2xindex>) -> ()
"memref.dealloc"(%8) : (memref<1xindex>) -> ()
"memref.dealloc"(%10) : (memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>) -> ()
"memref.dealloc"(%11) : (memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>) -> ()

func.return
}
}
Expand All @@ -49,9 +54,13 @@ builtin.module {
// CHECK-NEXT: %{{.*}} = "memref.cast"(%{{.*}}) : (memref<10x2xindex>) -> memref<?x?xindex>
// CHECK-NEXT: %{{.*}} = "memref.alloca"() <{"operandSegmentSizes" = array<i32: 0, 0>}> : () -> memref<1xindex>
// CHECK-NEXT: %{{.*}} = "memref.memory_space_cast"(%{{.*}}) : (memref<10x2xindex>) -> memref<10x2xindex, 1 : i32>
// CHECK-NEXT: %{{.*}} = "memref.alloc"() <{"operandSegmentSizes" = array<i32: 0, 0>}> : () -> memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>
// CHECK-NEXT: %{{.*}} = "memref.alloca"() <{"operandSegmentSizes" = array<i32: 0, 0>}> : () -> memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>
// CHECK-NEXT: "memref.dealloc"(%{{.*}}) : (memref<1xindex>) -> ()
// CHECK-NEXT: "memref.dealloc"(%{{.*}}) : (memref<10x2xindex>) -> ()
// CHECK-NEXT: "memref.dealloc"(%{{.*}}) : (memref<1xindex>) -> ()
// CHECK-NEXT: "memref.dealloc"(%{{.*}}) : (memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>) -> ()
// CHECK-NEXT: "memref.dealloc"(%{{.*}}) : (memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>) -> ()
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: }
16 changes: 14 additions & 2 deletions xdsl/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,18 @@ def get(
return_type: Attribute,
alignment: int | None = None,
shape: Iterable[int | AnyIntegerAttr] | None = None,
layout: Attribute = NoneAttr(),
memory_space: Attribute = NoneAttr(),
) -> Alloc:
if shape is None:
shape = [1]
return Alloc.build(
operands=[[], []],
result_types=[MemRefType.from_element_type_and_shape(return_type, shape)],
result_types=[
MemRefType.from_element_type_and_shape(
return_type, shape, layout, memory_space
)
],
attributes={
"alignment": IntegerAttr.from_int_and_width(alignment, 64)
if alignment is not None
Expand Down Expand Up @@ -369,6 +375,8 @@ def get(
alignment: int | AnyIntegerAttr | None = None,
shape: Iterable[int | AnyIntegerAttr] | None = None,
dynamic_sizes: Sequence[SSAValue | Operation] | None = None,
layout: Attribute = NoneAttr(),
memory_space: Attribute = NoneAttr(),
) -> Alloca:
if shape is None:
shape = [1]
Expand All @@ -381,7 +389,11 @@ def get(

return Alloca.build(
operands=[dynamic_sizes, []],
result_types=[MemRefType.from_element_type_and_shape(return_type, shape)],
result_types=[
MemRefType.from_element_type_and_shape(
return_type, shape, layout, memory_space
)
],
properties={
"alignment": alignment,
},
Expand Down

0 comments on commit e3ae802

Please sign in to comment.