Skip to content

Commit

Permalink
backend: (riscv) add 1d, 3d, and 4d snitch stream lowerings (#1781)
Browse files Browse the repository at this point in the history
tested in the experiments repo that this works
  • Loading branch information
superlopuh authored Nov 14, 2023
1 parent 622d448 commit d67edc1
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,77 @@
%A, %B, %C = "test.op"() : () -> (!riscv.reg<>, !riscv.reg<>, !riscv.reg<>)
// CHECK-NEXT: %A, %B, %C = "test.op"() : () -> (!riscv.reg<>, !riscv.reg<>, !riscv.reg<>)

%0 = "snitch_stream.stride_pattern"() {"ub" = [#builtin.int<2>, #builtin.int<3>], "strides" = [#builtin.int<24>, #builtin.int<8>], "dm" = #builtin.int<31>} : () -> !snitch_stream.stride_pattern_type
// CHECK-NEXT: %0 = riscv.li 2 : () -> !riscv.reg<>
// CHECK-NEXT: %1 = riscv.li 3 : () -> !riscv.reg<>
// CHECK-NEXT: %2 = riscv.li 24 : () -> !riscv.reg<>
// CHECK-NEXT: %3 = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %4 = riscv.addi %0, -1 : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %5 = riscv.addi %1, -1 : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "snitch.ssr_set_dimension_bound"(%4) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: "snitch.ssr_set_dimension_bound"(%5) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<1>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: "snitch.ssr_set_dimension_stride"(%2) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: %6 = riscv.mul %4, %2 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %7 = riscv.sub %3, %6 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "snitch.ssr_set_dimension_stride"(%7) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<1>} : (!riscv.reg<>) -> ()


%1 = "snitch_stream.strided_read"(%A, %0) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable<!riscv.freg<ft0>>
%sp1 = "snitch_stream.stride_pattern"() {"ub" = [#builtin.int<2>], "strides" = [#builtin.int<8>], "dm" = #builtin.int<31>} : () -> !snitch_stream.stride_pattern_type
// CHECK-NEXT: %{{.*}} = riscv.li 2 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.addi %{{.*}}, -1 : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "snitch.ssr_set_dimension_bound"(%{{.*}}) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: "snitch.ssr_set_dimension_stride"(%{{.*}}) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: %{{.*}} = riscv.li 0 : () -> !riscv.reg<>

%sp2 = "snitch_stream.stride_pattern"() {"ub" = [#builtin.int<2>, #builtin.int<3>], "strides" = [#builtin.int<24>, #builtin.int<8>], "dm" = #builtin.int<31>} : () -> !snitch_stream.stride_pattern_type
// CHECK-NEXT: %{{.*}} = riscv.li 2 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.li 3 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.li 24 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.addi %{{.*}}, -1 : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.addi %{{.*}}, -1 : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "snitch.ssr_set_dimension_bound"(%{{.*}}) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: "snitch.ssr_set_dimension_bound"(%{{.*}}) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<1>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: "snitch.ssr_set_dimension_stride"(%{{.*}}) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: %{{.*}} = riscv.li 0 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.mul %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.sub %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "snitch.ssr_set_dimension_stride"(%{{.*}}) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<1>} : (!riscv.reg<>) -> ()


%sp4 = "snitch_stream.stride_pattern"() {"ub" = [#builtin.int<2>, #builtin.int<3>, #builtin.int<4>, #builtin.int<5>], "strides" = [#builtin.int<480>, #builtin.int<160>, #builtin.int<40>, #builtin.int<8>], "dm" = #builtin.int<31>} : () -> !snitch_stream.stride_pattern_type
// CHECK-NEXT: %{{.*}} = riscv.li 2 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.li 3 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.li 4 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.li 5 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.li 480 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.li 160 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.li 40 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.addi %{{.*}}, -1 : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.addi %{{.*}}, -1 : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.addi %{{.*}}, -1 : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.addi %{{.*}}, -1 : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "snitch.ssr_set_dimension_bound"(%{{.*}}) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: "snitch.ssr_set_dimension_bound"(%{{.*}}) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<1>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: "snitch.ssr_set_dimension_bound"(%{{.*}}) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<2>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: "snitch.ssr_set_dimension_bound"(%{{.*}}) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<3>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: "snitch.ssr_set_dimension_stride"(%{{.*}}) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: %{{.*}} = riscv.li 0 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.mul %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.sub %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "snitch.ssr_set_dimension_stride"(%{{.*}}) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<1>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: %{{.*}} = riscv.mul %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.sub %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "snitch.ssr_set_dimension_stride"(%{{.*}}) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<2>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: %{{.*}} = riscv.mul %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.sub %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "snitch.ssr_set_dimension_stride"(%{{.*}}) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<3>} : (!riscv.reg<>) -> ()

%1 = "snitch_stream.strided_read"(%A, %sp2) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable<!riscv.freg<ft0>>
// CHECK-NEXT: "snitch.ssr_set_dimension_source"(%A) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<1>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: %a = riscv.get_float_register : () -> !riscv.freg<ft0>

%2 = "snitch_stream.strided_read"(%B, %0) {"dm" = #builtin.int<1>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable<!riscv.freg<ft1>>
%2 = "snitch_stream.strided_read"(%B, %sp2) {"dm" = #builtin.int<1>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable<!riscv.freg<ft1>>
// CHECK-NEXT: "snitch.ssr_set_dimension_source"(%B) {"dm" = #builtin.int<1>, "dimension" = #builtin.int<1>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: %b = riscv.get_float_register : () -> !riscv.freg<ft1>

%3 = "snitch_stream.strided_write"(%C, %0) {"dm" = #builtin.int<2>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable<!riscv.freg<ft2>>
%3 = "snitch_stream.strided_write"(%C, %sp2) {"dm" = #builtin.int<2>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable<!riscv.freg<ft2>>
// CHECK-NEXT: "snitch.ssr_set_dimension_destination"(%C) {"dm" = #builtin.int<2>, "dimension" = #builtin.int<1>} : (!riscv.reg<>) -> ()


%4 = riscv.li 6 : () -> !riscv.reg<>
// CHECK-NEXT: %8 = riscv.li 6 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.li 6 : () -> !riscv.reg<>


"snitch_stream.generic"(%4, %1, %2, %3) <{"operandSegmentSizes" = array<i32: 1, 2, 1>}> ({
Expand All @@ -42,8 +84,8 @@
snitch_stream.yield %sum : !riscv.freg<ft2>
}) : (!riscv.reg<>, !stream.readable<!riscv.freg<ft0>>, !stream.readable<!riscv.freg<ft1>>, !stream.writable<!riscv.freg<ft2>>) -> ()
// CHECK-NEXT: "snitch.ssr_enable"() : () -> ()
// CHECK-NEXT: %9 = riscv.addi %8, -1 : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: riscv_snitch.frep_outer %9, 0, 0 ({
// CHECK-NEXT: %{{.*}} = riscv.addi %{{.*}}, -1 : (!riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: riscv_snitch.frep_outer %{{.*}}, 0, 0 ({
// CHECK-NEXT: %sum = riscv.fadd.d %a, %b : (!riscv.freg<ft0>, !riscv.freg<ft1>) -> !riscv.freg<ft2>
// CHECK-NEXT: riscv_snitch.frep_yield %sum : (!riscv.freg<ft2>) -> ()
// CHECK-NEXT: }) : (!riscv.reg<>) -> ()
Expand Down
83 changes: 57 additions & 26 deletions xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
snitch_stream,
stream,
)
from xdsl.ir import MLContext
from xdsl.ir import MLContext, Operation
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand All @@ -29,45 +29,76 @@ def match_and_rewrite(
self, op: snitch_stream.StridePatternOp, rewriter: PatternRewriter, /
):
# reference implementation:
# // Configure an SSR data mover for a 2D loop nest.
# https://github.com/pulp-platform/snitch/blob/d026f47843f0ea6c269244c4e6851e0e09141ec3/sw/snRuntime/src/ssr.h#L73
#
# inline void snrt_ssr_loop_2d(enum snrt_ssr_dm dm, size_t b0, size_t b1,
# size_t s0, size_t s1) {
# 4d loop reproduced here:
#
# // Configure an SSR data mover for a 4D loop nest.
# // b0: Inner-most bound (limit of loop)
# // b3: Outer-most bound (limit of loop)
# // s0: increment size of inner-most loop
# inline void snrt_ssr_loop_4d(enum snrt_ssr_dm dm, size_t b0, size_t b1,
# size_t b2, size_t b3, size_t s0, size_t s1,
# size_t s2, size_t s3) {
# --b0;
# --b1;
# --b2;
# --b3;
# write_ssr_cfg(REG_BOUNDS + 0, dm, b0);
# write_ssr_cfg(REG_BOUNDS + 1, dm, b1);
# write_ssr_cfg(REG_BOUNDS + 2, dm, b2);
# write_ssr_cfg(REG_BOUNDS + 3, dm, b3);
# size_t a = 0;
# write_ssr_cfg(REG_STRIDES + 0, dm, s0 - a);
# a += s0 * b0;
# write_ssr_cfg(REG_STRIDES + 1, dm, s1 - a);
# a += s1 * b1;
# write_ssr_cfg(REG_STRIDES + 2, dm, s2 - a);
# a += s2 * b2;
# write_ssr_cfg(REG_STRIDES + 3, dm, s3 - a);
# a += s3 * b3;
# }
dim = len(op.ub)
if dim != 2:
raise NotImplementedError("Only 2d loop stride patterns are supported")

int_0 = builtin.IntAttr(0)
int_1 = builtin.IntAttr(1)
rank = len(op.ub)
if rank > 4:
raise NotImplementedError(
"Only 1d, 2d, 3d, or 4d loop stride patterns are supported"
)

b = tuple(b.data for b in op.ub.data)
s = tuple(s.data for s in op.strides.data)
ints = tuple(builtin.IntAttr(i) for i in range(rank))

rewriter.insert_op_before_matched_op(
[
b0 := riscv.LiOp(b[0]),
b1 := riscv.LiOp(b[1]),
s0 := riscv.LiOp(s[0]),
s1 := riscv.LiOp(s[1]),
new_b0 := riscv.AddiOp(b0, -1),
new_b1 := riscv.AddiOp(b1, -1),
snitch.SsrSetDimensionBoundOp(new_b0, op.dm, int_0),
snitch.SsrSetDimensionBoundOp(new_b1, op.dm, int_1),
snitch.SsrSetDimensionStrideOp(s0, op.dm, int_0),
a0 := riscv.MulOp(new_b0, s0, rd=riscv.IntRegisterType.unallocated()),
stride_1 := riscv.SubOp(s1, a0, rd=riscv.IntRegisterType.unallocated()),
snitch.SsrSetDimensionStrideOp(stride_1, op.dm, int_1),
],
b_ops = tuple(riscv.LiOp(b.data) for b in op.ub.data)
s_ops = tuple(riscv.LiOp(s.data) for s in op.strides.data)
new_b_ops = tuple(riscv.AddiOp(b_op.rd, -1) for b_op in b_ops)
set_bound_ops = tuple(
snitch.SsrSetDimensionBoundOp(new_b_op, op.dm, i)
for (i, new_b_op) in zip(ints, new_b_ops)
)

new_ops: list[Operation] = [
*b_ops,
*s_ops,
*new_b_ops,
*set_bound_ops,
snitch.SsrSetDimensionStrideOp(s_ops[0], op.dm, ints[0]),
a_op := riscv.LiOp(0, rd=riscv.IntRegisterType.unallocated()),
]

for i in range(1, rank):
a_inc_op = riscv.MulOp(
new_b_ops[i - 1], s_ops[i - 1], rd=riscv.IntRegisterType.unallocated()
)
new_a_op = riscv.AddOp(
a_op, a_inc_op, rd=riscv.IntRegisterType.unallocated()
)
stride_op = riscv.SubOp(
s_ops[i], new_a_op, rd=riscv.IntRegisterType.unallocated()
)
set_stride_op = snitch.SsrSetDimensionStrideOp(stride_op.rd, op.dm, ints[i])
new_ops.extend((a_inc_op, new_a_op, stride_op, set_stride_op))
a_op = new_a_op

rewriter.insert_op_before_matched_op(new_ops)
rewriter.erase_matched_op()


Expand Down

0 comments on commit d67edc1

Please sign in to comment.