From 5b746bfc35510ea3be589d3f183cadcc0b3bef7b Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Fri, 3 Nov 2023 10:34:34 +0000 Subject: [PATCH] backend: (snitch) add snitch register allocation (#1741) --- .../snitch_register_allocation.mlir | 39 +++++++++ xdsl/tools/command_line_tool.py | 2 + xdsl/transforms/snitch_register_allocation.py | 80 +++++++++++++++++++ 3 files changed, 121 insertions(+) create mode 100644 tests/filecheck/transforms/snitch_register_allocation.mlir create mode 100644 xdsl/transforms/snitch_register_allocation.py diff --git a/tests/filecheck/transforms/snitch_register_allocation.mlir b/tests/filecheck/transforms/snitch_register_allocation.mlir new file mode 100644 index 0000000000..d96ba1a146 --- /dev/null +++ b/tests/filecheck/transforms/snitch_register_allocation.mlir @@ -0,0 +1,39 @@ +// RUN: xdsl-opt -p snitch-allocate-registers %s | filecheck %s + +%stride_pattern, %ptr0, %ptr1, %ptr2 = "test.op"() : () -> (!snitch_stream.stride_pattern_type, !riscv.reg<>, !riscv.reg<>, !riscv.reg<>) +%s0 = "snitch_stream.strided_read"(%ptr0, %stride_pattern) {"dm" = #int<0>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +%s1 = "snitch_stream.strided_read"(%ptr1, %stride_pattern) {"dm" = #int<1>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +%s2 = "snitch_stream.strided_write"(%ptr2, %stride_pattern) {"dm" = #int<2>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> +%c128 = riscv.li 128 : () -> !riscv.reg<> + +"snitch_stream.generic"(%c128, %s0, %s1, %s2) <{"operandSegmentSizes" = array}> ({ +^0(%x : !riscv.freg<>, %y : !riscv.freg<>): + %r0 = riscv.fadd.d %x, %y : (!riscv.freg<>, !riscv.freg<>) -> !riscv.freg<> + snitch_stream.yield %r0 : !riscv.freg<> +}) : (!riscv.reg<>, !stream.readable>, !stream.readable>, !stream.writable>) -> () + +"snitch_stream.generic"(%c128, %s1, %s0, %s2) <{"operandSegmentSizes" = array}> ({ +^0(%x : !riscv.freg<>, %y : !riscv.freg<>): + %r0 = riscv.fadd.d %x, %y : (!riscv.freg<>, !riscv.freg<>) -> !riscv.freg<> + snitch_stream.yield %r0 : !riscv.freg<> +}) : (!riscv.reg<>, !stream.readable>, !stream.readable>, !stream.writable>) -> () + +// CHECK: builtin.module { + +// CHECK-NEXT: %stride_pattern, %ptr0, %ptr1, %ptr2 = "test.op"() : () -> (!snitch_stream.stride_pattern_type, !riscv.reg<>, !riscv.reg<>, !riscv.reg<>) +// CHECK-NEXT: %s0 = "snitch_stream.strided_read"(%ptr0, %stride_pattern) {"dm" = #int<0>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +// CHECK-NEXT: %s1 = "snitch_stream.strided_read"(%ptr1, %stride_pattern) {"dm" = #int<1>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +// CHECK-NEXT: %s2 = "snitch_stream.strided_write"(%ptr2, %stride_pattern) {"dm" = #int<2>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> +// CHECK-NEXT: %c128 = riscv.li 128 : () -> !riscv.reg<> +// CHECK-NEXT: "snitch_stream.generic"(%c128, %s0, %s1, %s2) <{"operandSegmentSizes" = array}> ({ +// CHECK-NEXT: ^0(%{{.*}} : !riscv.freg, %{{.*}} : !riscv.freg): +// CHECK-NEXT: %{{.*}} = riscv.fadd.d %{{.*}}, %{{.*}} : (!riscv.freg, !riscv.freg) -> !riscv.freg +// CHECK-NEXT: snitch_stream.yield %{{.*}} : !riscv.freg +// CHECK-NEXT: }) : (!riscv.reg<>, !stream.readable>, !stream.readable>, !stream.writable>) -> () +// CHECK-NEXT: "snitch_stream.generic"(%c128, %s1, %s0, %s2) <{"operandSegmentSizes" = array}> ({ +// CHECK-NEXT: ^1(%{{.*}} : !riscv.freg, %{{.*}} : !riscv.freg): +// CHECK-NEXT: %{{.*}} = riscv.fadd.d %{{.*}}, %{{.*}} : (!riscv.freg, !riscv.freg) -> !riscv.freg +// CHECK-NEXT: snitch_stream.yield %{{.*}} : !riscv.freg +// CHECK-NEXT: }) : (!riscv.reg<>, !stream.readable>, !stream.readable>, !stream.writable>) -> () + +// CHECK-NEXT: } diff --git a/xdsl/tools/command_line_tool.py b/xdsl/tools/command_line_tool.py index d9036d3461..95565807f3 100644 --- a/xdsl/tools/command_line_tool.py +++ b/xdsl/tools/command_line_tool.py @@ -67,6 +67,7 @@ reconcile_unrealized_casts, riscv_register_allocation, riscv_scf_loop_range_folding, + snitch_register_allocation, ) from xdsl.transforms.experimental import ( convert_stencil_to_ll_mlir, @@ -143,6 +144,7 @@ def get_all_passes() -> list[type[ModulePass]]: reduce_register_pressure.RiscvReduceRegisterPressurePass, riscv_register_allocation.RISCVRegisterAllocation, riscv_scf_loop_range_folding.RiscvScfLoopRangeFoldingPass, + snitch_register_allocation.SnitchRegisterAllocation, convert_arith_to_riscv.ConvertArithToRiscvPass, convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass, convert_memref_to_riscv.ConvertMemrefToRiscvPass, diff --git a/xdsl/transforms/snitch_register_allocation.py b/xdsl/transforms/snitch_register_allocation.py new file mode 100644 index 0000000000..9bd9524f4e --- /dev/null +++ b/xdsl/transforms/snitch_register_allocation.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass +from typing import Any, cast + +from xdsl.dialects import riscv, snitch_stream, stream +from xdsl.dialects.builtin import ModuleOp +from xdsl.ir import MLContext +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + op_type_rewrite_pattern, +) + + +class AllocateSnitchStridedStreamRegisters(RewritePattern): + """ + Allocates the register used by the stream as the one specified by the `dm` + (data mover) attribute. Must be called before allocating the registers in the + `snitch_stream.generic` body. + """ + + @op_type_rewrite_pattern + def match_and_rewrite( + self, + op: snitch_stream.StridedReadOp | snitch_stream.StridedWriteOp, + rewriter: PatternRewriter, + /, + ): + stream_type = op.stream.type + assert isinstance( + stream_type, stream.ReadableStreamType | stream.WritableStreamType + ) + stream_type = cast(stream.StreamType[Any], stream_type) + op.stream.type = type(stream_type)(riscv.Registers.FT[op.dm.data]) + + +class AllocateSnitchGenericRegisters(RewritePattern): + """ + Allocates the registers in the body of a `snitch_stream.generic` operation by assigning + them to the ones specified by the streams. + """ + + @op_type_rewrite_pattern + def match_and_rewrite( + self, op: snitch_stream.GenericOp, rewriter: PatternRewriter, / + ): + block = op.body.block + + for arg, input in zip(block.args, op.inputs): + assert isinstance(input.type, stream.ReadableStreamType) + input_type: stream.ReadableStreamType[Any] = input.type + arg.type = input_type.element_type + + yield_op = block.last_op + assert isinstance(yield_op, snitch_stream.YieldOp) + + for arg, output in zip(yield_op.values, op.outputs): + assert isinstance(output.type, stream.WritableStreamType) + output_type: stream.WritableStreamType[Any] = output.type + arg.type = output_type.element_type + + +@dataclass +class SnitchRegisterAllocation(ModulePass): + """ + Allocates unallocated registers for snitch operations. + """ + + name = "snitch-allocate-registers" + + def apply(self, ctx: MLContext, op: ModuleOp) -> None: + PatternRewriteWalker( + AllocateSnitchStridedStreamRegisters(), + apply_recursively=False, + ).rewrite_module(op) + PatternRewriteWalker( + AllocateSnitchGenericRegisters(), + apply_recursively=False, + ).rewrite_module(op)