Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend: (riscv) Exclude FP registers from RISC-V regalloc in the presence of snitch_stream IO ops #1786

Merged
merged 5 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// RUN: xdsl-opt --split-input-file -p "riscv-allocate-registers{allocation_strategy=LivenessBlockNaive}" %s | filecheck %s
// RUN: xdsl-opt --split-input-file -p "riscv-allocate-registers{allocation_strategy=LivenessBlockNaive exclude_snitch_reserved=false}" %s | filecheck %s --check-prefix=CHECK-SNITCH-UNRESERVED

riscv_func.func @main() {
%stride_pattern, %ptr0 = "test.op"() : () -> (!snitch_stream.stride_pattern_type, !riscv.reg<>)
%s0 = "snitch_stream.strided_read"(%ptr0, %stride_pattern) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable<!riscv.freg<>>
%v0, %v1, %v2 = "test.op"() : () -> (!riscv.freg<>, !riscv.freg<>, !riscv.freg<>)
%sum1 = riscv.fadd.s %v0, %v1 : (!riscv.freg<>, !riscv.freg<>) -> !riscv.freg<>
%sum2 = riscv.fadd.s %sum1, %v2 : (!riscv.freg<>, !riscv.freg<>) -> !riscv.freg<>
riscv_func.return
}

// CHECK: builtin.module {
// CHECK-NEXT: riscv_func.func @main() {
// CHECK-NEXT: %stride_pattern, %ptr0 = "test.op"() : () -> (!snitch_stream.stride_pattern_type, !riscv.reg<>)
// CHECK-NEXT: %s0 = "snitch_stream.strided_read"(%ptr0, %stride_pattern) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable<!riscv.freg<>>
// CHECK-NEXT: %v0, %v1, %v2 = "test.op"() : () -> (!riscv.freg<ft3>, !riscv.freg<ft5>, !riscv.freg<ft4>)
// CHECK-NEXT: %sum1 = riscv.fadd.s %v0, %v1 : (!riscv.freg<ft3>, !riscv.freg<ft5>) -> !riscv.freg<ft3>
// CHECK-NEXT: %sum2 = riscv.fadd.s %sum1, %v2 : (!riscv.freg<ft3>, !riscv.freg<ft4>) -> !riscv.freg<ft3>
// CHECK-NEXT: riscv_func.return
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK-SNITCH-UNRESERVED: builtin.module {
// CHECK-SNITCH-UNRESERVED-NEXT: riscv_func.func @main() {
// CHECK-SNITCH-UNRESERVED-NEXT: %stride_pattern, %ptr0 = "test.op"() : () -> (!snitch_stream.stride_pattern_type, !riscv.reg<>)
// CHECK-SNITCH-UNRESERVED-NEXT: %s0 = "snitch_stream.strided_read"(%ptr0, %stride_pattern) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable<!riscv.freg<>>
// CHECK-SNITCH-UNRESERVED-NEXT: %v0, %v1, %v2 = "test.op"() : () -> (!riscv.freg<ft0>, !riscv.freg<ft2>, !riscv.freg<ft1>)
// CHECK-SNITCH-UNRESERVED-NEXT: %sum1 = riscv.fadd.s %v0, %v1 : (!riscv.freg<ft0>, !riscv.freg<ft2>) -> !riscv.freg<ft0>
// CHECK-SNITCH-UNRESERVED-NEXT: %sum2 = riscv.fadd.s %sum1, %v2 : (!riscv.freg<ft0>, !riscv.freg<ft1>) -> !riscv.freg<ft0>
// CHECK-SNITCH-UNRESERVED-NEXT: riscv_func.return
// CHECK-SNITCH-UNRESERVED-NEXT: }
// CHECK-SNITCH-UNRESERVED-NEXT: }

// -----

riscv_func.func @main() {
%stride_pattern, %ptr0 = "test.op"() : () -> (!snitch_stream.stride_pattern_type, !riscv.reg<>)
%s0 = "snitch_stream.strided_write"(%ptr0, %stride_pattern) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable<!riscv.freg<>>
%v0, %v1, %v2 = "test.op"() : () -> (!riscv.freg<>, !riscv.freg<>, !riscv.freg<>)
%sum1 = riscv.fadd.s %v0, %v1 : (!riscv.freg<>, !riscv.freg<>) -> !riscv.freg<>
%sum2 = riscv.fadd.s %sum1, %v2 : (!riscv.freg<>, !riscv.freg<>) -> !riscv.freg<>
riscv_func.return
}

// CHECK: builtin.module {
// CHECK-NEXT: riscv_func.func @main() {
// CHECK-NEXT: %stride_pattern, %ptr0 = "test.op"() : () -> (!snitch_stream.stride_pattern_type, !riscv.reg<>)
// CHECK-NEXT: %s0 = "snitch_stream.strided_write"(%ptr0, %stride_pattern) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable<!riscv.freg<>>
// CHECK-NEXT: %v0, %v1, %v2 = "test.op"() : () -> (!riscv.freg<ft3>, !riscv.freg<ft5>, !riscv.freg<ft4>)
// CHECK-NEXT: %sum1 = riscv.fadd.s %v0, %v1 : (!riscv.freg<ft3>, !riscv.freg<ft5>) -> !riscv.freg<ft3>
// CHECK-NEXT: %sum2 = riscv.fadd.s %sum1, %v2 : (!riscv.freg<ft3>, !riscv.freg<ft4>) -> !riscv.freg<ft3>
// CHECK-NEXT: riscv_func.return
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK-SNITCH-UNRESERVED: builtin.module {
// CHECK-SNITCH-UNRESERVED-NEXT: riscv_func.func @main() {
// CHECK-SNITCH-UNRESERVED-NEXT: %stride_pattern, %ptr0 = "test.op"() : () -> (!snitch_stream.stride_pattern_type, !riscv.reg<>)
// CHECK-SNITCH-UNRESERVED-NEXT: %s0 = "snitch_stream.strided_write"(%ptr0, %stride_pattern) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable<!riscv.freg<>>
// CHECK-SNITCH-UNRESERVED-NEXT: %v0, %v1, %v2 = "test.op"() : () -> (!riscv.freg<ft0>, !riscv.freg<ft2>, !riscv.freg<ft1>)
// CHECK-SNITCH-UNRESERVED-NEXT: %sum1 = riscv.fadd.s %v0, %v1 : (!riscv.freg<ft0>, !riscv.freg<ft2>) -> !riscv.freg<ft0>
// CHECK-SNITCH-UNRESERVED-NEXT: %sum2 = riscv.fadd.s %sum1, %v2 : (!riscv.freg<ft0>, !riscv.freg<ft1>) -> !riscv.freg<ft0>
// CHECK-SNITCH-UNRESERVED-NEXT: riscv_func.return
// CHECK-SNITCH-UNRESERVED-NEXT: }
// CHECK-SNITCH-UNRESERVED-NEXT: }
35 changes: 26 additions & 9 deletions xdsl/backend/riscv/register_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from ordered_set import OrderedSet

from xdsl.backend.riscv.register_queue import RegisterQueue
from xdsl.dialects import riscv_func, riscv_scf
from xdsl.dialects import riscv_func, riscv_scf, snitch_stream
from xdsl.dialects.riscv import (
FloatRegisterType,
IntRegisterType,
RISCVOp,
RISCVRegisterType,
)
from xdsl.ir import Block, Operation, SSAValue
from xdsl.transforms.snitch_register_allocation import get_snitch_reserved


def gather_allocated(func: riscv_func.FuncOp) -> set[RISCVRegisterType]:
Expand All @@ -31,6 +32,16 @@ def gather_allocated(func: riscv_func.FuncOp) -> set[RISCVRegisterType]:
return allocated


def _uses_snitch_stream(func: riscv_func.FuncOp) -> bool:
"""Utility method to detect use of read/write ops of the `snitch_stream` dialect."""

for op in func.walk():
if isinstance(op, snitch_stream.StridedReadOp | snitch_stream.StridedWriteOp):
return True

return False
compor marked this conversation as resolved.
Show resolved Hide resolved


class RegisterAllocator(abc.ABC):
"""
Base class for register allocation strategies.
Expand Down Expand Up @@ -73,6 +84,7 @@ class RegisterAllocatorLivenessBlockNaive(RegisterAllocator):
live_ins_per_block: dict[Block, OrderedSet[SSAValue]]

exclude_preallocated: bool = True
exclude_snitch_reserved: bool = True

def __init__(self) -> None:
self.available_registers = RegisterQueue()
Expand Down Expand Up @@ -188,17 +200,22 @@ def allocate_func(self, func: riscv_func.FuncOp) -> None:
f"Cannot register allocate func with {len(func.body.blocks)} blocks."
)

preallocated: set[RISCVRegisterType] = set()

if self.exclude_preallocated:
preallocated = gather_allocated(func)
preallocated |= gather_allocated(func)

if self.exclude_snitch_reserved and _uses_snitch_stream(func):
preallocated |= get_snitch_reserved()

for pa_reg in preallocated:
if isinstance(pa_reg, IntRegisterType | FloatRegisterType):
self.available_registers.reserved_registers.add(pa_reg)
for pa_reg in preallocated:
if isinstance(pa_reg, IntRegisterType | FloatRegisterType):
self.available_registers.reserved_registers.add(pa_reg)

if pa_reg in self.available_registers.available_int_registers:
self.available_registers.available_int_registers.remove(pa_reg)
if pa_reg in self.available_registers.available_float_registers:
self.available_registers.available_float_registers.remove(pa_reg)
if pa_reg in self.available_registers.available_int_registers:
self.available_registers.available_int_registers.remove(pa_reg)
if pa_reg in self.available_registers.available_float_registers:
self.available_registers.available_float_registers.remove(pa_reg)

block = func.body.block

Expand Down
4 changes: 4 additions & 0 deletions xdsl/transforms/riscv_register_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class RISCVRegisterAllocation(ModulePass):
are excluded completely from any further allocation decisions.
"""

exclude_snitch_reserved: bool = True
"""Excludes floating-point registers that are used by the Snitch ISA extensions."""

def apply(self, ctx: MLContext, op: ModuleOp) -> None:
allocator_strategies = {
"LivenessBlockNaive": RegisterAllocatorLivenessBlockNaive,
Expand All @@ -50,4 +53,5 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None:
if self.limit_registers is not None:
allocator.available_registers.limit_registers(self.limit_registers)
allocator.exclude_preallocated = self.exclude_preallocated
allocator.exclude_snitch_reserved = self.exclude_snitch_reserved
allocator.allocate_func(inner_op)
12 changes: 12 additions & 0 deletions xdsl/transforms/snitch_register_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@
)


def get_snitch_reserved() -> set[riscv.FloatRegisterType]:
"""Utility method to make explicit the Snitch ISA assumptions wrt the
floating-point registers that are considered reserved.
Currently, the first 3 floating-point registers are reserved.
"""
compor marked this conversation as resolved.
Show resolved Hide resolved

num_reserved = 3
assert len(riscv.Registers.FT) >= num_reserved

return {riscv.Registers.FT[i] for i in range(0, num_reserved)}


class AllocateSnitchStridedStreamRegisters(RewritePattern):
"""
Allocates the register used by the stream as the one specified by the `dm`
Expand Down