diff --git a/tests/filecheck/transforms/stencil-inlining.mlir b/tests/filecheck/transforms/stencil-inlining.mlir new file mode 100644 index 0000000000..a59e16808c --- /dev/null +++ b/tests/filecheck/transforms/stencil-inlining.mlir @@ -0,0 +1,353 @@ +// RUN: xdsl-opt %s --split-input-file -p stencil-inlining | filecheck %s + +func.func @simple(%arg0: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg1: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) attributes {stencil.program} { + %2 = stencil.load %arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[0,66]x[0,66]x[0,63]xf64> + %3 = stencil.apply (%arg2 = %2 : !stencil.temp<[0,66]x[0,66]x[0,63]xf64>) -> (!stencil.temp<[1,65]x[2,66]x[3,63]xf64>) { + %5 = stencil.access %arg2 [-1, 0, 0] : !stencil.temp<[0,66]x[0,66]x[0,63]xf64> + %6 = stencil.access %arg2 [1, 0, 0] : !stencil.temp<[0,66]x[0,66]x[0,63]xf64> + %7 = arith.addf %5, %6 : f64 + %8 = stencil.store_result %7 : !stencil.result + stencil.return %8 : !stencil.result + } + %4 = stencil.apply (%arg2 = %2 : !stencil.temp<[0,66]x[0,66]x[0,63]xf64>, %arg3 = %3 : !stencil.temp<[1,65]x[2,66]x[3,63]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { + %5 = stencil.access %arg2 [0, 0, 0] : !stencil.temp<[0,66]x[0,66]x[0,63]xf64> + %6 = stencil.access %arg3 [1, 2, 3] : !stencil.temp<[1,65]x[2,66]x[3,63]xf64> + %7 = arith.addf %5, %6 : f64 + %8 = stencil.store_result %7 : !stencil.result + stencil.return %8 : !stencil.result + } + stencil.store %4 to %arg1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> + return +} + +// CHECK: func.func @simple(%arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg1 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) attributes {"stencil.program"}{ +// CHECK-NEXT: %0 = stencil.load %arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[0,66]x[0,66]x[0,63]xf64> +// CHECK-NEXT: %1 = stencil.apply(%arg2 = %0 : !stencil.temp<[0,66]x[0,66]x[0,63]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { +// CHECK-NEXT: %2 = stencil.access %arg2[0, 0, 0] : !stencil.temp<[0,66]x[0,66]x[0,63]xf64> +// CHECK-NEXT: %3 = stencil.access %arg2[0, 2, 3] : !stencil.temp<[0,66]x[0,66]x[0,63]xf64> +// CHECK-NEXT: %4 = stencil.access %arg2[2, 2, 3] : !stencil.temp<[0,66]x[0,66]x[0,63]xf64> +// CHECK-NEXT: %5 = arith.addf %3, %4 : f64 +// CHECK-NEXT: %6 = arith.addf %2, %5 : f64 +// CHECK-NEXT: %7 = stencil.store_result %6 : !stencil.result +// CHECK-NEXT: stencil.return %7 : !stencil.result +// CHECK-NEXT: } +// CHECK-NEXT: stencil.store %1 to %arg1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> +// CHECK-NEXT: func.return +// CHECK-NEXT: } + +// ----- + +func.func @simple_index(%arg0: f64, %arg1: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) { + %1 = stencil.apply (%arg2 = %arg0 : f64) -> (!stencil.temp<[1,65]x[2,66]x[3,63]xf64>) { + %3 = stencil.index 2 <[2, -1, 1]> + %c20 = arith.constant 20 : index + %cst = arith.constant 0.000000e+00 : f64 + %4 = arith.cmpi slt, %3, %c20 : index + %5 = arith.select %4, %arg2, %cst : f64 + %6 = stencil.store_result %5 : !stencil.result + stencil.return %6 : !stencil.result + } + %2 = stencil.apply (%arg2 = %arg0 : f64, %arg3 = %1 : !stencil.temp<[1,65]x[2,66]x[3,63]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { + %7 = stencil.access %arg3 [1, 2, 3] : !stencil.temp<[1,65]x[2,66]x[3,63]xf64> + %8 = arith.addf %7, %arg2 : f64 + %9 = stencil.store_result %8 : !stencil.result + stencil.return %9 : !stencil.result + } + stencil.store %2 to %arg1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> + return +} + +// CHECK: func.func @simple_index(%arg0 : f64, %arg1 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) { +// CHECK-NEXT: %0 = stencil.apply(%arg2 = %arg0 : f64) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { +// CHECK-NEXT: %1 = stencil.index 2 <[3, 1, 4]> +// CHECK-NEXT: %c20 = arith.constant 20 : index +// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %2 = arith.cmpi slt, %1, %c20 : index +// CHECK-NEXT: %3 = arith.select %2, %arg2, %cst : f64 +// CHECK-NEXT: %4 = arith.addf %3, %arg2 : f64 +// CHECK-NEXT: %5 = stencil.store_result %4 : !stencil.result +// CHECK-NEXT: stencil.return %5 : !stencil.result +// CHECK-NEXT: } +// CHECK-NEXT: stencil.store %0 to %arg1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> +// CHECK-NEXT: func.return +// CHECK-NEXT: } + +// ----- + +func.func @simple_ifelse(%arg0: f64, %arg1: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) attributes {stencil.program} { + %1 = stencil.apply (%arg2 = %arg0 : f64) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { + %true = arith.constant true + %3 = "scf.if"(%true) ({ + %4 = stencil.store_result %arg2 : !stencil.result + scf.yield %4 : !stencil.result + }, { + %4 = stencil.store_result %arg2 : !stencil.result + scf.yield %4 : !stencil.result + }) : (i1) -> (!stencil.result) + stencil.return %3 : !stencil.result + } + %2 = stencil.apply (%arg2 = %1 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { + %3 = stencil.access %arg2 [0, 0, 0] : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> + %4 = stencil.store_result %3 : !stencil.result + stencil.return %4 : !stencil.result + } + stencil.store %2 to %arg1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> + return +} + +// CHECK: func.func @simple_ifelse(%arg0 : f64, %arg1 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) attributes {"stencil.program"}{ +// CHECK-NEXT: %0 = stencil.apply(%arg2 = %arg0 : f64) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { +// CHECK-NEXT: %true = arith.constant true +// CHECK-NEXT: %1 = scf.if %true -> (f64) { +// CHECK-NEXT: scf.yield %arg2 : f64 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %arg2 : f64 +// CHECK-NEXT: } +// CHECK-NEXT: %2 = stencil.store_result %1 : !stencil.result +// CHECK-NEXT: stencil.return %2 : !stencil.result +// CHECK-NEXT: } +// CHECK-NEXT: stencil.store %0 to %arg1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> +// CHECK-NEXT: func.return +// CHECK-NEXT: } + +// ----- + +func.func @multiple_edges(%arg0: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg1: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg2: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) attributes {stencil.program} { + %3 = stencil.load %arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> + %4:2 = stencil.apply (%arg3 = %3 : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>, !stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { + %7 = stencil.access %arg3 [-1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> + %8 = stencil.access %arg3 [1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> + %9 = stencil.store_result %7 : !stencil.result + %10 = stencil.store_result %8 : !stencil.result + stencil.return %9, %10 : !stencil.result, !stencil.result + } + %5 = stencil.load %arg1 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[0,64]x[0,64]x[0,60]xf64> + %6 = stencil.apply (%arg3 = %3 : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64>, %arg4 = %4#0 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>, %arg5 = %4#1 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>, %arg6 = %5 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { + %7 = stencil.access %arg3 [0, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> + %8 = stencil.access %arg4 [0, 0, 0] : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> + %9 = stencil.access %arg5 [0, 0, 0] : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> + %10 = stencil.access %arg6 [0, 0, 0] : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> + %11 = arith.addf %7, %8 : f64 + %12 = arith.addf %9, %10 : f64 + %13 = arith.addf %11, %12 : f64 + %14 = stencil.store_result %13 : !stencil.result + stencil.return %14 : !stencil.result + } + stencil.store %6 to %arg2(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> + return +} + +// CHECK: func.func @multiple_edges(%arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg1 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg2 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) attributes {"stencil.program"}{ +// CHECK-NEXT: %0 = stencil.load %arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> +// CHECK-NEXT: %1 = stencil.load %arg1 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[0,64]x[0,64]x[0,60]xf64> +// CHECK-NEXT: %2 = stencil.apply(%arg3 = %0 : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64>, %arg6 = %1 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { +// CHECK-NEXT: %3 = stencil.access %arg3[0, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> +// CHECK-NEXT: %4 = stencil.access %arg3[-1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> +// CHECK-NEXT: %5 = stencil.access %arg3[1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> +// CHECK-NEXT: %6 = stencil.access %arg6[0, 0, 0] : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> +// CHECK-NEXT: %7 = arith.addf %3, %4 : f64 +// CHECK-NEXT: %8 = arith.addf %5, %6 : f64 +// CHECK-NEXT: %9 = arith.addf %7, %8 : f64 +// CHECK-NEXT: %10 = stencil.store_result %9 : !stencil.result +// CHECK-NEXT: stencil.return %10 : !stencil.result +// CHECK-NEXT: } +// CHECK-NEXT: stencil.store %2 to %arg2(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> +// CHECK-NEXT: func.return +// CHECK-NEXT: } + +// ----- + +func.func @avoid_redundant(%arg0: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg1: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) attributes {stencil.program} { + %2 = stencil.load %arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> + %3 = stencil.apply (%arg2 = %2 : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { + %5 = stencil.access %arg2 [-1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> + %6 = stencil.access %arg2 [1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> + %7 = arith.addf %5, %6 : f64 + %8 = stencil.store_result %7 : !stencil.result + stencil.return %8 : !stencil.result + } + %4 = stencil.apply (%arg2 = %3 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { + %5 = stencil.access %arg2 [0, 0, 0] : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> + %6 = stencil.access %arg2 [0, 0, 0] : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> + %7 = arith.addf %5, %6 : f64 + %8 = stencil.store_result %7 : !stencil.result + stencil.return %8 : !stencil.result + } + stencil.store %4 to %arg1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> + return +} + +// CHECK: func.func @avoid_redundant(%arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg1 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) attributes {"stencil.program"}{ +// CHECK-NEXT: %0 = stencil.load %arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> +// CHECK-NEXT: %1 = stencil.apply(%arg2 = %0 : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { +// CHECK-NEXT: %2 = stencil.access %arg2[-1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> +// CHECK-NEXT: %3 = stencil.access %arg2[1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> +// CHECK-NEXT: %4 = arith.addf %2, %3 : f64 +// CHECK-NEXT: %5 = stencil.access %arg2[-1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> +// CHECK-NEXT: %6 = stencil.access %arg2[1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> +// CHECK-NEXT: %7 = arith.addf %5, %6 : f64 +// CHECK-NEXT: %8 = arith.addf %4, %7 : f64 +// CHECK-NEXT: %9 = stencil.store_result %8 : !stencil.result +// CHECK-NEXT: stencil.return %9 : !stencil.result +// CHECK-NEXT: } +// CHECK-NEXT: stencil.store %1 to %arg1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> +// CHECK-NEXT: func.return +// CHECK-NEXT: } + +// ----- + +func.func @reroute(%arg0: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg1: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg2: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) attributes {stencil.program} { + %3 = stencil.load %arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[-1,66]x[0,66]x[0,63]xf64> + %4 = stencil.apply (%arg3 = %3 : !stencil.temp<[-1,66]x[0,66]x[0,63]xf64>) -> (!stencil.temp<[0,65]x[0,66]x[0,63]xf64>) { + %6 = stencil.access %arg3 [-1, 0, 0] : !stencil.temp<[-1,66]x[0,66]x[0,63]xf64> + %7 = stencil.access %arg3 [1, 0, 0] : !stencil.temp<[-1,66]x[0,66]x[0,63]xf64> + %8 = arith.addf %6, %7 : f64 + %9 = stencil.store_result %8 : !stencil.result + stencil.return %9 : !stencil.result + } + %5 = stencil.apply (%arg4 = %4 : !stencil.temp<[0,65]x[0,66]x[0,63]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { + %6 = stencil.access %arg4 [0, 0, 0] : !stencil.temp<[0,65]x[0,66]x[0,63]xf64> + %7 = stencil.access %arg4 [1, 2, 3] : !stencil.temp<[0,65]x[0,66]x[0,63]xf64> + %8 = arith.addf %6, %7 : f64 + %9 = stencil.store_result %8 : !stencil.result + stencil.return %9 : !stencil.result + } + stencil.store %4 to %arg1(<[0, 0, 0], [65, 66, 63]>) : !stencil.temp<[0,65]x[0,66]x[0,63]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> + stencil.store %5 to %arg2(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> + return +} + +// CHECK: func.func @reroute(%arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg1 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg2 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) attributes {"stencil.program"}{ +// CHECK-NEXT: %0 = stencil.load %arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[-1,66]x[0,66]x[0,63]xf64> +// CHECK-NEXT: %1, %2 = stencil.apply(%arg3 = %0 : !stencil.temp<[-1,66]x[0,66]x[0,63]xf64>) -> (!stencil.temp<[0,65]x[0,66]x[0,63]xf64>, !stencil.temp<[0,65]x[0,66]x[0,63]xf64>) { +// CHECK-NEXT: %3 = stencil.access %arg3[-1, 0, 0] : !stencil.temp<[-1,66]x[0,66]x[0,63]xf64> +// CHECK-NEXT: %4 = stencil.access %arg3[1, 0, 0] : !stencil.temp<[-1,66]x[0,66]x[0,63]xf64> +// CHECK-NEXT: %5 = arith.addf %3, %4 : f64 +// CHECK-NEXT: %6 = stencil.access %arg3[0, 2, 3] : !stencil.temp<[-1,66]x[0,66]x[0,63]xf64> +// CHECK-NEXT: %7 = stencil.access %arg3[2, 2, 3] : !stencil.temp<[-1,66]x[0,66]x[0,63]xf64> +// CHECK-NEXT: %8 = arith.addf %6, %7 : f64 +// CHECK-NEXT: %9 = arith.addf %5, %8 : f64 +// CHECK-NEXT: %10 = stencil.store_result %9 : !stencil.result +// CHECK-NEXT: %11 = stencil.access %arg3[-1, 0, 0] : !stencil.temp<[-1,66]x[0,66]x[0,63]xf64> +// CHECK-NEXT: %12 = stencil.access %arg3[1, 0, 0] : !stencil.temp<[-1,66]x[0,66]x[0,63]xf64> +// CHECK-NEXT: %13 = arith.addf %11, %12 : f64 +// CHECK-NEXT: stencil.return %10, %13 : !stencil.result, f64 +// CHECK-NEXT: } +// CHECK-NEXT: stencil.store %2 to %arg1(<[0, 0, 0], [65, 66, 63]>) : !stencil.temp<[0,65]x[0,66]x[0,63]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> +// CHECK-NEXT: stencil.store %1 to %arg2(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,65]x[0,66]x[0,63]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> +// CHECK-NEXT: func.return +// CHECK-NEXT: } + +// ----- + +func.func @root(%arg0: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg1: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg2: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) attributes {stencil.program} { + %3 = stencil.load %arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[0,65]x[0,66]x[0,63]xf64> + %4 = stencil.apply (%arg3 = %3 : !stencil.temp<[0,65]x[0,66]x[0,63]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { + %6 = stencil.access %arg3 [0, 0, 0] : !stencil.temp<[0,65]x[0,66]x[0,63]xf64> + %7 = stencil.store_result %6 : !stencil.result + stencil.return %7 : !stencil.result + } + %5 = stencil.apply (%arg3 = %3 : !stencil.temp<[0,65]x[0,66]x[0,63]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { + %6 = stencil.access %arg3 [1, 2, 3] : !stencil.temp<[0,65]x[0,66]x[0,63]xf64> + %7 = stencil.store_result %6 : !stencil.result + stencil.return %7 : !stencil.result + } + stencil.store %4 to %arg1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> + stencil.store %5 to %arg2(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> + return +} + +// CHECK: func.func @root(%arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg1 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg2 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) attributes {"stencil.program"}{ +// CHECK-NEXT: %0 = stencil.load %arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[0,65]x[0,66]x[0,63]xf64> +// CHECK-NEXT: %1, %2 = stencil.apply(%arg3 = %0 : !stencil.temp<[0,65]x[0,66]x[0,63]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>, !stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { +// CHECK-NEXT: %3 = stencil.access %arg3[1, 2, 3] : !stencil.temp<[0,65]x[0,66]x[0,63]xf64> +// CHECK-NEXT: %4 = stencil.store_result %3 : !stencil.result +// CHECK-NEXT: %5 = stencil.access %arg3[0, 0, 0] : !stencil.temp<[0,65]x[0,66]x[0,63]xf64> +// CHECK-NEXT: stencil.return %4, %5 : !stencil.result, f64 +// CHECK-NEXT: } +// CHECK-NEXT: stencil.store %2 to %arg1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> +// CHECK-NEXT: stencil.store %1 to %arg2(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> +// CHECK-NEXT: func.return +// CHECK-NEXT: } + +// ----- + +func.func @dyn_access(%arg0: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg1: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) { + %2 = stencil.load %arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[-2,66]x[-1,65]x[-2,60]xf64> + %3 = stencil.apply (%arg2 = %2 : !stencil.temp<[-2,66]x[-1,65]x[-2,60]xf64>) -> (!stencil.temp<[-2,66]x[-1,65]x[-1,61]xf64>) { + %6 = stencil.access %arg2 [0, 0, -1] : !stencil.temp<[-2,66]x[-1,65]x[-2,60]xf64> + %7 = stencil.store_result %6 : !stencil.result + stencil.return %7 : !stencil.result + } + %4 = stencil.apply (%arg3 = %3 : !stencil.temp<[-2,66]x[-1,65]x[-1,61]xf64>) -> (!stencil.temp<[-1,65]x[0,64]x[0,60]xf64>) { + %7 = stencil.index 0 <[0, 0, 0]> + %8 = stencil.dyn_access %arg3[%7, %7, %7] in <[-1, -1, -1]> : <[1, 1, 1]> : !stencil.temp<[-2,66]x[-1,65]x[-1,61]xf64> + %9 = stencil.store_result %8 : !stencil.result + stencil.return %9 : !stencil.result + } + %5 = stencil.apply (%arg4 = %4 : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { + %10 = stencil.access %arg4 [-1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> + %11 = stencil.access %arg4 [1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64> + %12 = arith.addf %11, %10 : f64 + %13 = stencil.store_result %12 : !stencil.result + stencil.return %13 : !stencil.result + } + stencil.store %5 to %arg1(<[0, 0, 0], [64, 64, 60]>): !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> + return +} + +// CHECK: func.func @dyn_access(%arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg1 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) { +// CHECK-NEXT: %0 = stencil.load %arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[-2,66]x[-1,65]x[-2,60]xf64> +// CHECK-NEXT: %1 = stencil.apply(%arg2 = %0 : !stencil.temp<[-2,66]x[-1,65]x[-2,60]xf64>) -> (!stencil.temp<[-2,66]x[-1,65]x[-1,61]xf64>) { +// CHECK-NEXT: %2 = stencil.access %arg2[0, 0, -1] : !stencil.temp<[-2,66]x[-1,65]x[-2,60]xf64> +// CHECK-NEXT: %3 = stencil.store_result %2 : !stencil.result +// CHECK-NEXT: stencil.return %3 : !stencil.result +// CHECK-NEXT: } +// CHECK-NEXT: %2 = stencil.apply(%arg3 = %1 : !stencil.temp<[-2,66]x[-1,65]x[-1,61]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { +// CHECK-NEXT: %3 = stencil.index 0 <[-1, 0, 0]> +// CHECK-NEXT: %4 = stencil.dyn_access %arg3[%3, %3, %3] in <[-2, -1, -1]> : <[0, 1, 1]> : !stencil.temp<[-2,66]x[-1,65]x[-1,61]xf64> +// CHECK-NEXT: %5 = stencil.index 0 <[1, 0, 0]> +// CHECK-NEXT: %6 = stencil.dyn_access %arg3[%5, %5, %5] in <[0, -1, -1]> : <[2, 1, 1]> : !stencil.temp<[-2,66]x[-1,65]x[-1,61]xf64> +// CHECK-NEXT: %7 = arith.addf %6, %4 : f64 +// CHECK-NEXT: %8 = stencil.store_result %7 : !stencil.result +// CHECK-NEXT: stencil.return %8 : !stencil.result +// CHECK-NEXT: } +// CHECK-NEXT: stencil.store %2 to %arg1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> +// CHECK-NEXT: func.return +// CHECK-NEXT: } + +// ----- + +func.func @simple_buffer(%arg0: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg1: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) attributes {stencil.program} { + %2 = stencil.load %arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[0,64]x[0,64]x[0,60]xf64> + %3 = stencil.apply (%arg2 = %2 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { + %6 = stencil.access %arg2 [0, 0, 0] : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> + %7 = stencil.store_result %6 : !stencil.result + stencil.return %7 : !stencil.result + } + %4 = stencil.buffer %3 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> -> !stencil.temp<[0,64]x[0,64]x[0,60]xf64> + %5 = stencil.apply (%arg2 = %4 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { + %6 = stencil.access %arg2 [0, 0, 0] : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> + %7 = stencil.store_result %6 : !stencil.result + stencil.return %7 : !stencil.result + } + stencil.store %5 to %arg1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> + return +} + +// CHECK: func.func @simple_buffer(%arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg1 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>) attributes {"stencil.program"}{ +// CHECK-NEXT: %0 = stencil.load %arg0 : !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> -> !stencil.temp<[0,64]x[0,64]x[0,60]xf64> +// CHECK-NEXT: %1 = stencil.apply(%arg2 = %0 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { +// CHECK-NEXT: %2 = stencil.access %arg2[0, 0, 0] : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> +// CHECK-NEXT: %3 = stencil.store_result %2 : !stencil.result +// CHECK-NEXT: stencil.return %3 : !stencil.result +// CHECK-NEXT: } +// CHECK-NEXT: %2 = stencil.buffer %1 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> -> !stencil.temp<[0,64]x[0,64]x[0,60]xf64> +// CHECK-NEXT: %3 = stencil.apply(%arg2 = %2 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) { +// CHECK-NEXT: %4 = stencil.access %arg2[0, 0, 0] : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> +// CHECK-NEXT: %5 = stencil.store_result %4 : !stencil.result +// CHECK-NEXT: stencil.return %5 : !stencil.result +// CHECK-NEXT: } +// CHECK-NEXT: stencil.store %3 to %arg1(<[0, 0, 0], [64, 64, 60]>) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64> +// CHECK-NEXT: func.return +// CHECK-NEXT: } diff --git a/xdsl/transforms/__init__.py b/xdsl/transforms/__init__.py index b61dccf042..ed0a62faef 100644 --- a/xdsl/transforms/__init__.py +++ b/xdsl/transforms/__init__.py @@ -448,6 +448,11 @@ def get_stencil_bufferize(): return stencil_bufferize.StencilBufferize + def get_stencil_inlining(): + from xdsl.transforms import stencil_inlining + + return stencil_inlining.StencilInliningPass + def get_stencil_shape_minimize(): from xdsl.transforms import stencil_shape_minimize @@ -574,6 +579,7 @@ def get_varith_fuse_repeated_operands(): "shape-inference": get_shape_inference, "snitch-allocate-registers": get_snitch_allocate_registers, "stencil-bufferize": get_stencil_bufferize, + "stencil-inlining": get_stencil_inlining, "stencil-shape-minimize": get_stencil_shape_minimize, "stencil-storage-materialization": get_stencil_storage_materialization, "stencil-tensorize-z-dimension": get_stencil_tensorize_z_dimension, diff --git a/xdsl/transforms/stencil_inlining.py b/xdsl/transforms/stencil_inlining.py new file mode 100644 index 0000000000..19a2846ff4 --- /dev/null +++ b/xdsl/transforms/stencil_inlining.py @@ -0,0 +1,366 @@ +from collections.abc import Sequence +from dataclasses import dataclass +from typing import cast + +from xdsl.context import MLContext +from xdsl.dialects import builtin, scf +from xdsl.dialects.stencil import ( + AccessOp, + ApplyOp, + DynAccessOp, + ResultType, + ReturnOp, + StencilBoundsAttr, + StoreResultOp, + TempType, +) +from xdsl.ir import ( + Attribute, + Block, + BlockArgument, + Operation, + OpResult, +) +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + op_type_rewrite_pattern, +) +from xdsl.rewriter import InsertPoint +from xdsl.transforms.canonicalization_patterns.stencil import ( + ApplyRedundantOperands, + ApplyUnusedOperands, + ApplyUnusedResults, +) +from xdsl.transforms.shape_inference_patterns.stencil import update_result_size +from xdsl.transforms.stencil_unroll import offseted_block_clone + + +def is_before_in_block(op1: Operation, op2: Operation): + """ + Check if op1 is before op2 in the same block. + """ + block = op1.parent + assert block is not None + assert block is op2.parent + return block.get_operation_index(op1) < block.get_operation_index(op2) + + +class StencilStoreResultForwardPattern(RewritePattern): + """ + Replace non-empty `stencil.store_result`s by their argument. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: StoreResultOp, rewriter: PatternRewriter, /): + if op.arg is None: + return + rewriter.replace_matched_op([], [op.arg]) + + +class StencilIfResultForwardPattern(RewritePattern): + """ + Replace `!stencil.result`-typed scf.if by `T`-typed scf.if. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: scf.IfOp, rewriter: PatternRewriter, /): + result_types = [r.type for r in op.output] + new_result_types = [ + t.elem if isinstance(t, ResultType) else t for t in result_types + ] + if new_result_types == result_types: + return + rewriter.replace_matched_op( + scf.IfOp( + op.cond, + new_result_types, + op.detach_region(0), + op.detach_region(0), + ) + ) + + +def has_single_consumer(producer: ApplyOp, consumer: ApplyOp): + """ + Check if the producer has a single consumer. + """ + return all( + isinstance(u.operation, ApplyOp) and u.operation == consumer + for r in producer.results + for u in r.uses + ) + + +def is_rerouting_possible(producer: ApplyOp, consumer: ApplyOp): + """ + Check if rerouting is possible. + """ + # Perform producer consumer inlining instead + if has_single_consumer(producer, consumer): + return False + return not any( + isinstance(operand.owner, Operation) + and (operand.owner is not producer) + and is_before_in_block(producer, operand.owner) + for operand in consumer.operands + ) + + +def is_inlining_possible(producer: ApplyOp, consumer: ApplyOp): + """ + Check if inlining is possible. + """ + # Don't inline any producer with conditional writes. + r = not any( + store_result.arg is None + for store_result in producer.walk() + if isinstance(store_result, StoreResultOp) + ) and not any( + # Don't inline any dynamic accesses. + isinstance(use.operation, DynAccessOp) + for consumer_operand in consumer.operands + if consumer_operand.owner is producer + for use in consumer.region.block.args[ + consumer.operands.index(consumer_operand) + ].uses + ) + + return r + + +class StencilReroutingPattern(RewritePattern): + """ + Reroute the producer's results through the consumer to enable inlining: + ``` + a b a b + │ │ │ │ + ┌───▼─────────────┐ │ ┌───▼─────────────┐ │ + │ producer ├─┐ │ │ producer ├─┐ │ + └──┬──────────────┘ │e │ rerouting └─────────────┬───┘ │e │ + │ │ │ ──────────► │c' │ │ + │ ┌─▼───────────▼──┐ └──►┌─▼───────────▼──┐ + │ │ consumer │ │ consumer │ + │ └────────────┬───┘ ┌──────────────┴────────────┬───┘ + │ │ │ │ + ▼ ▼ ▼ ▼ + c d c d + ``` + """ + + def redirect_store( + self, producer: ApplyOp, consumer: ApplyOp, rewriter: PatternRewriter + ): + # We want to replace the consumer adding the producer's results to its operands + # and results + new_operands = list(consumer.args) + list(producer.results) + new_results = list(r.type for r in consumer.res + producer.res) + + new_consumer = ApplyOp.get( + new_operands, + Block(arg_types=[o.type for o in new_operands]), + cast(Sequence[TempType[Attribute]], new_results), + ) + + # The new consumer contains the computation of the inital one + rewriter.inline_block( + consumer.region.block, + InsertPoint.at_end(new_consumer.region.block), + new_consumer.region.block.args[: len(consumer.args)], + ) + + # Update the bounds if needed + producer_bounds = cast(TempType[Attribute], producer.res[0].type).bounds + consumer_bounds = cast(TempType[Attribute], consumer.res[0].type).bounds + if isinstance(producer_bounds, StencilBoundsAttr): + new_bounds = producer_bounds | consumer_bounds + elif isinstance(consumer_bounds, StencilBoundsAttr): + new_bounds = producer_bounds | consumer_bounds + else: + new_bounds = None + if isinstance(new_bounds, StencilBoundsAttr): + update_result_size(new_consumer.res[0], new_bounds, rewriter) + + # Reroute new arguments to the new apply's return + return_op = cast(ReturnOp, new_consumer.region.block.last_op) + return_operands = list(return_op.arg) + zero_offset = [0] * new_consumer.get_rank() + for arg in new_consumer.region.block.args[-len(producer.res) :]: + access = AccessOp.get(arg, zero_offset) + rewriter.insert_op(access, InsertPoint.before(return_op)) + return_operands.append(access.res) + rewriter.replace_op(return_op, ReturnOp.get(return_operands)) + + # Replace the producer's results by the rerouted consumer results + rerouted_results = new_consumer.res[-len(producer.res) :] + for pres, rres in zip(producer.res, rerouted_results, strict=True): + for use in list(pres.uses): + if use.operation is new_consumer: + continue + use.operation.operands[use.index] = rres + + rewriter.replace_op( + consumer, new_consumer, new_consumer.res[: len(consumer.res)] + ) + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): + consumer = op + + # Reroute input dependency + # That is, two applys share some operands but have no direct def-use link. + for operand in consumer.operands: + if isinstance(operand.owner, Operation): + for res in operand.owner.results: + for use in res.uses: + # Only consider other apply operations + if isinstance(producer := use.operation, ApplyOp): + # Only consider other consumers before the apply op + if consumer is producer: + continue + if not is_before_in_block(producer, consumer): + continue + + if is_inlining_possible( + producer, consumer + ) and is_rerouting_possible(producer, consumer): + return self.redirect_store(producer, consumer, rewriter) + + # Reroute output dependency + # That is, the consumer is already using some of the producer's results + for operand in consumer.operands: + producer = operand.owner + if isinstance(producer, ApplyOp): + if is_inlining_possible(producer, consumer) and is_rerouting_possible( + producer, consumer + ): + return self.redirect_store(producer, consumer, rewriter) + + +@dataclass +class StencilInliningPattern(RewritePattern): + """ + Inline a producer apply in a consumer apply, to use in the simple case where the + consumer is the only user of the producer's results: + ``` + a b c + │ │ │ + ┌▼─▼────────┐ │ a b c + │ producer │ │ │ │ │ + └─────┬─┬───┘ │ inlining ┌▼─▼───────▼┐ + d│ │e │ ─────────► │ inlined │ + ┌─▼─▼─────▼┐ └─────────┬─┘ + │ consumer │ │ + └────────┬─┘ ▼ + │ output + ▼ + output + ``` + """ + + result_type_cleaner = PatternRewriteWalker( + GreedyRewritePatternApplier( + [StencilIfResultForwardPattern(), StencilStoreResultForwardPattern()] + ) + ) + + def inline_producer( + self, producer: ApplyOp, consumer: ApplyOp, rewriter: PatternRewriter + ): + """ + Inline the producer into the consumer. + """ + + self.result_type_cleaner.rewrite_region(producer.region) + + # Concatenate both applies operands lists. + operands = list(consumer.operands) + list(producer.operands) + + # Create a new apply with the concatenated operands + # Corresponding block arguments, and only the consumer's results. + # (The producer's results are only used in the consumer by assumption) + merged_block = Block(arg_types=[o.type for o in operands]) + + # Prepare the list of block arguments corresponding to the producer's operands. + merged_producer_arguments = merged_block.args[len(consumer.operands) :] + + # Inline the consumer's block to begin with. + rewriter.inline_block( + consumer.region.block, + InsertPoint.at_start(merged_block), + merged_block.args[: len(consumer.operands)], + ) + + # Store the list of consumer accesses + consumer_accesses = [ + op for op in merged_block.walk(reverse=True) if isinstance(op, AccessOp) + ] + + # Start inlining accesses to the producer + for access in consumer_accesses: + # Skip if it is another access + temp = consumer.args[cast(BlockArgument, access.temp).index] + if temp.owner is not producer: + continue + # Make pyright happy about temp being an OpResult + temp = cast(OpResult, temp) + # Find the index of the producer's result + producer_index = producer.res.index(temp) + + # Clone the producer's block offseted according to the access offset. + offsetted_block = offseted_block_clone(producer, list(access.offset)) + + # Get the returnop's accessed operand. + return_op = cast(ReturnOp, offsetted_block.last_op) + accessed = return_op.arg[producer_index] + + # Remove the return, inline the computation, replace the access. + rewriter.erase_op(return_op) + rewriter.inline_block( + offsetted_block, InsertPoint.before(access), merged_producer_arguments + ) + rewriter.replace_op(access, [], [accessed]) + + new_operands = operands + for arg in reversed(list(merged_block.args)): + if not arg.uses: + new_operands.pop(arg.index) + merged_block.erase_arg(arg) + new_apply = ApplyOp.get( + new_operands, + merged_block, + [cast(TempType[Attribute], r.type) for r in consumer.results], + ) + rewriter.replace_op(consumer, new_apply) + rewriter.erase_op(producer) + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): + for operand in (consumer := op).operands: + if isinstance(producer := operand.owner, ApplyOp): + if has_single_consumer(producer, consumer) and is_inlining_possible( + producer, consumer + ): + return self.inline_producer(producer, consumer, rewriter) + + +@dataclass(frozen=True) +class StencilInliningPass(ModulePass): + name = "stencil-inlining" + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + walker = PatternRewriteWalker( + GreedyRewritePatternApplier( + [ + StencilReroutingPattern(), + StencilInliningPattern(), + ApplyUnusedResults(), + ApplyUnusedOperands(), + ApplyRedundantOperands(), + ] + ) + ) + walker.rewrite_module(op) diff --git a/xdsl/transforms/stencil_unroll.py b/xdsl/transforms/stencil_unroll.py index f65d91bb90..d90de9fe72 100644 --- a/xdsl/transforms/stencil_unroll.py +++ b/xdsl/transforms/stencil_unroll.py @@ -42,10 +42,10 @@ def offseted_block_clone(apply: ApplyOp, unroll_offset: Sequence[int]): offset_mapping = list(range(0, len(op.offset))) else: offset_mapping = op.offset_mapping - - new_offset = [o for o in op.offset] - for i in offset_mapping: - new_offset[i] += unroll_offset[i] + new_offset = [ + o + unroll_offset[m] + for o, m in zip(op.offset, offset_mapping, strict=True) + ] op.offset = IndexAttr.get(*new_offset) case DynAccessOp():