From 158cfe1ea5b15d245cd1461f9d9299c054d57179 Mon Sep 17 00:00:00 2001 From: Nicolai Stawinoga <36768051+n-io@users.noreply.github.com> Date: Tue, 26 Nov 2024 15:02:44 +0100 Subject: [PATCH] bug: (lower-csl-stencil) Zero-out accumulator for full reduction access (#3520) This resulted in `nan`s after a few iterations. Co-authored-by: n-io --- .../transforms/lower-csl-stencil.mlir | 440 +++++++++--------- xdsl/transforms/lower_csl_stencil.py | 30 +- 2 files changed, 243 insertions(+), 227 deletions(-) diff --git a/tests/filecheck/transforms/lower-csl-stencil.mlir b/tests/filecheck/transforms/lower-csl-stencil.mlir index 7581cbd85e..84889df5a0 100644 --- a/tests/filecheck/transforms/lower-csl-stencil.mlir +++ b/tests/filecheck/transforms/lower-csl-stencil.mlir @@ -95,31 +95,33 @@ builtin.module { // CHECK-NEXT: "csl.export"() <{"var_name" = @gauss_seidel_func, "type" = () -> ()}> : () -> () // CHECK-NEXT: csl.func @gauss_seidel_func() { // CHECK-NEXT: %accumulator = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> -// CHECK-NEXT: %37 = arith.constant 2 : i16 -// CHECK-NEXT: %38 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb2}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> -// CHECK-NEXT: %39 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb2}> : () -> !csl.ptr<() -> (), #csl, #csl> -// CHECK-NEXT: %40 = memref.subview %arg0[1] [510] [1] : memref<512xf32> to memref<510xf32> -// CHECK-NEXT: "csl.member_call"(%34, %40, %37, %38, %39) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () +// CHECK-NEXT: %37 = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: "csl.fmovs"(%accumulator, %37) : (memref<510xf32>, f32) -> () +// CHECK-NEXT: %38 = arith.constant 2 : i16 +// CHECK-NEXT: %39 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb2}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> +// CHECK-NEXT: %40 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb2}> : () -> !csl.ptr<() -> (), #csl, #csl> +// CHECK-NEXT: %41 = memref.subview %arg0[1] [510] [1] : memref<512xf32> to memref<510xf32> +// CHECK-NEXT: "csl.member_call"(%34, %41, %38, %39, %40) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @receive_chunk_cb2(%offset : i16) { // CHECK-NEXT: %offset_1 = arith.index_cast %offset : i16 to index -// CHECK-NEXT: %41 = memref.subview %accumulator[%offset_1] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>> -// CHECK-NEXT: %42 = arith.constant 4 : i16 -// CHECK-NEXT: %43 = "csl.get_mem_dsd"(%accumulator, %42, %29, %31) <{"strides" = [0 : i16, 0 : i16, 1 : i16]}> : (memref<510xf32>, i16, i16, i16) -> !csl -// CHECK-NEXT: %44 = arith.index_cast %offset_1 : index to si16 -// CHECK-NEXT: %45 = "csl.increment_dsd_offset"(%43, %44) <{"elem_type" = f32}> : (!csl, si16) -> !csl -// CHECK-NEXT: %46 = "csl.member_call"(%34) <{"field" = "getRecvBufDsd"}> : (!csl.imported_module) -> !csl -// CHECK-NEXT: "csl.fadds"(%45, %45, %46) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: %42 = memref.subview %accumulator[%offset_1] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>> +// CHECK-NEXT: %43 = arith.constant 4 : i16 +// CHECK-NEXT: %44 = "csl.get_mem_dsd"(%accumulator, %43, %29, %31) <{"strides" = [0 : i16, 0 : i16, 1 : i16]}> : (memref<510xf32>, i16, i16, i16) -> !csl +// CHECK-NEXT: %45 = arith.index_cast %offset_1 : index to si16 +// CHECK-NEXT: %46 = "csl.increment_dsd_offset"(%44, %45) <{"elem_type" = f32}> : (!csl, si16) -> !csl +// CHECK-NEXT: %47 = "csl.member_call"(%34) <{"field" = "getRecvBufDsd"}> : (!csl.imported_module) -> !csl +// CHECK-NEXT: "csl.fadds"(%46, %46, %47) : (!csl, !csl, !csl) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @done_exchange_cb2() { -// CHECK-NEXT: %47 = memref.subview %arg0[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> -// CHECK-NEXT: %48 = memref.subview %arg0[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> -// CHECK-NEXT: "csl.fadds"(%accumulator, %accumulator, %48) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () -// CHECK-NEXT: "csl.fadds"(%accumulator, %accumulator, %47) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () -// CHECK-NEXT: %49 = arith.constant 1.666600e-01 : f32 -// CHECK-NEXT: "csl.fmuls"(%accumulator, %accumulator, %49) : (memref<510xf32>, memref<510xf32>, f32) -> () +// CHECK-NEXT: %48 = memref.subview %arg0[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> +// CHECK-NEXT: %49 = memref.subview %arg0[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> +// CHECK-NEXT: "csl.fadds"(%accumulator, %accumulator, %49) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () +// CHECK-NEXT: "csl.fadds"(%accumulator, %accumulator, %48) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () +// CHECK-NEXT: %50 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: "csl.fmuls"(%accumulator, %accumulator, %50) : (memref<510xf32>, memref<510xf32>, f32) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () @@ -236,52 +238,52 @@ builtin.module { // CHECK-NEXT: "csl_wrapper.module"() <{"height" = 512 : i16, "params" = [#csl_wrapper.param<"z_dim" default=512 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=1 : i16>, #csl_wrapper.param<"chunk_size" default=510 : i16>, #csl_wrapper.param<"padded_z_dim" default=510 : i16>], "program_name" = "loop", "width" = 1024 : i16}> ({ // CHECK-NEXT: ^2(%arg0_1 : i16, %arg1_1 : i16, %arg2 : i16, %arg3 : i16, %arg4 : i16, %arg5 : i16, %arg6 : i16, %arg7 : i16, %arg8 : i16): -// CHECK-NEXT: %50 = arith.constant 0 : i16 -// CHECK-NEXT: %51 = "csl.get_color"(%50) : (i16) -> !csl.color -// CHECK-NEXT: %52 = "csl_wrapper.import"(%arg2, %arg3, %51) <{"fields" = ["width", "height", "LAUNCH"], "module" = ""}> : (i16, i16, !csl.color) -> !csl.imported_module -// CHECK-NEXT: %53 = "csl_wrapper.import"(%arg5, %arg2, %arg3) <{"fields" = ["pattern", "peWidth", "peHeight"], "module" = "routes.csl"}> : (i16, i16, i16) -> !csl.imported_module -// CHECK-NEXT: %54 = "csl.member_call"(%53, %arg0_1, %arg1_1, %arg2, %arg3, %arg5) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct -// CHECK-NEXT: %55 = "csl.member_call"(%52, %arg0_1) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct -// CHECK-NEXT: %56 = arith.constant 1 : i16 -// CHECK-NEXT: %57 = arith.subi %arg5, %56 : i16 -// CHECK-NEXT: %58 = arith.subi %arg2, %arg0_1 : i16 -// CHECK-NEXT: %59 = arith.subi %arg3, %arg1_1 : i16 -// CHECK-NEXT: %60 = arith.cmpi slt, %arg0_1, %57 : i16 -// CHECK-NEXT: %61 = arith.cmpi slt, %arg1_1, %57 : i16 -// CHECK-NEXT: %62 = arith.cmpi slt, %58, %arg5 : i16 +// CHECK-NEXT: %51 = arith.constant 0 : i16 +// CHECK-NEXT: %52 = "csl.get_color"(%51) : (i16) -> !csl.color +// CHECK-NEXT: %53 = "csl_wrapper.import"(%arg2, %arg3, %52) <{"fields" = ["width", "height", "LAUNCH"], "module" = ""}> : (i16, i16, !csl.color) -> !csl.imported_module +// CHECK-NEXT: %54 = "csl_wrapper.import"(%arg5, %arg2, %arg3) <{"fields" = ["pattern", "peWidth", "peHeight"], "module" = "routes.csl"}> : (i16, i16, i16) -> !csl.imported_module +// CHECK-NEXT: %55 = "csl.member_call"(%54, %arg0_1, %arg1_1, %arg2, %arg3, %arg5) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct +// CHECK-NEXT: %56 = "csl.member_call"(%53, %arg0_1) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct +// CHECK-NEXT: %57 = arith.constant 1 : i16 +// CHECK-NEXT: %58 = arith.subi %arg5, %57 : i16 +// CHECK-NEXT: %59 = arith.subi %arg2, %arg0_1 : i16 +// CHECK-NEXT: %60 = arith.subi %arg3, %arg1_1 : i16 +// CHECK-NEXT: %61 = arith.cmpi slt, %arg0_1, %58 : i16 +// CHECK-NEXT: %62 = arith.cmpi slt, %arg1_1, %58 : i16 // CHECK-NEXT: %63 = arith.cmpi slt, %59, %arg5 : i16 -// CHECK-NEXT: %64 = arith.ori %60, %61 : i1 -// CHECK-NEXT: %65 = arith.ori %64, %62 : i1 +// CHECK-NEXT: %64 = arith.cmpi slt, %60, %arg5 : i16 +// CHECK-NEXT: %65 = arith.ori %61, %62 : i1 // CHECK-NEXT: %66 = arith.ori %65, %63 : i1 -// CHECK-NEXT: "csl_wrapper.yield"(%55, %54, %66) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () +// CHECK-NEXT: %67 = arith.ori %66, %64 : i1 +// CHECK-NEXT: "csl_wrapper.yield"(%56, %55, %67) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () // CHECK-NEXT: }, { // CHECK-NEXT: ^3(%arg0_2 : i16, %arg1_2 : i16, %arg2_1 : i16, %arg3_1 : i16, %arg4_1 : i16, %arg5_1 : i16, %arg6_1 : i16, %arg7_1 : !csl.comptime_struct, %arg8_1 : !csl.comptime_struct, %arg9 : i1): -// CHECK-NEXT: %67 = "csl_wrapper.import"(%arg7_1) <{"fields" = [""], "module" = ""}> : (!csl.comptime_struct) -> !csl.imported_module -// CHECK-NEXT: %68 = "csl_wrapper.import"(%arg3_1, %arg5_1, %arg8_1) <{"fields" = ["pattern", "chunkSize", ""], "module" = "stencil_comms.csl"}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module -// CHECK-NEXT: %69 = memref.alloc() : memref<512xf32> +// CHECK-NEXT: %68 = "csl_wrapper.import"(%arg7_1) <{"fields" = [""], "module" = ""}> : (!csl.comptime_struct) -> !csl.imported_module +// CHECK-NEXT: %69 = "csl_wrapper.import"(%arg3_1, %arg5_1, %arg8_1) <{"fields" = ["pattern", "chunkSize", ""], "module" = "stencil_comms.csl"}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module // CHECK-NEXT: %70 = memref.alloc() : memref<512xf32> -// CHECK-NEXT: %71 = "csl.addressof"(%69) : (memref<512xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: %71 = memref.alloc() : memref<512xf32> // CHECK-NEXT: %72 = "csl.addressof"(%70) : (memref<512xf32>) -> !csl.ptr, #csl> -// CHECK-NEXT: "csl.export"(%71) <{"type" = !csl.ptr, #csl>, "var_name" = "a"}> : (!csl.ptr, #csl>) -> () -// CHECK-NEXT: "csl.export"(%72) <{"type" = !csl.ptr, #csl>, "var_name" = "b"}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: %73 = "csl.addressof"(%71) : (memref<512xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: "csl.export"(%72) <{"type" = !csl.ptr, #csl>, "var_name" = "a"}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: "csl.export"(%73) <{"type" = !csl.ptr, #csl>, "var_name" = "b"}> : (!csl.ptr, #csl>) -> () // CHECK-NEXT: "csl.export"() <{"type" = () -> (), "var_name" = @gauss_seidel_func}> : () -> () -// CHECK-NEXT: %73 = "csl.variable"() <{"default" = 0 : i16}> : () -> !csl.var -// CHECK-NEXT: %74 = "csl.variable"() : () -> !csl.var> +// CHECK-NEXT: %74 = "csl.variable"() <{"default" = 0 : i16}> : () -> !csl.var // CHECK-NEXT: %75 = "csl.variable"() : () -> !csl.var> +// CHECK-NEXT: %76 = "csl.variable"() : () -> !csl.var> // CHECK-NEXT: csl.func @loop() { -// CHECK-NEXT: %76 = arith.constant 0 : index -// CHECK-NEXT: %77 = arith.constant 1000 : index -// CHECK-NEXT: %78 = arith.constant 1 : index -// CHECK-NEXT: "csl.store_var"(%74, %69) : (!csl.var>, memref<512xf32>) -> () +// CHECK-NEXT: %77 = arith.constant 0 : index +// CHECK-NEXT: %78 = arith.constant 1000 : index +// CHECK-NEXT: %79 = arith.constant 1 : index // CHECK-NEXT: "csl.store_var"(%75, %70) : (!csl.var>, memref<512xf32>) -> () +// CHECK-NEXT: "csl.store_var"(%76, %71) : (!csl.var>, memref<512xf32>) -> () // CHECK-NEXT: csl.activate local, 1 : i32 // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.task @for_cond0() attributes {"kind" = #csl, "id" = 1 : i5}{ -// CHECK-NEXT: %79 = arith.constant 1000 : i16 -// CHECK-NEXT: %80 = "csl.load_var"(%73) : (!csl.var) -> i16 -// CHECK-NEXT: %81 = arith.cmpi slt, %80, %79 : i16 -// CHECK-NEXT: scf.if %81 { +// CHECK-NEXT: %80 = arith.constant 1000 : i16 +// CHECK-NEXT: %81 = "csl.load_var"(%74) : (!csl.var) -> i16 +// CHECK-NEXT: %82 = arith.cmpi slt, %81, %80 : i16 +// CHECK-NEXT: scf.if %82 { // CHECK-NEXT: "csl.call"() <{"callee" = @for_body0}> : () -> () // CHECK-NEXT: } else { // CHECK-NEXT: "csl.call"() <{"callee" = @for_post0}> : () -> () @@ -289,60 +291,62 @@ builtin.module { // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @for_body0() { -// CHECK-NEXT: %arg10 = "csl.load_var"(%73) : (!csl.var) -> i16 -// CHECK-NEXT: %arg11 = "csl.load_var"(%74) : (!csl.var>) -> memref<512xf32> -// CHECK-NEXT: %arg12 = "csl.load_var"(%75) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: %arg10 = "csl.load_var"(%74) : (!csl.var) -> i16 +// CHECK-NEXT: %arg11 = "csl.load_var"(%75) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: %arg12 = "csl.load_var"(%76) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: %accumulator_1 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> -// CHECK-NEXT: %82 = arith.constant 1 : i16 -// CHECK-NEXT: %83 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb3}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> -// CHECK-NEXT: %84 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb3}> : () -> !csl.ptr<() -> (), #csl, #csl> -// CHECK-NEXT: %85 = memref.subview %arg11[1] [510] [1] : memref<512xf32> to memref<510xf32> -// CHECK-NEXT: "csl.member_call"(%68, %85, %82, %83, %84) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () +// CHECK-NEXT: %83 = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: "csl.fmovs"(%accumulator_1, %83) : (memref<510xf32>, f32) -> () +// CHECK-NEXT: %84 = arith.constant 1 : i16 +// CHECK-NEXT: %85 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb3}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> +// CHECK-NEXT: %86 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb3}> : () -> !csl.ptr<() -> (), #csl, #csl> +// CHECK-NEXT: %87 = memref.subview %arg11[1] [510] [1] : memref<512xf32> to memref<510xf32> +// CHECK-NEXT: "csl.member_call"(%69, %87, %84, %85, %86) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @receive_chunk_cb3(%offset_2 : i16) { // CHECK-NEXT: %offset_3 = arith.index_cast %offset_2 : i16 to index -// CHECK-NEXT: %86 = memref.subview %accumulator_1[%offset_3] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> -// CHECK-NEXT: %87 = arith.constant 4 : i16 -// CHECK-NEXT: %88 = "csl.get_mem_dsd"(%accumulator_1, %87, %arg3_1, %arg5_1) <{"strides" = [0 : i16, 0 : i16, 1 : i16]}> : (memref<510xf32>, i16, i16, i16) -> !csl -// CHECK-NEXT: %89 = arith.index_cast %offset_3 : index to si16 -// CHECK-NEXT: %90 = "csl.increment_dsd_offset"(%88, %89) <{"elem_type" = f32}> : (!csl, si16) -> !csl -// CHECK-NEXT: %91 = "csl.member_call"(%68) <{"field" = "getRecvBufDsd"}> : (!csl.imported_module) -> !csl -// CHECK-NEXT: "csl.fadds"(%90, %90, %91) : (!csl, !csl, !csl) -> () -// CHECK-NEXT: "memref.copy"(%86, %86) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>) -> () +// CHECK-NEXT: %88 = memref.subview %accumulator_1[%offset_3] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> +// CHECK-NEXT: %89 = arith.constant 4 : i16 +// CHECK-NEXT: %90 = "csl.get_mem_dsd"(%accumulator_1, %89, %arg3_1, %arg5_1) <{"strides" = [0 : i16, 0 : i16, 1 : i16]}> : (memref<510xf32>, i16, i16, i16) -> !csl +// CHECK-NEXT: %91 = arith.index_cast %offset_3 : index to si16 +// CHECK-NEXT: %92 = "csl.increment_dsd_offset"(%90, %91) <{"elem_type" = f32}> : (!csl, si16) -> !csl +// CHECK-NEXT: %93 = "csl.member_call"(%69) <{"field" = "getRecvBufDsd"}> : (!csl.imported_module) -> !csl +// CHECK-NEXT: "csl.fadds"(%92, %92, %93) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "memref.copy"(%88, %88) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @done_exchange_cb3() { -// CHECK-NEXT: %arg12_1 = "csl.load_var"(%75) : (!csl.var>) -> memref<512xf32> -// CHECK-NEXT: %arg11_1 = "csl.load_var"(%74) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: %arg12_1 = "csl.load_var"(%76) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: %arg11_1 = "csl.load_var"(%75) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: scf.if %arg9 { // CHECK-NEXT: } else { -// CHECK-NEXT: %92 = memref.subview %arg11_1[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> -// CHECK-NEXT: %93 = memref.subview %arg11_1[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> -// CHECK-NEXT: "csl.fadds"(%accumulator_1, %accumulator_1, %93) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () -// CHECK-NEXT: "csl.fadds"(%accumulator_1, %accumulator_1, %92) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () -// CHECK-NEXT: %94 = arith.constant 1.666600e-01 : f32 -// CHECK-NEXT: "csl.fmuls"(%accumulator_1, %accumulator_1, %94) : (memref<510xf32>, memref<510xf32>, f32) -> () -// CHECK-NEXT: %95 = memref.subview %arg12_1[1] [510] [1] : memref<512xf32> to memref<510xf32> -// CHECK-NEXT: "memref.copy"(%accumulator_1, %95) : (memref<510xf32>, memref<510xf32>) -> () +// CHECK-NEXT: %94 = memref.subview %arg11_1[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> +// CHECK-NEXT: %95 = memref.subview %arg11_1[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> +// CHECK-NEXT: "csl.fadds"(%accumulator_1, %accumulator_1, %95) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () +// CHECK-NEXT: "csl.fadds"(%accumulator_1, %accumulator_1, %94) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () +// CHECK-NEXT: %96 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: "csl.fmuls"(%accumulator_1, %accumulator_1, %96) : (memref<510xf32>, memref<510xf32>, f32) -> () +// CHECK-NEXT: %97 = memref.subview %arg12_1[1] [510] [1] : memref<512xf32> to memref<510xf32> +// CHECK-NEXT: "memref.copy"(%accumulator_1, %97) : (memref<510xf32>, memref<510xf32>) -> () // CHECK-NEXT: } // CHECK-NEXT: "csl.call"() <{"callee" = @for_inc0}> : () -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @for_inc0() { -// CHECK-NEXT: %96 = arith.constant 1 : i16 -// CHECK-NEXT: %97 = "csl.load_var"(%73) : (!csl.var) -> i16 -// CHECK-NEXT: %98 = arith.addi %97, %96 : i16 -// CHECK-NEXT: "csl.store_var"(%73, %98) : (!csl.var, i16) -> () -// CHECK-NEXT: %99 = "csl.load_var"(%74) : (!csl.var>) -> memref<512xf32> -// CHECK-NEXT: %100 = "csl.load_var"(%75) : (!csl.var>) -> memref<512xf32> -// CHECK-NEXT: "csl.store_var"(%74, %100) : (!csl.var>, memref<512xf32>) -> () -// CHECK-NEXT: "csl.store_var"(%75, %99) : (!csl.var>, memref<512xf32>) -> () +// CHECK-NEXT: %98 = arith.constant 1 : i16 +// CHECK-NEXT: %99 = "csl.load_var"(%74) : (!csl.var) -> i16 +// CHECK-NEXT: %100 = arith.addi %99, %98 : i16 +// CHECK-NEXT: "csl.store_var"(%74, %100) : (!csl.var, i16) -> () +// CHECK-NEXT: %101 = "csl.load_var"(%75) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: %102 = "csl.load_var"(%76) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: "csl.store_var"(%75, %102) : (!csl.var>, memref<512xf32>) -> () +// CHECK-NEXT: "csl.store_var"(%76, %101) : (!csl.var>, memref<512xf32>) -> () // CHECK-NEXT: csl.activate local, 1 : i32 // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @for_post0() { -// CHECK-NEXT: "csl.member_call"(%67) <{"field" = "unblock_cmd_stream"}> : (!csl.imported_module) -> () +// CHECK-NEXT: "csl.member_call"(%68) <{"field" = "unblock_cmd_stream"}> : (!csl.imported_module) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () @@ -407,71 +411,71 @@ builtin.module { }) : () -> () // CHECK-NEXT: "csl_wrapper.module"() <{"width" = 1022 : i16, "height" = 510 : i16, "params" = [#csl_wrapper.param<"z_dim" default=512 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=2 : i16>, #csl_wrapper.param<"chunk_size" default=255 : i16>, #csl_wrapper.param<"padded_z_dim" default=510 : i16>], "program_name" = "partial_access"}> ({ -// CHECK-NEXT: ^4(%101 : i16, %102 : i16, %103 : i16, %104 : i16, %105 : i16, %106 : i16, %107 : i16, %108 : i16, %109 : i16): -// CHECK-NEXT: %110 = arith.constant 0 : i16 -// CHECK-NEXT: %111 = "csl.get_color"(%110) : (i16) -> !csl.color -// CHECK-NEXT: %112 = "csl_wrapper.import"(%103, %104, %111) <{"module" = "", "fields" = ["width", "height", "LAUNCH"]}> : (i16, i16, !csl.color) -> !csl.imported_module -// CHECK-NEXT: %113 = "csl_wrapper.import"(%106, %103, %104) <{"module" = "routes.csl", "fields" = ["pattern", "peWidth", "peHeight"]}> : (i16, i16, i16) -> !csl.imported_module -// CHECK-NEXT: %114 = "csl.member_call"(%113, %101, %102, %103, %104, %106) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct -// CHECK-NEXT: %115 = "csl.member_call"(%112, %101) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct -// CHECK-NEXT: %116 = arith.constant 1 : i16 -// CHECK-NEXT: %117 = arith.subi %106, %116 : i16 -// CHECK-NEXT: %118 = arith.subi %103, %101 : i16 -// CHECK-NEXT: %119 = arith.subi %104, %102 : i16 -// CHECK-NEXT: %120 = arith.cmpi slt, %101, %117 : i16 -// CHECK-NEXT: %121 = arith.cmpi slt, %102, %117 : i16 -// CHECK-NEXT: %122 = arith.cmpi slt, %118, %106 : i16 -// CHECK-NEXT: %123 = arith.cmpi slt, %119, %106 : i16 -// CHECK-NEXT: %124 = arith.ori %120, %121 : i1 -// CHECK-NEXT: %125 = arith.ori %124, %122 : i1 -// CHECK-NEXT: %126 = arith.ori %125, %123 : i1 -// CHECK-NEXT: "csl_wrapper.yield"(%115, %114, %126) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () +// CHECK-NEXT: ^4(%103 : i16, %104 : i16, %105 : i16, %106 : i16, %107 : i16, %108 : i16, %109 : i16, %110 : i16, %111 : i16): +// CHECK-NEXT: %112 = arith.constant 0 : i16 +// CHECK-NEXT: %113 = "csl.get_color"(%112) : (i16) -> !csl.color +// CHECK-NEXT: %114 = "csl_wrapper.import"(%105, %106, %113) <{"module" = "", "fields" = ["width", "height", "LAUNCH"]}> : (i16, i16, !csl.color) -> !csl.imported_module +// CHECK-NEXT: %115 = "csl_wrapper.import"(%108, %105, %106) <{"module" = "routes.csl", "fields" = ["pattern", "peWidth", "peHeight"]}> : (i16, i16, i16) -> !csl.imported_module +// CHECK-NEXT: %116 = "csl.member_call"(%115, %103, %104, %105, %106, %108) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct +// CHECK-NEXT: %117 = "csl.member_call"(%114, %103) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct +// CHECK-NEXT: %118 = arith.constant 1 : i16 +// CHECK-NEXT: %119 = arith.subi %108, %118 : i16 +// CHECK-NEXT: %120 = arith.subi %105, %103 : i16 +// CHECK-NEXT: %121 = arith.subi %106, %104 : i16 +// CHECK-NEXT: %122 = arith.cmpi slt, %103, %119 : i16 +// CHECK-NEXT: %123 = arith.cmpi slt, %104, %119 : i16 +// CHECK-NEXT: %124 = arith.cmpi slt, %120, %108 : i16 +// CHECK-NEXT: %125 = arith.cmpi slt, %121, %108 : i16 +// CHECK-NEXT: %126 = arith.ori %122, %123 : i1 +// CHECK-NEXT: %127 = arith.ori %126, %124 : i1 +// CHECK-NEXT: %128 = arith.ori %127, %125 : i1 +// CHECK-NEXT: "csl_wrapper.yield"(%117, %116, %128) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () // CHECK-NEXT: }, { -// CHECK-NEXT: ^5(%127 : i16, %128 : i16, %129 : i16, %130 : i16, %131 : i16, %132 : i16, %133 : i16, %memcpy_params_1 : !csl.comptime_struct, %stencil_comms_params_1 : !csl.comptime_struct, %isBorderRegionPE_1 : i1): -// CHECK-NEXT: %134 = "csl_wrapper.import"(%memcpy_params_1) <{"module" = "", "fields" = [""]}> : (!csl.comptime_struct) -> !csl.imported_module -// CHECK-NEXT: %135 = "csl_wrapper.import"(%130, %132, %stencil_comms_params_1) <{"module" = "stencil_comms.csl", "fields" = ["pattern", "chunkSize", ""]}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module +// CHECK-NEXT: ^5(%129 : i16, %130 : i16, %131 : i16, %132 : i16, %133 : i16, %134 : i16, %135 : i16, %memcpy_params_1 : !csl.comptime_struct, %stencil_comms_params_1 : !csl.comptime_struct, %isBorderRegionPE_1 : i1): +// CHECK-NEXT: %136 = "csl_wrapper.import"(%memcpy_params_1) <{"module" = "", "fields" = [""]}> : (!csl.comptime_struct) -> !csl.imported_module +// CHECK-NEXT: %137 = "csl_wrapper.import"(%132, %134, %stencil_comms_params_1) <{"module" = "stencil_comms.csl", "fields" = ["pattern", "chunkSize", ""]}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module // CHECK-NEXT: %arg0_3 = memref.alloc() : memref<512xf32> // CHECK-NEXT: %arg1_3 = memref.alloc() : memref<512xf32> -// CHECK-NEXT: %136 = "csl.addressof"(%arg0_3) : (memref<512xf32>) -> !csl.ptr, #csl> -// CHECK-NEXT: %137 = "csl.addressof"(%arg1_3) : (memref<512xf32>) -> !csl.ptr, #csl> -// CHECK-NEXT: "csl.export"(%136) <{"var_name" = "arg0", "type" = !csl.ptr, #csl>}> : (!csl.ptr, #csl>) -> () -// CHECK-NEXT: "csl.export"(%137) <{"var_name" = "arg1", "type" = !csl.ptr, #csl>}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: %138 = "csl.addressof"(%arg0_3) : (memref<512xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: %139 = "csl.addressof"(%arg1_3) : (memref<512xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: "csl.export"(%138) <{"var_name" = "arg0", "type" = !csl.ptr, #csl>}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: "csl.export"(%139) <{"var_name" = "arg1", "type" = !csl.ptr, #csl>}> : (!csl.ptr, #csl>) -> () // CHECK-NEXT: "csl.export"() <{"var_name" = @gauss_seidel_func, "type" = () -> ()}> : () -> () // CHECK-NEXT: csl.func @partial_access() { // CHECK-NEXT: %accumulator_2 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> -// CHECK-NEXT: %138 = arith.constant 2 : i16 -// CHECK-NEXT: %139 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb0}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> -// CHECK-NEXT: %140 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb0}> : () -> !csl.ptr<() -> (), #csl, #csl> -// CHECK-NEXT: %141 = memref.subview %arg0_3[1] [510] [1] : memref<512xf32> to memref<510xf32> -// CHECK-NEXT: "csl.member_call"(%135, %141, %138, %139, %140) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () +// CHECK-NEXT: %140 = arith.constant 2 : i16 +// CHECK-NEXT: %141 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb0}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> +// CHECK-NEXT: %142 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb0}> : () -> !csl.ptr<() -> (), #csl, #csl> +// CHECK-NEXT: %143 = memref.subview %arg0_3[1] [510] [1] : memref<512xf32> to memref<510xf32> +// CHECK-NEXT: "csl.member_call"(%137, %143, %140, %141, %142) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @receive_chunk_cb0(%offset_4 : i16) { // CHECK-NEXT: %offset_5 = arith.index_cast %offset_4 : i16 to index -// CHECK-NEXT: %142 = arith.constant 1 : i16 -// CHECK-NEXT: %143 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %144 = "csl.member_call"(%135, %143, %142) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %145 = builtin.unrealized_conversion_cast %144 : !csl to memref<255xf32> -// CHECK-NEXT: %146 = arith.constant 1 : i16 -// CHECK-NEXT: %147 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %148 = "csl.member_call"(%135, %147, %146) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %149 = builtin.unrealized_conversion_cast %148 : !csl to memref<255xf32> -// CHECK-NEXT: %150 = arith.constant 1 : i16 -// CHECK-NEXT: %151 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %152 = "csl.member_call"(%135, %151, %150) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %153 = builtin.unrealized_conversion_cast %152 : !csl to memref<255xf32> -// CHECK-NEXT: %154 = memref.subview %accumulator_2[%offset_5] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>> -// CHECK-NEXT: "csl.fadds"(%154, %149, %153) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>, memref<255xf32>) -> () -// CHECK-NEXT: "csl.fadds"(%154, %154, %145) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> () +// CHECK-NEXT: %144 = arith.constant 1 : i16 +// CHECK-NEXT: %145 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %146 = "csl.member_call"(%137, %145, %144) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %147 = builtin.unrealized_conversion_cast %146 : !csl to memref<255xf32> +// CHECK-NEXT: %148 = arith.constant 1 : i16 +// CHECK-NEXT: %149 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %150 = "csl.member_call"(%137, %149, %148) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %151 = builtin.unrealized_conversion_cast %150 : !csl to memref<255xf32> +// CHECK-NEXT: %152 = arith.constant 1 : i16 +// CHECK-NEXT: %153 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %154 = "csl.member_call"(%137, %153, %152) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %155 = builtin.unrealized_conversion_cast %154 : !csl to memref<255xf32> +// CHECK-NEXT: %156 = memref.subview %accumulator_2[%offset_5] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>> +// CHECK-NEXT: "csl.fadds"(%156, %151, %155) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>, memref<255xf32>) -> () +// CHECK-NEXT: "csl.fadds"(%156, %156, %147) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @done_exchange_cb0() { -// CHECK-NEXT: %155 = memref.subview %arg0_3[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> -// CHECK-NEXT: %156 = memref.subview %arg0_3[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> -// CHECK-NEXT: "csl.fadds"(%accumulator_2, %accumulator_2, %156) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () -// CHECK-NEXT: "csl.fadds"(%accumulator_2, %accumulator_2, %155) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () -// CHECK-NEXT: %157 = arith.constant 1.666600e-01 : f32 -// CHECK-NEXT: "csl.fmuls"(%accumulator_2, %accumulator_2, %157) : (memref<510xf32>, memref<510xf32>, f32) -> () +// CHECK-NEXT: %157 = memref.subview %arg0_3[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> +// CHECK-NEXT: %158 = memref.subview %arg0_3[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> +// CHECK-NEXT: "csl.fadds"(%accumulator_2, %accumulator_2, %158) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () +// CHECK-NEXT: "csl.fadds"(%accumulator_2, %accumulator_2, %157) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () +// CHECK-NEXT: %159 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: "csl.fmuls"(%accumulator_2, %accumulator_2, %159) : (memref<510xf32>, memref<510xf32>, f32) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () @@ -581,52 +585,52 @@ builtin.module { // CHECK-NEXT: "csl_wrapper.module"() <{"height" = 512 : i16, "params" = [#csl_wrapper.param<"z_dim" default=511 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=1 : i16>, #csl_wrapper.param<"chunk_size" default=510 : i16>, #csl_wrapper.param<"padded_z_dim" default=510 : i16>], "program_name" = "chunk_reduce_only", "width" = 1024 : i16}> ({ // CHECK-NEXT: ^6(%arg0_4 : i16, %arg1_4 : i16, %arg2_2 : i16, %arg3_2 : i16, %arg4_2 : i16, %arg5_2 : i16, %arg6_2 : i16, %arg7_2 : i16, %arg8_2 : i16): -// CHECK-NEXT: %158 = arith.constant 1 : i16 -// CHECK-NEXT: %159 = arith.constant 0 : i16 -// CHECK-NEXT: %160 = "csl.get_color"(%159) : (i16) -> !csl.color -// CHECK-NEXT: %161 = "csl_wrapper.import"(%arg2_2, %arg3_2, %160) <{"fields" = ["width", "height", "LAUNCH"], "module" = ""}> : (i16, i16, !csl.color) -> !csl.imported_module -// CHECK-NEXT: %162 = "csl_wrapper.import"(%arg5_2, %arg2_2, %arg3_2) <{"fields" = ["pattern", "peWidth", "peHeight"], "module" = "routes.csl"}> : (i16, i16, i16) -> !csl.imported_module -// CHECK-NEXT: %163 = "csl.member_call"(%162, %arg0_4, %arg1_4, %arg2_2, %arg3_2, %arg5_2) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct -// CHECK-NEXT: %164 = "csl.member_call"(%161, %arg0_4) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct -// CHECK-NEXT: %165 = arith.subi %arg5_2, %158 : i16 -// CHECK-NEXT: %166 = arith.subi %arg2_2, %arg0_4 : i16 -// CHECK-NEXT: %167 = arith.subi %arg3_2, %arg1_4 : i16 -// CHECK-NEXT: %168 = arith.cmpi slt, %arg0_4, %165 : i16 -// CHECK-NEXT: %169 = arith.cmpi slt, %arg1_4, %165 : i16 -// CHECK-NEXT: %170 = arith.cmpi slt, %166, %arg5_2 : i16 -// CHECK-NEXT: %171 = arith.cmpi slt, %167, %arg5_2 : i16 -// CHECK-NEXT: %172 = arith.ori %168, %169 : i1 -// CHECK-NEXT: %173 = arith.ori %172, %170 : i1 -// CHECK-NEXT: %174 = arith.ori %173, %171 : i1 -// CHECK-NEXT: "csl_wrapper.yield"(%164, %163, %174) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () +// CHECK-NEXT: %160 = arith.constant 1 : i16 +// CHECK-NEXT: %161 = arith.constant 0 : i16 +// CHECK-NEXT: %162 = "csl.get_color"(%161) : (i16) -> !csl.color +// CHECK-NEXT: %163 = "csl_wrapper.import"(%arg2_2, %arg3_2, %162) <{"fields" = ["width", "height", "LAUNCH"], "module" = ""}> : (i16, i16, !csl.color) -> !csl.imported_module +// CHECK-NEXT: %164 = "csl_wrapper.import"(%arg5_2, %arg2_2, %arg3_2) <{"fields" = ["pattern", "peWidth", "peHeight"], "module" = "routes.csl"}> : (i16, i16, i16) -> !csl.imported_module +// CHECK-NEXT: %165 = "csl.member_call"(%164, %arg0_4, %arg1_4, %arg2_2, %arg3_2, %arg5_2) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct +// CHECK-NEXT: %166 = "csl.member_call"(%163, %arg0_4) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct +// CHECK-NEXT: %167 = arith.subi %arg5_2, %160 : i16 +// CHECK-NEXT: %168 = arith.subi %arg2_2, %arg0_4 : i16 +// CHECK-NEXT: %169 = arith.subi %arg3_2, %arg1_4 : i16 +// CHECK-NEXT: %170 = arith.cmpi slt, %arg0_4, %167 : i16 +// CHECK-NEXT: %171 = arith.cmpi slt, %arg1_4, %167 : i16 +// CHECK-NEXT: %172 = arith.cmpi slt, %168, %arg5_2 : i16 +// CHECK-NEXT: %173 = arith.cmpi slt, %169, %arg5_2 : i16 +// CHECK-NEXT: %174 = arith.ori %170, %171 : i1 +// CHECK-NEXT: %175 = arith.ori %174, %172 : i1 +// CHECK-NEXT: %176 = arith.ori %175, %173 : i1 +// CHECK-NEXT: "csl_wrapper.yield"(%166, %165, %176) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () // CHECK-NEXT: }, { // CHECK-NEXT: ^7(%arg0_5 : i16, %arg1_5 : i16, %arg2_3 : i16, %arg3_3 : i16, %arg4_3 : i16, %arg5_3 : i16, %arg6_3 : i16, %arg7_3 : !csl.comptime_struct, %arg8_3 : !csl.comptime_struct, %arg9_1 : i1): -// CHECK-NEXT: %175 = "csl_wrapper.import"(%arg7_3) <{"fields" = [""], "module" = ""}> : (!csl.comptime_struct) -> !csl.imported_module -// CHECK-NEXT: %176 = "csl_wrapper.import"(%arg3_3, %arg5_3, %arg8_3) <{"fields" = ["pattern", "chunkSize", ""], "module" = "stencil_comms.csl"}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module -// CHECK-NEXT: %177 = memref.alloc() : memref<511xf32> -// CHECK-NEXT: %178 = memref.alloc() : memref<511xf32> -// CHECK-NEXT: %179 = "csl.addressof"(%177) : (memref<511xf32>) -> !csl.ptr, #csl> -// CHECK-NEXT: %180 = "csl.addressof"(%178) : (memref<511xf32>) -> !csl.ptr, #csl> -// CHECK-NEXT: "csl.export"(%179) <{"type" = !csl.ptr, #csl>, "var_name" = "arg0"}> : (!csl.ptr, #csl>) -> () -// CHECK-NEXT: "csl.export"(%180) <{"type" = !csl.ptr, #csl>, "var_name" = "arg1"}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: %177 = "csl_wrapper.import"(%arg7_3) <{"fields" = [""], "module" = ""}> : (!csl.comptime_struct) -> !csl.imported_module +// CHECK-NEXT: %178 = "csl_wrapper.import"(%arg3_3, %arg5_3, %arg8_3) <{"fields" = ["pattern", "chunkSize", ""], "module" = "stencil_comms.csl"}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module +// CHECK-NEXT: %179 = memref.alloc() : memref<511xf32> +// CHECK-NEXT: %180 = memref.alloc() : memref<511xf32> +// CHECK-NEXT: %181 = "csl.addressof"(%179) : (memref<511xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: %182 = "csl.addressof"(%180) : (memref<511xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: "csl.export"(%181) <{"type" = !csl.ptr, #csl>, "var_name" = "arg0"}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: "csl.export"(%182) <{"type" = !csl.ptr, #csl>, "var_name" = "arg1"}> : (!csl.ptr, #csl>) -> () // CHECK-NEXT: "csl.export"() <{"type" = () -> (), "var_name" = @chunk_reduce_only}> : () -> () -// CHECK-NEXT: %181 = "csl.variable"() <{"default" = 0 : i16}> : () -> !csl.var -// CHECK-NEXT: %182 = "csl.variable"() : () -> !csl.var> -// CHECK-NEXT: %183 = "csl.variable"() : () -> !csl.var> +// CHECK-NEXT: %183 = "csl.variable"() <{"default" = 0 : i16}> : () -> !csl.var +// CHECK-NEXT: %184 = "csl.variable"() : () -> !csl.var> +// CHECK-NEXT: %185 = "csl.variable"() : () -> !csl.var> // CHECK-NEXT: csl.func @chunk_reduce_only() { -// CHECK-NEXT: %184 = arith.constant 0 : index -// CHECK-NEXT: %185 = arith.constant 1000 : index -// CHECK-NEXT: %186 = arith.constant 1 : index -// CHECK-NEXT: "csl.store_var"(%182, %177) : (!csl.var>, memref<511xf32>) -> () -// CHECK-NEXT: "csl.store_var"(%183, %178) : (!csl.var>, memref<511xf32>) -> () +// CHECK-NEXT: %186 = arith.constant 0 : index +// CHECK-NEXT: %187 = arith.constant 1000 : index +// CHECK-NEXT: %188 = arith.constant 1 : index +// CHECK-NEXT: "csl.store_var"(%184, %179) : (!csl.var>, memref<511xf32>) -> () +// CHECK-NEXT: "csl.store_var"(%185, %180) : (!csl.var>, memref<511xf32>) -> () // CHECK-NEXT: csl.activate local, 1 : i32 // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.task @for_cond0() attributes {"kind" = #csl, "id" = 1 : i5}{ -// CHECK-NEXT: %187 = arith.constant 1000 : i16 -// CHECK-NEXT: %188 = "csl.load_var"(%181) : (!csl.var) -> i16 -// CHECK-NEXT: %189 = arith.cmpi slt, %188, %187 : i16 -// CHECK-NEXT: scf.if %189 { +// CHECK-NEXT: %189 = arith.constant 1000 : i16 +// CHECK-NEXT: %190 = "csl.load_var"(%183) : (!csl.var) -> i16 +// CHECK-NEXT: %191 = arith.cmpi slt, %190, %189 : i16 +// CHECK-NEXT: scf.if %191 { // CHECK-NEXT: "csl.call"() <{"callee" = @for_body0}> : () -> () // CHECK-NEXT: } else { // CHECK-NEXT: "csl.call"() <{"callee" = @for_post0}> : () -> () @@ -634,72 +638,72 @@ builtin.module { // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @for_body0() { -// CHECK-NEXT: %arg10_1 = "csl.load_var"(%181) : (!csl.var) -> i16 -// CHECK-NEXT: %arg11_2 = "csl.load_var"(%182) : (!csl.var>) -> memref<511xf32> -// CHECK-NEXT: %arg12_2 = "csl.load_var"(%183) : (!csl.var>) -> memref<511xf32> +// CHECK-NEXT: %arg10_1 = "csl.load_var"(%183) : (!csl.var) -> i16 +// CHECK-NEXT: %arg11_2 = "csl.load_var"(%184) : (!csl.var>) -> memref<511xf32> +// CHECK-NEXT: %arg12_2 = "csl.load_var"(%185) : (!csl.var>) -> memref<511xf32> // CHECK-NEXT: %accumulator_3 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> // CHECK-NEXT: %north = arith.constant dense<[0.000000e+00, 3.141500e-01]> : memref<2xf32> // CHECK-NEXT: %south = arith.constant dense<[0.000000e+00, 1.000000e+00]> : memref<2xf32> // CHECK-NEXT: %east = arith.constant dense<[0.000000e+00, 1.000000e+00]> : memref<2xf32> // CHECK-NEXT: %west = arith.constant dense<[0.000000e+00, 2.345678e-01]> : memref<2xf32> -// CHECK-NEXT: %190 = "csl.addressof"(%east) : (memref<2xf32>) -> !csl.ptr, #csl, #csl> -// CHECK-NEXT: %191 = "csl.addressof"(%west) : (memref<2xf32>) -> !csl.ptr, #csl, #csl> -// CHECK-NEXT: %192 = "csl.addressof"(%south) : (memref<2xf32>) -> !csl.ptr, #csl, #csl> -// CHECK-NEXT: %193 = "csl.addressof"(%north) : (memref<2xf32>) -> !csl.ptr, #csl, #csl> -// CHECK-NEXT: %194 = arith.constant false -// CHECK-NEXT: "csl.member_call"(%176, %190, %191, %192, %193, %194) <{"field" = "setCoeffs"}> : (!csl.imported_module, !csl.ptr, #csl, #csl>, !csl.ptr, #csl, #csl>, !csl.ptr, #csl, #csl>, !csl.ptr, #csl, #csl>, i1) -> () -// CHECK-NEXT: %195 = arith.constant 1 : i16 -// CHECK-NEXT: %196 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb1}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> -// CHECK-NEXT: %197 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb1}> : () -> !csl.ptr<() -> (), #csl, #csl> -// CHECK-NEXT: %198 = memref.subview %arg11_2[0] [510] [1] : memref<511xf32> to memref<510xf32> -// CHECK-NEXT: "csl.member_call"(%176, %198, %195, %196, %197) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () +// CHECK-NEXT: %192 = "csl.addressof"(%east) : (memref<2xf32>) -> !csl.ptr, #csl, #csl> +// CHECK-NEXT: %193 = "csl.addressof"(%west) : (memref<2xf32>) -> !csl.ptr, #csl, #csl> +// CHECK-NEXT: %194 = "csl.addressof"(%south) : (memref<2xf32>) -> !csl.ptr, #csl, #csl> +// CHECK-NEXT: %195 = "csl.addressof"(%north) : (memref<2xf32>) -> !csl.ptr, #csl, #csl> +// CHECK-NEXT: %196 = arith.constant false +// CHECK-NEXT: "csl.member_call"(%178, %192, %193, %194, %195, %196) <{"field" = "setCoeffs"}> : (!csl.imported_module, !csl.ptr, #csl, #csl>, !csl.ptr, #csl, #csl>, !csl.ptr, #csl, #csl>, !csl.ptr, #csl, #csl>, i1) -> () +// CHECK-NEXT: %197 = arith.constant 1 : i16 +// CHECK-NEXT: %198 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb1}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> +// CHECK-NEXT: %199 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb1}> : () -> !csl.ptr<() -> (), #csl, #csl> +// CHECK-NEXT: %200 = memref.subview %arg11_2[0] [510] [1] : memref<511xf32> to memref<510xf32> +// CHECK-NEXT: "csl.member_call"(%178, %200, %197, %198, %199) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @receive_chunk_cb1(%offset_6 : i16) { // CHECK-NEXT: %offset_7 = arith.index_cast %offset_6 : i16 to index -// CHECK-NEXT: %arg11_3 = "csl.load_var"(%182) : (!csl.var>) -> memref<511xf32> -// CHECK-NEXT: %199 = arith.constant dense<1.234500e-01> : memref<510xf32> -// CHECK-NEXT: %200 = arith.constant 1 : i16 -// CHECK-NEXT: %201 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %202 = "csl.member_call"(%176, %201, %200) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %203 = builtin.unrealized_conversion_cast %202 : !csl to memref<510xf32> -// CHECK-NEXT: %204 = memref.subview %arg11_3[1] [510] [1] : memref<511xf32> to memref<510xf32, strided<[1], offset: 1>> -// CHECK-NEXT: %205 = arith.constant 1 : i16 -// CHECK-NEXT: %206 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %207 = "csl.member_call"(%176, %206, %205) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %208 = builtin.unrealized_conversion_cast %207 : !csl to memref<510xf32> -// CHECK-NEXT: %209 = memref.subview %accumulator_3[%offset_7] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> -// CHECK-NEXT: "csl.fadds"(%209, %204, %208) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: 1>>, memref<510xf32>) -> () -// CHECK-NEXT: "csl.fadds"(%209, %209, %203) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> () -// CHECK-NEXT: %210 = arith.constant 1.234500e-01 : f32 -// CHECK-NEXT: "csl.fmuls"(%209, %209, %210) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, f32) -> () +// CHECK-NEXT: %arg11_3 = "csl.load_var"(%184) : (!csl.var>) -> memref<511xf32> +// CHECK-NEXT: %201 = arith.constant dense<1.234500e-01> : memref<510xf32> +// CHECK-NEXT: %202 = arith.constant 1 : i16 +// CHECK-NEXT: %203 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %204 = "csl.member_call"(%178, %203, %202) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %205 = builtin.unrealized_conversion_cast %204 : !csl to memref<510xf32> +// CHECK-NEXT: %206 = memref.subview %arg11_3[1] [510] [1] : memref<511xf32> to memref<510xf32, strided<[1], offset: 1>> +// CHECK-NEXT: %207 = arith.constant 1 : i16 +// CHECK-NEXT: %208 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %209 = "csl.member_call"(%178, %208, %207) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %210 = builtin.unrealized_conversion_cast %209 : !csl to memref<510xf32> +// CHECK-NEXT: %211 = memref.subview %accumulator_3[%offset_7] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> +// CHECK-NEXT: "csl.fadds"(%211, %206, %210) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: 1>>, memref<510xf32>) -> () +// CHECK-NEXT: "csl.fadds"(%211, %211, %205) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> () +// CHECK-NEXT: %212 = arith.constant 1.234500e-01 : f32 +// CHECK-NEXT: "csl.fmuls"(%211, %211, %212) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, f32) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @done_exchange_cb1() { -// CHECK-NEXT: %arg12_3 = "csl.load_var"(%183) : (!csl.var>) -> memref<511xf32> -// CHECK-NEXT: %arg11_4 = "csl.load_var"(%182) : (!csl.var>) -> memref<511xf32> +// CHECK-NEXT: %arg12_3 = "csl.load_var"(%185) : (!csl.var>) -> memref<511xf32> +// CHECK-NEXT: %arg11_4 = "csl.load_var"(%184) : (!csl.var>) -> memref<511xf32> // CHECK-NEXT: scf.if %arg9_1 { // CHECK-NEXT: } else { -// CHECK-NEXT: %211 = memref.subview %arg12_3[0] [510] [1] : memref<511xf32> to memref<510xf32> -// CHECK-NEXT: "memref.copy"(%accumulator_3, %211) : (memref<510xf32>, memref<510xf32>) -> () +// CHECK-NEXT: %213 = memref.subview %arg12_3[0] [510] [1] : memref<511xf32> to memref<510xf32> +// CHECK-NEXT: "memref.copy"(%accumulator_3, %213) : (memref<510xf32>, memref<510xf32>) -> () // CHECK-NEXT: } // CHECK-NEXT: "csl.call"() <{"callee" = @for_inc0}> : () -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @for_inc0() { -// CHECK-NEXT: %212 = arith.constant 1 : i16 -// CHECK-NEXT: %213 = "csl.load_var"(%181) : (!csl.var) -> i16 -// CHECK-NEXT: %214 = arith.addi %213, %212 : i16 -// CHECK-NEXT: "csl.store_var"(%181, %214) : (!csl.var, i16) -> () -// CHECK-NEXT: %215 = "csl.load_var"(%182) : (!csl.var>) -> memref<511xf32> -// CHECK-NEXT: %216 = "csl.load_var"(%183) : (!csl.var>) -> memref<511xf32> -// CHECK-NEXT: "csl.store_var"(%182, %216) : (!csl.var>, memref<511xf32>) -> () -// CHECK-NEXT: "csl.store_var"(%183, %215) : (!csl.var>, memref<511xf32>) -> () +// CHECK-NEXT: %214 = arith.constant 1 : i16 +// CHECK-NEXT: %215 = "csl.load_var"(%183) : (!csl.var) -> i16 +// CHECK-NEXT: %216 = arith.addi %215, %214 : i16 +// CHECK-NEXT: "csl.store_var"(%183, %216) : (!csl.var, i16) -> () +// CHECK-NEXT: %217 = "csl.load_var"(%184) : (!csl.var>) -> memref<511xf32> +// CHECK-NEXT: %218 = "csl.load_var"(%185) : (!csl.var>) -> memref<511xf32> +// CHECK-NEXT: "csl.store_var"(%184, %218) : (!csl.var>, memref<511xf32>) -> () +// CHECK-NEXT: "csl.store_var"(%185, %217) : (!csl.var>, memref<511xf32>) -> () // CHECK-NEXT: csl.activate local, 1 : i32 // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @for_post0() { -// CHECK-NEXT: "csl.member_call"(%175) <{"field" = "unblock_cmd_stream"}> : (!csl.imported_module) -> () +// CHECK-NEXT: "csl.member_call"(%177) <{"field" = "unblock_cmd_stream"}> : (!csl.imported_module) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () diff --git a/xdsl/transforms/lower_csl_stencil.py b/xdsl/transforms/lower_csl_stencil.py index 02d7f35cd7..1d429df342 100644 --- a/xdsl/transforms/lower_csl_stencil.py +++ b/xdsl/transforms/lower_csl_stencil.py @@ -5,8 +5,10 @@ from xdsl.dialects import arith, func, memref, stencil from xdsl.dialects.builtin import ( AnyFloatAttr, + AnyMemRefType, ArrayAttr, DenseIntOrFPElementsAttr, + Float16Type, Float32Type, FloatAttr, FunctionType, @@ -388,6 +390,7 @@ class FullStencilAccessImmediateReductionOptimization(RewritePattern): * each access is immediately processed by the same (type of) reduction op * each reduction op uses the same accumulator to store a result * each reduction op uses no inputs except from the above access ops + * if this is inside a loop, we need to zero-out the accumulator buffer either before or after the loop * todo: the data of the accumulator is not itself an input of the reduction * todo: no other ops modify the accumulator in-between reduction ops """ @@ -410,9 +413,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, return # find potential 'reduction' ops - reduction_ops: set[Operation] = set( - u.operation for a in access_ops for u in a.result.uses - ) + reduction_ops = set(u.operation for a in access_ops for u in a.result.uses) # check if reduction ops are of the same type red_op_ts = set(type(r) for r in reduction_ops) @@ -421,7 +422,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, csl.FmulsOp, ]: return - red_ops = cast(set[csl.BuiltinDsdOp], reduction_ops) + reduction_ops = cast(set[csl.BuiltinDsdOp], reduction_ops) # check: only apply rewrite if each access has exactly one use if any(len(a.result.uses) != 1 for a in access_ops): @@ -429,18 +430,18 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, # check: only apply rewrite if reduction ops use `access` ops only (plus one other, checked below) # note, we have already checked that each access op is only consumed once, which by implication is here - red_args = set(arg for r in red_ops for arg in r.ops) + red_args = set(arg for r in reduction_ops for arg in r.ops) nonaccess_args = red_args - set(a.result for a in access_ops) if len(nonaccess_args) > 1: return # check: only apply rewrite if the non-`access` op is an accumulator and the result param in all reduction ops accumulator = nonaccess_args.pop() - if any(accumulator != r.ops[0] for r in red_ops): + if any(accumulator != r.ops[0] for r in reduction_ops): return if ( - not isattr(accumulator.type, memref.MemRefType) + not isattr(accumulator.type, AnyMemRefType) or not isinstance(op.accumulator, OpResult) or not isinstance(alloc := op.accumulator.op, memref.Alloc) ): @@ -488,12 +489,23 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, rewriter.insert_op( [*new_ops, full_stencil_dsd, reduction_op], - InsertPoint.after(list(red_ops)[-1]), + InsertPoint.after(list(reduction_ops)[-1]), ) - for e in [*access_ops, *red_ops]: + for e in [*access_ops, *reduction_ops]: rewriter.erase_op(e, safe_erase=False) + # housekeeping: this strategy requires zeroing out the accumulator iff the apply is inside a loop + assert (elem_t := accumulator.type.get_element_type()) in [ + Float16Type(), + Float32Type(), + ] + zero = arith.Constant(FloatAttr(0.0, elem_t)) + mov_op = csl.FmovsOp if elem_t == Float32Type() else csl.FmovhOp + rewriter.insert_op( + [zero, mov_op(operands=[[op.accumulator, zero]])], InsertPoint.before(op) + ) + @staticmethod def is_full_2d_starshaped_access( offsets: set[tuple[int, ...]], max_offset: int