Skip to content

Commit

Permalink
Merge pull request #298 from opencompl/sasha/relu-f32-linalg-snitch-s…
Browse files Browse the repository at this point in the history
…tream

add relu f32 linalg and snitch stream
  • Loading branch information
superlopuh authored Aug 8, 2024
2 parents e30a90f + a3fd9da commit f787eaf
Show file tree
Hide file tree
Showing 15 changed files with 63 additions and 22 deletions.
7 changes: 5 additions & 2 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ TESTSET_FAST = [
],
variant=["baseline", "linalg_xdsl"],
),
*expand("relu/4x8xf32/{variant}", variant=["baseline", "snrt"]),
*expand(
"relu/4x8xf32/{variant}", variant=["baseline", "linalg", "snrt", "snitch_stream"]
),
*expand(
"sum/4x8xf32/{variant}", variant=["baseline", "snrt", "linalg", "linalg_xdsl"]
),
Expand Down Expand Up @@ -563,7 +565,8 @@ rule xdsl_kernel_generate_source:
json="kernels/{kernel}/{shape}/params.json",
template="kernels/{kernel}/linalg.mlir.template",
output:
"kernels/{kernel}/{shape}/{variant}.xdsl.mlir",
# Restrict this rule to variant=linalg_xdsl to avoid ambiguous matches
"kernels/{kernel}/{shape}/linalg_xdsl.xdsl.mlir",
wildcard_constraints:
kernel="|".join(KERNEL_TEMPLATES),
params:
Expand Down
2 changes: 1 addition & 1 deletion kernels/matmul_transb/snitch_stream.xdsl.mlir.template
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ riscv.assembly_section ".text" {

snitch_stream.streaming_region {
stride_patterns = [
#snitch_stream.stride_pattern<ub = [{{M}}, {{N // 8}}, {{K // 2}}], strides = [{{K * 4}}, 0, 8], 8>,
#snitch_stream.stride_pattern<ub = [{{M}}, {{N // 8}}, {{K // 2}}], strides = [{{K * 4}}, 0, 8], repeat = 8>,
#snitch_stream.stride_pattern<ub = [{{M}}, {{N // 8}}, {{K // 2}}, {{8}}], strides = [0, {{4 * 8 * K}}, {{4 * 2}}, {{4 * K}}]>
]
} ins(%X_moved, %Y_moved : !riscv.reg, !riscv.reg) {
Expand Down
17 changes: 17 additions & 0 deletions kernels/relu/linalg.mlir.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module {
func.func public @relu(%arg0: memref<{{M}}x{{N}}xf{{precision}}> {llvm.noalias}, %arg1: memref<{{M}}x{{N}}xf{{precision}}> {llvm.noalias}) -> memref<{{M}}x{{N}}xf{{precision}}> {
%cst = arith.constant 0.000000e+00 : f{{precision}}
linalg.generic {
indexing_maps = [
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>
],
iterator_types = ["parallel", "parallel"]
} ins(%arg0 : memref<{{M}}x{{N}}xf{{precision}}>) outs(%arg1 : memref<{{M}}x{{N}}xf{{precision}}>) {
^bb0(%in: f{{precision}}, %out: f{{precision}}):
%0 = arith.maxf %in, %cst : f{{precision}}
linalg.yield %0 : f{{precision}}
}
return %arg1 : memref<{{M}}x{{N}}xf{{precision}}>
}
}
11 changes: 0 additions & 11 deletions kernels/relu/linalg_xdsl.xdsl.mlir.template

This file was deleted.

28 changes: 28 additions & 0 deletions kernels/relu/snitch_stream.xdsl.mlir.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

riscv.assembly_section ".text" {
riscv.directive ".globl" "relu"
riscv.directive ".p2align" "2"
riscv_func.func @relu(%X : !riscv.reg<a0>, %Y : !riscv.reg<a1>) {
%X_1 = riscv.mv %X : (!riscv.reg<a0>) -> !riscv.reg
%Y_1 = riscv.mv %Y : (!riscv.reg<a1>) -> !riscv.reg
%zero = riscv.get_register : !riscv.reg<zero>
%zero_float = riscv.fcvt.d.w %zero : (!riscv.reg<zero>) -> !riscv.freg
%zero_vector = riscv_snitch.vfcpka.s.s %zero_float, %zero_float : (!riscv.freg, !riscv.freg) -> !riscv.freg
snitch_stream.streaming_region {
patterns = [
#snitch_stream.stride_pattern<ub = [{{M * N // 2}}], strides = [8]>
]
} ins(%X_1 : !riscv.reg) outs(%Y_1 : !riscv.reg) {
^0(%x : !stream.readable<!riscv.freg<ft0>>, %0 : !stream.writable<!riscv.freg<ft1>>):
%niter = riscv.li {{M * N // 2}} : !riscv.reg
%c0 = riscv.li 0 : !riscv.reg
%c1 = riscv.li 1 : !riscv.reg
riscv_scf.for %i : !riscv.reg = %c0 to %niter step %c1 {
%x_1 = riscv_snitch.read from %x : !riscv.freg<ft0>
%y = riscv_snitch.vfmax.s %x_1, %zero_vector : (!riscv.freg<ft0>, !riscv.freg) -> !riscv.freg<ft1>
riscv_snitch.write %y to %0 : !riscv.freg<ft1>
}
}
riscv_func.return
}
}
2 changes: 2 additions & 0 deletions results/kernels.csv
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ pooling_nchw_max_d1_s2_3x3,4x4xf64,linalg_xdsl,275,1023,1020,0.9943820224719101,
conv2d_d1_s1_3x3,4x4xf64,baseline,667,1460,1457,2.9793103448275864,1.3333333333333333,128,145,432,0.21739130434782608,0.5991735537190083,242,108,81,0.36281859070464767,0,16,1.0,1.0,1,0.0,242,0.9097744360902256,24,0,0,0.035982008995502246,0,794,0.0,0.3988005997001499,0.0
conv2d_d1_s1_3x3,4x4xf64,linalg_xdsl,308,1074,1071,2.6797752808988764,0.0,144,178,477,0.577922077922078,0.9888888888888889,180,0,0,0.5844155844155844,0,0,3.214285714285714,3.2142857142857144,1,0.0,56,0.4628099173553719,65,0,0,0.21103896103896103,0,767,0.0,0.7954545454545454,0.0
relu,4x8xf32,baseline,297,1042,1039,0.9897959183673469,1.0,0,98,97,0.32996632996632996,0.6049382716049383,162,32,32,0.5454545454545454,0,32,1.0,1.0,1,0.0,162,0.9585798816568047,7,0,0,0.02356902356902357,0,746,0.0,0.569023569023569,0.0
relu,4x8xf32,linalg,210,954,951,0.9705882352941176,1.0,0,34,33,0.1619047619047619,0.3469387755102041,98,32,32,0.4666666666666667,0,32,1.0,1.0,1,0.0,98,0.9158878504672897,9,0,0,0.04285714285714286,0,745,0.0,0.5095238095238095,0.0
relu,4x8xf32,snrt,85,826,823,0.9473684210526315,0.0,0,19,18,0.2235294117647059,0.9047619047619048,21,0,0,0.24705882352941178,0,0,3.0,3.0,1,0.0,7,0.35,13,0,0,0.15294117647058825,0,742,0.0,0.4,0.0
relu,4x8xf32,snitch_stream,67,801,798,0.9473684210526315,0.0,0,19,18,0.2835820895522388,0.9047619047619048,21,0,0,0.31343283582089554,0,0,3.0,3.0,1,0.0,7,0.2916666666666667,17,0,0,0.2537313432835821,0,735,0.0,0.5671641791044777,0.0
sum,4x8xf32,baseline,238,998,995,2.909090909090909,1.703125,0,33,96,0.13865546218487396,0.2558139534883721,129,109,64,0.542016806722689,0,32,1.0,1.0,1,0.0,129,0.9416058394160584,8,0,0,0.03361344537815126,0,761,0.0,0.5756302521008403,0.0
sum,4x8xf32,snrt,72,835,831,2.823529411764706,0.0,0,17,48,0.2361111111111111,0.8947368421052632,19,0,0,0.2638888888888889,0,0,3.8,3.8,1,0.0,5,0.25,15,0,0,0.20833333333333334,0,764,0.0,0.4722222222222222,0.0
sum,4x8xf32,linalg,247,1011,1008,2.909090909090909,1.703125,0,33,96,0.13360323886639677,0.2558139534883721,129,109,64,0.5222672064777328,0,32,1.0,1.0,1,0.0,129,0.9347826086956522,9,0,0,0.03643724696356275,0,765,0.0,0.5587044534412956,0.0
Expand Down
2 changes: 2 additions & 0 deletions results/kernels.fast.csv
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ pooling_nchw_max_d1_s2_3x3,4x4xf64,linalg_xdsl,275,1023,1020,0.9943820224719101,
conv2d_d1_s1_3x3,4x4xf64,baseline,667,1460,1457,2.9793103448275864,1.3333333333333333,128,145,432,0.21739130434782608,0.5991735537190083,242,108,81,0.36281859070464767,0,16,1.0,1.0,1,0.0,242,0.9097744360902256,24,0,0,0.035982008995502246,0,794,0.0,0.3988005997001499,0.0
conv2d_d1_s1_3x3,4x4xf64,linalg_xdsl,308,1074,1071,2.6797752808988764,0.0,144,178,477,0.577922077922078,0.9888888888888889,180,0,0,0.5844155844155844,0,0,3.214285714285714,3.2142857142857144,1,0.0,56,0.4628099173553719,65,0,0,0.21103896103896103,0,767,0.0,0.7954545454545454,0.0
relu,4x8xf32,baseline,297,1042,1039,0.9897959183673469,1.0,0,98,97,0.32996632996632996,0.6049382716049383,162,32,32,0.5454545454545454,0,32,1.0,1.0,1,0.0,162,0.9585798816568047,7,0,0,0.02356902356902357,0,746,0.0,0.569023569023569,0.0
relu,4x8xf32,linalg,210,954,951,0.9705882352941176,1.0,0,34,33,0.1619047619047619,0.3469387755102041,98,32,32,0.4666666666666667,0,32,1.0,1.0,1,0.0,98,0.9158878504672897,9,0,0,0.04285714285714286,0,745,0.0,0.5095238095238095,0.0
relu,4x8xf32,snrt,85,826,823,0.9473684210526315,0.0,0,19,18,0.2235294117647059,0.9047619047619048,21,0,0,0.24705882352941178,0,0,3.0,3.0,1,0.0,7,0.35,13,0,0,0.15294117647058825,0,742,0.0,0.4,0.0
relu,4x8xf32,snitch_stream,67,801,798,0.9473684210526315,0.0,0,19,18,0.2835820895522388,0.9047619047619048,21,0,0,0.31343283582089554,0,0,3.0,3.0,1,0.0,7,0.2916666666666667,17,0,0,0.2537313432835821,0,735,0.0,0.5671641791044777,0.0
sum,4x8xf32,baseline,238,998,995,2.909090909090909,1.703125,0,33,96,0.13865546218487396,0.2558139534883721,129,109,64,0.542016806722689,0,32,1.0,1.0,1,0.0,129,0.9416058394160584,8,0,0,0.03361344537815126,0,761,0.0,0.5756302521008403,0.0
sum,4x8xf32,snrt,72,835,831,2.823529411764706,0.0,0,17,48,0.2361111111111111,0.8947368421052632,19,0,0,0.2638888888888889,0,0,3.8,3.8,1,0.0,5,0.25,15,0,0,0.20833333333333334,0,764,0.0,0.4722222222222222,0.0
sum,4x8xf32,linalg,247,1011,1008,2.909090909090909,1.703125,0,33,96,0.13360323886639677,0.2558139534883721,129,109,64,0.5222672064777328,0,32,1.0,1.0,1,0.0,129,0.9347826086956522,9,0,0,0.03643724696356275,0,765,0.0,0.5587044534412956,0.0
Expand Down
2 changes: 1 addition & 1 deletion results/pivoted.csv
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ matmul_transb 4x16x16xf32,3386,,,871,849
pooling_nchw_max_d1_s2_3x3 4x4xf64,442,,275,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,582,,271,,
relu 4x4xf64,142,,72,,
relu 4x8xf32,297,,,,85
relu 4x8xf32,297,210,,67,85
saxpy 64xf32,634,634,,,140
sum 4x4xf64,129,,87,,
sum 4x8xf32,238,247,87,,72
Expand Down
2 changes: 1 addition & 1 deletion results/pivoted.fast.csv
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ matmul_transb 4x16x16xf32,3386,,,871,849
pooling_nchw_max_d1_s2_3x3 4x4xf64,442,,275,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,582,,271,,
relu 4x4xf64,142,,72,,
relu 4x8xf32,297,,,,85
relu 4x8xf32,297,210,,67,85
saxpy 64xf32,634,634,,,140
sum 4x4xf64,129,,87,,
sum 4x8xf32,238,247,87,,72
Expand Down
2 changes: 1 addition & 1 deletion results/pivoted_fpu.csv
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ matmul_transb 4x16x16xf32,0.21,,,0.77,0.79
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.33,,0.65,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.22,,0.66,,
relu 4x4xf64,0.13,,0.25,,
relu 4x8xf32,0.33,,,,0.22
relu 4x8xf32,0.33,0.16,,0.28,0.22
saxpy 64xf32,0.10,0.10,,,0.46
sum 4x4xf64,0.13,,0.20,,
sum 4x8xf32,0.14,0.13,0.20,,0.24
Expand Down
2 changes: 1 addition & 1 deletion results/pivoted_fpu.fast.csv
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ matmul_transb 4x16x16xf32,0.21,,,0.77,0.79
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.33,,0.65,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.22,,0.66,,
relu 4x4xf64,0.13,,0.25,,
relu 4x8xf32,0.33,,,,0.22
relu 4x8xf32,0.33,0.16,,0.28,0.22
saxpy 64xf32,0.10,0.10,,,0.46
sum 4x4xf64,0.13,,0.20,,
sum 4x8xf32,0.14,0.13,0.20,,0.24
Expand Down
2 changes: 1 addition & 1 deletion results/pivoted_ipc.csv
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ matmul_transb 4x16x16xf32,0.95,,,0.92,0.88
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.68,,0.84,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.45,,0.84,,
relu 4x4xf64,0.40,,0.53,,
relu 4x8xf32,0.57,,,,0.40
relu 4x8xf32,0.57,0.51,,0.57,0.40
saxpy 64xf32,0.93,0.93,,,0.65
sum 4x4xf64,0.57,,0.46,,
sum 4x8xf32,0.58,0.56,0.46,,0.47
Expand Down
2 changes: 1 addition & 1 deletion results/pivoted_ipc.fast.csv
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ matmul_transb 4x16x16xf32,0.95,,,0.92,0.88
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.68,,0.84,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.45,,0.84,,
relu 4x4xf64,0.40,,0.53,,
relu 4x8xf32,0.57,,,,0.40
relu 4x8xf32,0.57,0.51,,0.57,0.40
saxpy 64xf32,0.93,0.93,,,0.65
sum 4x4xf64,0.57,,0.46,,
sum 4x8xf32,0.58,0.56,0.46,,0.47
Expand Down
2 changes: 1 addition & 1 deletion xdsl
Submodule xdsl updated 45 files
+5 −5 pyproject.toml
+29 −0 tests/dialects/test_tensor.py
+85 −0 tests/dialects/test_transform.py
+5 −0 tests/filecheck/dialects/arith/arith_ops.mlir
+13 −7 tests/filecheck/dialects/csl/csl-stencil-ops.mlir
+73 −45 tests/filecheck/dialects/linalg/linalg_ops.mlir
+20 −13 tests/filecheck/dialects/riscv_snitch/assembly_emission.mlir
+4 −0 tests/filecheck/dialects/riscv_snitch/ops.mlir
+1 −0 tests/filecheck/dialects/snitch_stream/ops.mlir
+9 −9 tests/filecheck/dialects/stencil/invalid.mlir
+26 −2 tests/filecheck/dialects/stencil/stencil_ops.mlir
+23 −0 tests/filecheck/dialects/tensor/invalid_ops.mlir
+26 −0 tests/filecheck/dialects/tensor/ops.mlir
+44 −10 tests/filecheck/dialects/transform/transform_types.mlir
+1 −1 tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/invalid_ops.mlir
+8 −0 tests/filecheck/mlir-conversion/with-mlir/dialects/tensor/ops.mlir
+42 −10 tests/filecheck/mlir-conversion/with-mlir/dialects/transform/transform_types.mlir
+55 −3 tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir
+31 −4 tests/filecheck/transforms/convert-stencil-to-ll-mlir.mlir
+4 −4 tests/filecheck/transforms/distribute-stencil.mlir
+15 −0 tests/filecheck/transforms/lift-arith-to-linalg.mlir
+469 −0 tests/filecheck/transforms/stencil-bufferize.mlir
+8 −8 tests/filecheck/transforms/stencil-shape-inference.mlir
+44 −46 tests/filecheck/transforms/stencil-tensorize-z-dimension.mlir
+9 −9 tests/filecheck/transforms/stencil-to-csl-stencil.mlir
+2 −5 tests/interpreters/test_linalg_interpreter.py
+10 −2 xdsl/dialects/arith.py
+161 −145 xdsl/dialects/builtin.py
+127 −112 xdsl/dialects/linalg.py
+16 −11 xdsl/dialects/riscv_snitch.py
+2 −0 xdsl/dialects/snitch_stream.py
+185 −49 xdsl/dialects/stencil.py
+112 −6 xdsl/dialects/tensor.py
+329 −9 xdsl/dialects/transform.py
+12 −11 xdsl/parser/attribute_parser.py
+12 −0 xdsl/tools/command_line_tool.py
+21 −1 xdsl/traits.py
+25 −3 xdsl/transforms/canonicalization_patterns/stencil.py
+1 −1 xdsl/transforms/dead_code_elimination.py
+17 −7 xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
+82 −31 xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
+65 −0 xdsl/transforms/lift_arith_to_linalg.py
+525 −0 xdsl/transforms/stencil_bufferize.py
+0 −2 xdsl/transforms/stencil_to_csl_stencil.py
+9 −4 xdsl/xdsl_opt_main.py
2 changes: 1 addition & 1 deletion xdsl_commit.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
adfcb056f40e6f3811ce6d26d1ea50da69c98ca9
f64501c5760cde6f78fa66e64da961ea19c274d6

0 comments on commit f787eaf

Please sign in to comment.