Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: convert type_offsets in ptr to arith.constant #3394

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions tests/filecheck/mlir-conversion/with-mlir/ptr_loop_folding.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// RUN: xdsl-opt -p convert-memref-to-ptr,convert-ptr-type-offsets,mlir-opt[scf-for-loop-canonicalization,scf-for-loop-range-folding,scf-for-loop-canonicalization],scf-for-loop-flatten,mlir-opt[scf-for-loop-canonicalization,scf-for-loop-range-folding,scf-for-loop-canonicalization] --split-input-file %s | filecheck %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much are we missing in xDSL from the mlir-opt pipeline here? If it's not a lot, I'd much rather have this logic in xDSL.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe both scf-for-loop-canonicalization and scf-for-loop-range-folding are missing. Do you think it's worth porting them?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep definitely, I'd love to keep the Snitch compilation flow entirely working without the need of MLIR, for environments like WASM, and for when we start the work of schedule exploration, and whatever hackery we need to do to make it fast, it's much easier to play with all this in one environment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scf-for-loop-range-folding is already in, and was less buggy than the MLIR one until recently, where I upstreamed a bug fix to MLIR after noticing the difference with xDSL :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we bench it first? I'm just curious to see if it makes any improvements. Plus, if there is no speed increase, maybe the port won't be worth it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, right, there is. But I think it's only for the riscv loops. Do we want a general scf version?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can guarantee speed improvements, and register pressure improvements in the final assembly, if that's what you mean

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I think it would be worth having an scf version for this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I'll port them to xdsl

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we now have everything we need?


func.func @fill(%m: memref<10xi32>) {
%c0 = arith.constant 0 : index
%end = arith.constant 10 : index
%c1 = arith.constant 1 : index
%val = arith.constant 100 : i32
scf.for %i = %c0 to %end step %c1 {
memref.store %val, %m[%i] : memref<10xi32>
}
return
}

// CHECK: func.func @fill(%arg4 : memref<10xi32>) {
// CHECK-NEXT: %0 = arith.constant 0 : index
// CHECK-NEXT: %1 = arith.constant 100 : i32
// CHECK-NEXT: %2 = arith.constant 40 : index
// CHECK-NEXT: %3 = arith.constant 4 : index
// CHECK-NEXT: scf.for %arg5 = %0 to %2 step %3 {
// CHECK-NEXT: %4 = ptr_xdsl.to_ptr %arg4 : memref<10xi32> -> !ptr_xdsl.ptr
// CHECK-NEXT: %5 = ptr_xdsl.ptradd %4, %arg5 : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr
// CHECK-NEXT: ptr_xdsl.store %1, %5 : i32, !ptr_xdsl.ptr
// CHECK-NEXT: }
// CHECK-NEXT: func.return
// CHECK-NEXT: }

func.func @fill2d(%m: memref<10x10xi32>) {
%c0 = arith.constant 0 : index
%end = arith.constant 10 : index
%c1 = arith.constant 1 : index
%val = arith.constant 100 : i32
scf.for %i = %c0 to %end step %c1 {
scf.for %j = %c0 to %end step %c1 {
memref.store %val, %m[%i, %j] : memref<10x10xi32>
}
}
return
}

// CHECK-NEXT: func.func @fill2d(%arg2 : memref<10x10xi32>) {
// CHECK-NEXT: %0 = arith.constant 0 : index
// CHECK-NEXT: %1 = arith.constant 100 : i32
// CHECK-NEXT: %2 = arith.constant 400 : index
// CHECK-NEXT: %3 = arith.constant 4 : index
// CHECK-NEXT: scf.for %arg3 = %0 to %2 step %3 {
// CHECK-NEXT: %4 = ptr_xdsl.to_ptr %arg2 : memref<10x10xi32> -> !ptr_xdsl.ptr
// CHECK-NEXT: %5 = ptr_xdsl.ptradd %4, %arg3 : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr
// CHECK-NEXT: ptr_xdsl.store %1, %5 : i32, !ptr_xdsl.ptr
// CHECK-NEXT: }
// CHECK-NEXT: func.return
// CHECK-NEXT: }

func.func @fill3d(%m: memref<10x10x10xi32>) {
%c0 = arith.constant 0 : index
%end = arith.constant 10 : index
%c1 = arith.constant 1 : index
%val = arith.constant 100 : i32
scf.for %i = %c0 to %end step %c1 {
scf.for %j = %c0 to %end step %c1 {
scf.for %k = %c0 to %end step %c1 {
memref.store %val, %m[%i, %j, %k] : memref<10x10x10xi32>
}
}
}
return
}

// CHECK-NEXT: func.func @fill3d(%arg0 : memref<10x10x10xi32>) {
// CHECK-NEXT: %0 = arith.constant 0 : index
// CHECK-NEXT: %1 = arith.constant 100 : i32
// CHECK-NEXT: %2 = arith.constant 4000 : index
// CHECK-NEXT: %3 = arith.constant 4 : index
// CHECK-NEXT: scf.for %arg1 = %0 to %2 step %3 {
// CHECK-NEXT: %4 = ptr_xdsl.to_ptr %arg0 : memref<10x10x10xi32> -> !ptr_xdsl.ptr
// CHECK-NEXT: %5 = ptr_xdsl.ptradd %4, %arg1 : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr
// CHECK-NEXT: ptr_xdsl.store %1, %5 : i32, !ptr_xdsl.ptr
// CHECK-NEXT: }
// CHECK-NEXT: func.return
// CHECK-NEXT: }
12 changes: 12 additions & 0 deletions tests/filecheck/transforms/convert_ptr_type_offsets.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: xdsl-opt -p convert-ptr-type-offsets --split-input-file --verify-diagnostics %s | filecheck %s

%a1 = ptr_xdsl.type_offset i32 : index
// CHECK: %a1 = arith.constant 4 : index

%a2 = ptr_xdsl.type_offset f128 : index
// CHECK-NEXT: %a2 = arith.constant 16 : index

// -----

%a3 = ptr_xdsl.type_offset tensor<4xi32> : index
// CHECK: Type offset is currently only supported for fixed size types
6 changes: 6 additions & 0 deletions xdsl/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ def get_convert_print_format_to_riscv_debug():

return convert_print_format_to_riscv_debug.ConvertPrintFormatToRiscvDebugPass

def get_convert_ptr_type_offsets():
from xdsl.transforms import convert_ptr_type_offsets

return convert_ptr_type_offsets.ConvertPtrTypeOffsetsPass

def get_convert_qref_to_qssa():
from xdsl.transforms import convert_qref_to_qssa

Expand Down Expand Up @@ -468,6 +473,7 @@ def get_varith_fuse_repeated_operands():
"convert-scf-to-openmp": get_convert_scf_to_openmp,
"convert-scf-to-riscv-scf": get_convert_scf_to_riscv_scf,
"convert-snitch-stream-to-snitch": get_convert_snitch_stream_to_snitch,
"convert-ptr-type-offsets": get_convert_ptr_type_offsets,
"convert-stencil-to-csl-stencil": get_convert_stencil_to_csl_stencil,
"inline-snrt": get_convert_snrt_to_riscv,
"convert-stencil-to-ll-mlir": get_convert_stencil_to_ll_mlir,
Expand Down
38 changes: 38 additions & 0 deletions xdsl/transforms/convert_ptr_type_offsets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from dataclasses import dataclass
from typing import cast

from xdsl.context import MLContext
from xdsl.dialects import arith, ptr
from xdsl.dialects.builtin import FixedBitwidthType, IndexType, ModuleOp
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.utils.exceptions import DiagnosticException


@dataclass
class ConvertTypeOffsetOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ptr.TypeOffsetOp, rewriter: PatternRewriter, /):
if not issubclass(type(op.elem_type), FixedBitwidthType):
raise DiagnosticException(
"Type offset is currently only supported for fixed size types"
)
elem_type = cast(FixedBitwidthType, op.elem_type)
rewriter.replace_matched_op(
arith.Constant.from_int_and_width(elem_type.size, IndexType())
)


class ConvertPtrTypeOffsetsPass(ModulePass):
name = "convert-ptr-type-offsets"

def apply(self, ctx: MLContext, op: ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier([ConvertTypeOffsetOp()]),
).rewrite_module(op)
Loading