Skip to content

Commit

Permalink
backend: (snitch) add snitch register allocation (#1741)
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh authored Nov 3, 2023
1 parent 87a34d9 commit 5b746bf
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 0 deletions.
39 changes: 39 additions & 0 deletions tests/filecheck/transforms/snitch_register_allocation.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: xdsl-opt -p snitch-allocate-registers %s | filecheck %s

%stride_pattern, %ptr0, %ptr1, %ptr2 = "test.op"() : () -> (!snitch_stream.stride_pattern_type, !riscv.reg<>, !riscv.reg<>, !riscv.reg<>)
%s0 = "snitch_stream.strided_read"(%ptr0, %stride_pattern) {"dm" = #int<0>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable<!riscv.freg<>>
%s1 = "snitch_stream.strided_read"(%ptr1, %stride_pattern) {"dm" = #int<1>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable<!riscv.freg<>>
%s2 = "snitch_stream.strided_write"(%ptr2, %stride_pattern) {"dm" = #int<2>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable<!riscv.freg<>>
%c128 = riscv.li 128 : () -> !riscv.reg<>

"snitch_stream.generic"(%c128, %s0, %s1, %s2) <{"operandSegmentSizes" = array<i32: 1, 2, 1>}> ({
^0(%x : !riscv.freg<>, %y : !riscv.freg<>):
%r0 = riscv.fadd.d %x, %y : (!riscv.freg<>, !riscv.freg<>) -> !riscv.freg<>
snitch_stream.yield %r0 : !riscv.freg<>
}) : (!riscv.reg<>, !stream.readable<!riscv.freg<>>, !stream.readable<!riscv.freg<>>, !stream.writable<!riscv.freg<>>) -> ()

"snitch_stream.generic"(%c128, %s1, %s0, %s2) <{"operandSegmentSizes" = array<i32: 1, 2, 1>}> ({
^0(%x : !riscv.freg<>, %y : !riscv.freg<>):
%r0 = riscv.fadd.d %x, %y : (!riscv.freg<>, !riscv.freg<>) -> !riscv.freg<>
snitch_stream.yield %r0 : !riscv.freg<>
}) : (!riscv.reg<>, !stream.readable<!riscv.freg<>>, !stream.readable<!riscv.freg<>>, !stream.writable<!riscv.freg<>>) -> ()

// CHECK: builtin.module {

// CHECK-NEXT: %stride_pattern, %ptr0, %ptr1, %ptr2 = "test.op"() : () -> (!snitch_stream.stride_pattern_type, !riscv.reg<>, !riscv.reg<>, !riscv.reg<>)
// CHECK-NEXT: %s0 = "snitch_stream.strided_read"(%ptr0, %stride_pattern) {"dm" = #int<0>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable<!riscv.freg<ft0>>
// CHECK-NEXT: %s1 = "snitch_stream.strided_read"(%ptr1, %stride_pattern) {"dm" = #int<1>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable<!riscv.freg<ft1>>
// CHECK-NEXT: %s2 = "snitch_stream.strided_write"(%ptr2, %stride_pattern) {"dm" = #int<2>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable<!riscv.freg<ft2>>
// CHECK-NEXT: %c128 = riscv.li 128 : () -> !riscv.reg<>
// CHECK-NEXT: "snitch_stream.generic"(%c128, %s0, %s1, %s2) <{"operandSegmentSizes" = array<i32: 1, 2, 1>}> ({
// CHECK-NEXT: ^0(%{{.*}} : !riscv.freg<ft0>, %{{.*}} : !riscv.freg<ft1>):
// CHECK-NEXT: %{{.*}} = riscv.fadd.d %{{.*}}, %{{.*}} : (!riscv.freg<ft0>, !riscv.freg<ft1>) -> !riscv.freg<ft2>
// CHECK-NEXT: snitch_stream.yield %{{.*}} : !riscv.freg<ft2>
// CHECK-NEXT: }) : (!riscv.reg<>, !stream.readable<!riscv.freg<ft0>>, !stream.readable<!riscv.freg<ft1>>, !stream.writable<!riscv.freg<ft2>>) -> ()
// CHECK-NEXT: "snitch_stream.generic"(%c128, %s1, %s0, %s2) <{"operandSegmentSizes" = array<i32: 1, 2, 1>}> ({
// CHECK-NEXT: ^1(%{{.*}} : !riscv.freg<ft1>, %{{.*}} : !riscv.freg<ft0>):
// CHECK-NEXT: %{{.*}} = riscv.fadd.d %{{.*}}, %{{.*}} : (!riscv.freg<ft1>, !riscv.freg<ft0>) -> !riscv.freg<ft2>
// CHECK-NEXT: snitch_stream.yield %{{.*}} : !riscv.freg<ft2>
// CHECK-NEXT: }) : (!riscv.reg<>, !stream.readable<!riscv.freg<ft1>>, !stream.readable<!riscv.freg<ft0>>, !stream.writable<!riscv.freg<ft2>>) -> ()

// CHECK-NEXT: }
2 changes: 2 additions & 0 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
reconcile_unrealized_casts,
riscv_register_allocation,
riscv_scf_loop_range_folding,
snitch_register_allocation,
)
from xdsl.transforms.experimental import (
convert_stencil_to_ll_mlir,
Expand Down Expand Up @@ -143,6 +144,7 @@ def get_all_passes() -> list[type[ModulePass]]:
reduce_register_pressure.RiscvReduceRegisterPressurePass,
riscv_register_allocation.RISCVRegisterAllocation,
riscv_scf_loop_range_folding.RiscvScfLoopRangeFoldingPass,
snitch_register_allocation.SnitchRegisterAllocation,
convert_arith_to_riscv.ConvertArithToRiscvPass,
convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass,
convert_memref_to_riscv.ConvertMemrefToRiscvPass,
Expand Down
80 changes: 80 additions & 0 deletions xdsl/transforms/snitch_register_allocation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from dataclasses import dataclass
from typing import Any, cast

from xdsl.dialects import riscv, snitch_stream, stream
from xdsl.dialects.builtin import ModuleOp
from xdsl.ir import MLContext
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)


class AllocateSnitchStridedStreamRegisters(RewritePattern):
"""
Allocates the register used by the stream as the one specified by the `dm`
(data mover) attribute. Must be called before allocating the registers in the
`snitch_stream.generic` body.
"""

@op_type_rewrite_pattern
def match_and_rewrite(
self,
op: snitch_stream.StridedReadOp | snitch_stream.StridedWriteOp,
rewriter: PatternRewriter,
/,
):
stream_type = op.stream.type
assert isinstance(
stream_type, stream.ReadableStreamType | stream.WritableStreamType
)
stream_type = cast(stream.StreamType[Any], stream_type)
op.stream.type = type(stream_type)(riscv.Registers.FT[op.dm.data])


class AllocateSnitchGenericRegisters(RewritePattern):
"""
Allocates the registers in the body of a `snitch_stream.generic` operation by assigning
them to the ones specified by the streams.
"""

@op_type_rewrite_pattern
def match_and_rewrite(
self, op: snitch_stream.GenericOp, rewriter: PatternRewriter, /
):
block = op.body.block

for arg, input in zip(block.args, op.inputs):
assert isinstance(input.type, stream.ReadableStreamType)
input_type: stream.ReadableStreamType[Any] = input.type
arg.type = input_type.element_type

yield_op = block.last_op
assert isinstance(yield_op, snitch_stream.YieldOp)

for arg, output in zip(yield_op.values, op.outputs):
assert isinstance(output.type, stream.WritableStreamType)
output_type: stream.WritableStreamType[Any] = output.type
arg.type = output_type.element_type


@dataclass
class SnitchRegisterAllocation(ModulePass):
"""
Allocates unallocated registers for snitch operations.
"""

name = "snitch-allocate-registers"

def apply(self, ctx: MLContext, op: ModuleOp) -> None:
PatternRewriteWalker(
AllocateSnitchStridedStreamRegisters(),
apply_recursively=False,
).rewrite_module(op)
PatternRewriteWalker(
AllocateSnitchGenericRegisters(),
apply_recursively=False,
).rewrite_module(op)

0 comments on commit 5b746bf

Please sign in to comment.