-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
backend: (snitch) add snitch register allocation (#1741)
- Loading branch information
1 parent
87a34d9
commit 5b746bf
Showing
3 changed files
with
121 additions
and
0 deletions.
There are no files selected for viewing
39 changes: 39 additions & 0 deletions
39
tests/filecheck/transforms/snitch_register_allocation.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// RUN: xdsl-opt -p snitch-allocate-registers %s | filecheck %s | ||
|
||
%stride_pattern, %ptr0, %ptr1, %ptr2 = "test.op"() : () -> (!snitch_stream.stride_pattern_type, !riscv.reg<>, !riscv.reg<>, !riscv.reg<>) | ||
%s0 = "snitch_stream.strided_read"(%ptr0, %stride_pattern) {"dm" = #int<0>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable<!riscv.freg<>> | ||
%s1 = "snitch_stream.strided_read"(%ptr1, %stride_pattern) {"dm" = #int<1>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable<!riscv.freg<>> | ||
%s2 = "snitch_stream.strided_write"(%ptr2, %stride_pattern) {"dm" = #int<2>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable<!riscv.freg<>> | ||
%c128 = riscv.li 128 : () -> !riscv.reg<> | ||
|
||
"snitch_stream.generic"(%c128, %s0, %s1, %s2) <{"operandSegmentSizes" = array<i32: 1, 2, 1>}> ({ | ||
^0(%x : !riscv.freg<>, %y : !riscv.freg<>): | ||
%r0 = riscv.fadd.d %x, %y : (!riscv.freg<>, !riscv.freg<>) -> !riscv.freg<> | ||
snitch_stream.yield %r0 : !riscv.freg<> | ||
}) : (!riscv.reg<>, !stream.readable<!riscv.freg<>>, !stream.readable<!riscv.freg<>>, !stream.writable<!riscv.freg<>>) -> () | ||
|
||
"snitch_stream.generic"(%c128, %s1, %s0, %s2) <{"operandSegmentSizes" = array<i32: 1, 2, 1>}> ({ | ||
^0(%x : !riscv.freg<>, %y : !riscv.freg<>): | ||
%r0 = riscv.fadd.d %x, %y : (!riscv.freg<>, !riscv.freg<>) -> !riscv.freg<> | ||
snitch_stream.yield %r0 : !riscv.freg<> | ||
}) : (!riscv.reg<>, !stream.readable<!riscv.freg<>>, !stream.readable<!riscv.freg<>>, !stream.writable<!riscv.freg<>>) -> () | ||
|
||
// CHECK: builtin.module { | ||
|
||
// CHECK-NEXT: %stride_pattern, %ptr0, %ptr1, %ptr2 = "test.op"() : () -> (!snitch_stream.stride_pattern_type, !riscv.reg<>, !riscv.reg<>, !riscv.reg<>) | ||
// CHECK-NEXT: %s0 = "snitch_stream.strided_read"(%ptr0, %stride_pattern) {"dm" = #int<0>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable<!riscv.freg<ft0>> | ||
// CHECK-NEXT: %s1 = "snitch_stream.strided_read"(%ptr1, %stride_pattern) {"dm" = #int<1>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable<!riscv.freg<ft1>> | ||
// CHECK-NEXT: %s2 = "snitch_stream.strided_write"(%ptr2, %stride_pattern) {"dm" = #int<2>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable<!riscv.freg<ft2>> | ||
// CHECK-NEXT: %c128 = riscv.li 128 : () -> !riscv.reg<> | ||
// CHECK-NEXT: "snitch_stream.generic"(%c128, %s0, %s1, %s2) <{"operandSegmentSizes" = array<i32: 1, 2, 1>}> ({ | ||
// CHECK-NEXT: ^0(%{{.*}} : !riscv.freg<ft0>, %{{.*}} : !riscv.freg<ft1>): | ||
// CHECK-NEXT: %{{.*}} = riscv.fadd.d %{{.*}}, %{{.*}} : (!riscv.freg<ft0>, !riscv.freg<ft1>) -> !riscv.freg<ft2> | ||
// CHECK-NEXT: snitch_stream.yield %{{.*}} : !riscv.freg<ft2> | ||
// CHECK-NEXT: }) : (!riscv.reg<>, !stream.readable<!riscv.freg<ft0>>, !stream.readable<!riscv.freg<ft1>>, !stream.writable<!riscv.freg<ft2>>) -> () | ||
// CHECK-NEXT: "snitch_stream.generic"(%c128, %s1, %s0, %s2) <{"operandSegmentSizes" = array<i32: 1, 2, 1>}> ({ | ||
// CHECK-NEXT: ^1(%{{.*}} : !riscv.freg<ft1>, %{{.*}} : !riscv.freg<ft0>): | ||
// CHECK-NEXT: %{{.*}} = riscv.fadd.d %{{.*}}, %{{.*}} : (!riscv.freg<ft1>, !riscv.freg<ft0>) -> !riscv.freg<ft2> | ||
// CHECK-NEXT: snitch_stream.yield %{{.*}} : !riscv.freg<ft2> | ||
// CHECK-NEXT: }) : (!riscv.reg<>, !stream.readable<!riscv.freg<ft1>>, !stream.readable<!riscv.freg<ft0>>, !stream.writable<!riscv.freg<ft2>>) -> () | ||
|
||
// CHECK-NEXT: } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, cast | ||
|
||
from xdsl.dialects import riscv, snitch_stream, stream | ||
from xdsl.dialects.builtin import ModuleOp | ||
from xdsl.ir import MLContext | ||
from xdsl.passes import ModulePass | ||
from xdsl.pattern_rewriter import ( | ||
PatternRewriter, | ||
PatternRewriteWalker, | ||
RewritePattern, | ||
op_type_rewrite_pattern, | ||
) | ||
|
||
|
||
class AllocateSnitchStridedStreamRegisters(RewritePattern): | ||
""" | ||
Allocates the register used by the stream as the one specified by the `dm` | ||
(data mover) attribute. Must be called before allocating the registers in the | ||
`snitch_stream.generic` body. | ||
""" | ||
|
||
@op_type_rewrite_pattern | ||
def match_and_rewrite( | ||
self, | ||
op: snitch_stream.StridedReadOp | snitch_stream.StridedWriteOp, | ||
rewriter: PatternRewriter, | ||
/, | ||
): | ||
stream_type = op.stream.type | ||
assert isinstance( | ||
stream_type, stream.ReadableStreamType | stream.WritableStreamType | ||
) | ||
stream_type = cast(stream.StreamType[Any], stream_type) | ||
op.stream.type = type(stream_type)(riscv.Registers.FT[op.dm.data]) | ||
|
||
|
||
class AllocateSnitchGenericRegisters(RewritePattern): | ||
""" | ||
Allocates the registers in the body of a `snitch_stream.generic` operation by assigning | ||
them to the ones specified by the streams. | ||
""" | ||
|
||
@op_type_rewrite_pattern | ||
def match_and_rewrite( | ||
self, op: snitch_stream.GenericOp, rewriter: PatternRewriter, / | ||
): | ||
block = op.body.block | ||
|
||
for arg, input in zip(block.args, op.inputs): | ||
assert isinstance(input.type, stream.ReadableStreamType) | ||
input_type: stream.ReadableStreamType[Any] = input.type | ||
arg.type = input_type.element_type | ||
|
||
yield_op = block.last_op | ||
assert isinstance(yield_op, snitch_stream.YieldOp) | ||
|
||
for arg, output in zip(yield_op.values, op.outputs): | ||
assert isinstance(output.type, stream.WritableStreamType) | ||
output_type: stream.WritableStreamType[Any] = output.type | ||
arg.type = output_type.element_type | ||
|
||
|
||
@dataclass | ||
class SnitchRegisterAllocation(ModulePass): | ||
""" | ||
Allocates unallocated registers for snitch operations. | ||
""" | ||
|
||
name = "snitch-allocate-registers" | ||
|
||
def apply(self, ctx: MLContext, op: ModuleOp) -> None: | ||
PatternRewriteWalker( | ||
AllocateSnitchStridedStreamRegisters(), | ||
apply_recursively=False, | ||
).rewrite_module(op) | ||
PatternRewriteWalker( | ||
AllocateSnitchGenericRegisters(), | ||
apply_recursively=False, | ||
).rewrite_module(op) |