diff --git a/.gitmodules b/.gitmodules index 57b211151..da9a6bc53 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,7 +3,7 @@ url = https://github.com/rise-lang/executor.git [submodule "lib/yacx"] path = lib/yacx - url = https://github.com/ZerataX/yacx.git + url = https://github.com/rise-lang/yacx.git [submodule "lib/elevate"] path = lib/elevate url = https://github.com/elevate-lang/elevate.git diff --git a/build.sbt b/build.sbt index 3752ce94b..b6c2dbead 100644 --- a/build.sbt +++ b/build.sbt @@ -1,82 +1,65 @@ -ThisBuild / scalaVersion := "2.13.3" -ThisBuild / organization := "org.rise-lang" - -lazy val commonSettings = Seq( - scalacOptions ++= Seq( - "-Wunused:nowarn", - "-Xfatal-warnings", - "-Xlint:-unused", - "-Ymacro-annotations", - "-unchecked", - "-deprecation", - "-feature", - "-language:reflectiveCalls", - ), - fork := true -) - lazy val riseAndShine = (project in file(".")) - .aggregate(executor, CUexecutor) - .dependsOn(meta, arithExpr, executor, CUexecutor, elevate) + .aggregate(meta, executor, CUexecutor) + .dependsOn(arithExpr, executor, CUexecutor, elevate) .settings( name := "riseAndShine", + organization := "org.rise-lang", version := "1.0", + scalaVersion := "3.0.1", javaOptions ++= Seq("-Djava.library.path=lib/yacx/build:lib/executor/lib/Executor/build", "-DexecuteCudaTests=false", "-Xss26m"), - commonSettings, + scalacOptions ++= Seq( + // "-Xfatal-warnings", + // "-rewrite", + "-source:3.0-migration", + // "-indent", + // "-new-syntax", + "-deprecation", + "-feature", + "-unchecked", + "-language:reflectiveCalls", + ), + + fork := true, libraryDependencies ++= Seq( - // scala - "org.scala-lang" % "scala-reflect" % scalaVersion.value, - "org.scala-lang" % "scala-compiler" % scalaVersion.value, - "org.scala-lang" % "scala-library" % scalaVersion.value, - "org.scala-lang.modules" %% "scala-xml" % "1.3.0", - "org.scala-lang.modules" %% "scala-parallel-collections" % "0.2.0", - // testing - "junit" % "junit" % "4.11", - "org.scalatest" %% "scalatest" % "3.1.0" % "test", - "org.apache.logging.log4j" % "log4j-core" % "2.14.1", - "org.apache.logging.log4j" %% "log4j-api-scala" % "12.0", - // json - "com.typesafe.play" %% "play-json" % "2.9.1", - // subprocess communication - "com.lihaoyi" %% "os-lib" % "0.7.3" + "org.scala-lang.modules" %% "scala-parallel-collections" % "1.0.3", + // testing + "junit" % "junit" % "4.11", + "org.scalatest" %% "scalatest" % "3.2.9" % "test", + "org.apache.logging.log4j" % "log4j-core" % "2.14.1", + "org.wvlet.airframe" %% "airframe-log" % "21.5.4", + // os + ("com.lihaoyi" %% "os-lib" % "0.7.8").cross(CrossVersion.for3Use2_13), + // json + ("com.typesafe.play" %% "play-json" % "2.9.2").cross(CrossVersion.for3Use2_13), + // xml + "org.scala-lang.modules" %% "scala-xml" % "2.0.1" ), - compile := ((compile in Compile) dependsOn (generateRISEPrimitives, clap)).value, - test := ((test in Test) dependsOn (generateRISEPrimitives, clap)).value + compile := ((Compile / compile) dependsOn generateRISEPrimitives).value, + test := ((Test / test) dependsOn generateRISEPrimitives).value ) lazy val generateRISEPrimitives = taskKey[Unit]("Generate RISE Primitives") -generateRISEPrimitives := { - runner.value.run("meta.generator.RisePrimitives", - (dependencyClasspath in Compile).value.files, - Seq((scalaSource in Compile).value.getAbsolutePath), - streams.value.log).failed foreach (sys error _.getMessage) -} +generateRISEPrimitives := (Def.taskDyn { + (meta / Compile / runMain).toTask( + " meta.generator.RisePrimitives " + (Compile / scalaSource).value.getAbsolutePath + ) +}).value lazy val generateDPIAPrimitives = taskKey[Unit]("Generate DPIA Primitives") -generateDPIAPrimitives := { - runner.value.run("meta.generator.DPIAPrimitives", - (dependencyClasspath in Compile).value.files, - Seq((scalaSource in Compile).value.getAbsolutePath), - streams.value.log).failed foreach (sys error _.getMessage) -} - -lazy val meta = (project in file("meta")) - .settings( - name := "meta", - version := "1.0", - commonSettings, - libraryDependencies += "com.lihaoyi" %% "fastparse" % "2.2.2", - libraryDependencies += "com.lihaoyi" %% "scalaparse" % "2.2.2", - libraryDependencies += "com.lihaoyi" %% "os-lib" % "0.7.3", - libraryDependencies += "org.scalameta" %% "scalameta" % "4.4.10", +generateDPIAPrimitives := (Def.taskDyn { + (meta / Compile / runMain).toTask( + " meta.generator.DPIAPrimitives " + (Compile / scalaSource).value.getAbsolutePath ) +}).value + +lazy val meta = (project in file("meta")) lazy val arithExpr = project in file("lib/arithexpr") @@ -86,22 +69,11 @@ lazy val CUexecutor = project in file("lib/yacx") lazy val elevate = project in file("lib/elevate") -lazy val docs = (project in file("riseAndShine-docs")) +lazy val docs = (project in file("riseAndShine-docs")) .settings( moduleName := "riseAndShine-docs", mdocOut := file("docs-website/docs"), + scalaVersion := "3.0.0", ) .enablePlugins(MdocPlugin) .dependsOn(riseAndShine) - -lazy val clap = taskKey[Unit]("Builds Clap library") - -clap := { - import scala.language.postfixOps - import scala.sys.process._ - //noinspection PostfixMethodCall - "echo y" #| (baseDirectory.value + "/lib/clap/buildClap.sh") ! -} - - - diff --git a/lib/arithexpr b/lib/arithexpr index 40b70e34a..f6206db6d 160000 --- a/lib/arithexpr +++ b/lib/arithexpr @@ -1 +1 @@ -Subproject commit 40b70e34a56bdddf7ce6de7d8a3ed949167cfd32 +Subproject commit f6206db6d438d397d2af3b9b4ad0fc0651784c8c diff --git a/lib/elevate b/lib/elevate index 6f74f57f4..d29d56948 160000 --- a/lib/elevate +++ b/lib/elevate @@ -1 +1 @@ -Subproject commit 6f74f57f49b99efa7d914e42b2faaed610fcaef4 +Subproject commit d29d5694874eb374b6637561ab15936ac9f4ae9f diff --git a/lib/executor b/lib/executor index 0deb6cba9..f16b0cf90 160000 --- a/lib/executor +++ b/lib/executor @@ -1 +1 @@ -Subproject commit 0deb6cba9970f710f54c043567f0568e8f7e3dc5 +Subproject commit f16b0cf90059cb226eddf0ed8a206fce7e2dba82 diff --git a/lib/yacx b/lib/yacx index da81fe8f8..deeec1d25 160000 --- a/lib/yacx +++ b/lib/yacx @@ -1 +1 @@ -Subproject commit da81fe8f814151b1c2024617f0bc9891f210cd84 +Subproject commit deeec1d25b79cb691ca88f0dfeb93bd8fc65bea7 diff --git a/meta/build.sbt b/meta/build.sbt new file mode 100644 index 000000000..cdf53e4ec --- /dev/null +++ b/meta/build.sbt @@ -0,0 +1,21 @@ +lazy val meta = (project in file(".")) + .settings( + name := "meta", + version := "1.0", + scalaVersion := "2.13.6", + scalacOptions ++= Seq( + "-Wunused:nowarn", + "-Xfatal-warnings", + "-Xlint:-unused", + "-Ymacro-annotations", + "-unchecked", + "-deprecation", + "-feature", + "-language:reflectiveCalls", + ), + fork := true, + libraryDependencies += "com.lihaoyi" %% "fastparse" % "2.2.2", + libraryDependencies += "com.lihaoyi" %% "scalaparse" % "2.2.2", + libraryDependencies += "com.lihaoyi" %% "os-lib" % "0.7.8", + libraryDependencies += "org.scalameta" %% "scalameta" % "4.4.10", + ) diff --git a/project/build.properties b/project/build.properties index 0b2e09c5a..10fd9eee0 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.4.7 +sbt.version=1.5.5 diff --git a/project/plugins.sbt b/project/plugins.sbt index 76d8825fd..b1da9e3f5 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,3 +1,3 @@ -addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "1.4.0-RC1") -addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.3.0") -addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.2.17") +addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "1.4.8") +addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.2") +addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.2.21") \ No newline at end of file diff --git a/src/main/scala/apps/cameraPipelineRewrite.scala b/src/main/scala/apps/cameraPipelineRewrite.scala index e27aad62a..b51290e16 100644 --- a/src/main/scala/apps/cameraPipelineRewrite.scala +++ b/src/main/scala/apps/cameraPipelineRewrite.scala @@ -46,7 +46,7 @@ object cameraPipelineRewrite { } def takeDropTowardsInput: Strategy[Rise] = { - normalize.apply( + normalize( gentleBetaReduction() <+ etaReduction() <+ takeAll <+ dropNothing <+ mapFusion <+ mapIdentity <+ @@ -161,7 +161,7 @@ object cameraPipelineRewrite { case _ => Failure(idxAfterF) } - def normalizeSingleInput: Strategy[Rise] = normalize.apply( + def normalizeSingleInput: Strategy[Rise] = normalize( dropBeforeTake <+ dropBeforeMap <+ takeBeforeMap <+ slideBeforeMap <+ mapFusion // <+ TODO // (not(isAppliedMap) `;` idxAfterF `;` debugS("idx")) @@ -293,13 +293,13 @@ object cameraPipelineRewrite { argument(argument(normalizeInput) `;` repeat(mapFusion)) } - def stronglyReducedForm: Strategy[Rise] = normalize.apply( + def stronglyReducedForm: Strategy[Rise] = normalize( betaReduction <+ etaReduction() <+ removeTransposePair <+ mapFusion <+ idxReduction <+ fstReduction <+ sndReduction ) - def gentlyReducedForm: Strategy[Rise] = normalize.apply( + def gentlyReducedForm: Strategy[Rise] = normalize( gentleBetaReduction() <+ etaReduction() <+ removeTransposePair <+ mapFusion <+ idxReduction <+ fstReduction <+ sndReduction @@ -307,7 +307,7 @@ object cameraPipelineRewrite { def demosaicCircularBuffers: Strategy[Rise] = { rewriteSteps(scala.collection.Seq( - normalize.apply(gentleBetaReduction()), + normalize(gentleBetaReduction()), takeDropTowardsInput, @@ -373,7 +373,7 @@ object cameraPipelineRewrite { def precomputeSharpenStrengthX32: Strategy[Rise] = { // |> toMem() |> letf(fun(strength_x32 => - normalize.apply(gentleBetaReduction()) `;` + normalize(gentleBetaReduction()) `;` afterTopLevel( function(argument( // sharpen ??? @@ -399,7 +399,7 @@ object cameraPipelineRewrite { } def precomputeColorCorrectionMatrix: Strategy[Rise] = { - normalize.apply(gentleBetaReduction()) `;` + normalize(gentleBetaReduction()) `;` afterTopLevel( argument(argument({ case expr @ App(Lambda(x, color_correct), matrix) => @@ -407,14 +407,14 @@ object cameraPipelineRewrite { p.mapSeq(p.mapSeq(fun(x => x)))(matrix)) !: expr.t) case _ => Failure(precomputeColorCorrectionMatrix) })) `;` - normalize.apply(gentleBetaReduction() <+ letHoist) + normalize(gentleBetaReduction() <+ letHoist) ) } def precomputeCurve: Strategy[Rise] = { // TODO: apply_curve curve: // |> mapSeq(fun(x => x)) |> letf(fun(curve => - normalize.apply(gentleBetaReduction()) `;` + normalize(gentleBetaReduction()) `;` afterTopLevel( argument(function(argument( topDown( diff --git a/src/main/scala/apps/convolution.scala b/src/main/scala/apps/convolution.scala index 9d45d1242..3dcfbc5eb 100644 --- a/src/main/scala/apps/convolution.scala +++ b/src/main/scala/apps/convolution.scala @@ -11,6 +11,7 @@ import rise.openCL.DSL._ import rise.openCL.primitives.oclReduceSeqUnroll import util.{Time, TimeSpan} import shine.OpenCL.KernelExecutor._ +import reflect.Selectable.reflectiveSelectable object convolution { private val id = fun(x => x) diff --git a/src/main/scala/apps/gemv.scala b/src/main/scala/apps/gemv.scala index b9f8642c3..35b9ec98d 100644 --- a/src/main/scala/apps/gemv.scala +++ b/src/main/scala/apps/gemv.scala @@ -7,6 +7,7 @@ import rise.core.primitives.{let => _, _} import rise.core.types._ import rise.core.types.DataType._ import HighLevelConstructs.reorderWithStride +import reflect.Selectable.reflectiveSelectable object gemv { // we can use implicit type parameters and type annotations to specify the function type of mult diff --git a/src/main/scala/apps/harrisCornerDetection.scala b/src/main/scala/apps/harrisCornerDetection.scala index 5b222b51f..beab0d8bf 100644 --- a/src/main/scala/apps/harrisCornerDetection.scala +++ b/src/main/scala/apps/harrisCornerDetection.scala @@ -9,6 +9,7 @@ import rise.openCL.DSL._ import rise.openCL.primitives.oclRotateValues import util.gen import shine.OpenCL.KernelExecutor._ +import reflect.Selectable.reflectiveSelectable import scala.reflect.ClassTag diff --git a/src/main/scala/apps/harrisCornerDetectionHalideRewrite.scala b/src/main/scala/apps/harrisCornerDetectionHalideRewrite.scala index e98914a49..e82c42355 100644 --- a/src/main/scala/apps/harrisCornerDetectionHalideRewrite.scala +++ b/src/main/scala/apps/harrisCornerDetectionHalideRewrite.scala @@ -32,16 +32,16 @@ object harrisCornerDetectionHalideRewrite { }) } - val unrollDots: Strategy[Rise] = normalize.apply(lowering.reduceSeqUnroll) + val unrollDots: Strategy[Rise] = normalize(lowering.reduceSeqUnroll) def someGentleReduction: Strategy[Rise] = gentleBetaReduction() <+ etaReduction() <+ idxReduction <+ fstReduction <+ sndReduction <+ removeTransposePair - def reducedFusedForm: Strategy[Rise] = normalize.apply( + def reducedFusedForm: Strategy[Rise] = normalize( someGentleReduction <+ mapFusion ) - def reducedFissionedForm: Strategy[Rise] = normalize.apply( + def reducedFissionedForm: Strategy[Rise] = normalize( someGentleReduction <+ mapLastFission() ) @@ -95,7 +95,7 @@ object harrisCornerDetectionHalideRewrite { depFunction(isEqualTo(rise.core.primitives.padEmpty.primitive)) object ocl { - val unrollDots: Strategy[Rise] = normalize.apply( + val unrollDots: Strategy[Rise] = normalize( lowering.ocl.reduceSeqUnroll(AddressSpace.Private)) val lineBuffer: Strategy[Rise] = lowering.ocl.circularBuffer(AddressSpace.Global) @@ -114,7 +114,7 @@ object harrisCornerDetectionHalideRewrite { afterTopLevel( // zip unzip simplification topDown(argument(isAppliedUnzip) `;` betaReduction) `;` - normalize.apply( + normalize( someGentleReduction <+ mapFusion <+ zipUnzipAccessSimplification <+ mapProjZipUnification() ) @@ -123,7 +123,7 @@ object harrisCornerDetectionHalideRewrite { def harrisIxWithIy: Strategy[Rise] = afterTopLevel(afterDefs( - normalize.apply( + normalize( someGentleReduction <+ takeOutsidePair <+ vectorize.asScalarOutsidePair ) `;` @@ -139,7 +139,7 @@ object harrisCornerDetectionHalideRewrite { ): Strategy[Rise] = { topDown(lowering.iterateStream) `;` repeatNTimes(2)(argumentsTd(function(lineBuffer))) `;` - normalize.apply(lowering.ocl.circularBufferLoadFusion) `;` + normalize(lowering.ocl.circularBufferLoadFusion) `;` reducedFusedForm `;` argument(argument(topDown(lowering.mapSeq))) `;` argument(function(argument( @@ -196,10 +196,10 @@ object harrisCornerDetectionHalideRewrite { def vectorizeRoundUpAndNormalize(vwidth: Int): Strategy[Rise] = { vectorize.roundUpAfter(vwidth) `;` - normalize.apply(padEmptyBeforeMap <+ padEmptyBeforeTranspose) + normalize(padEmptyBeforeMap <+ padEmptyBeforeTranspose) } - def normalizeVectorized: Strategy[Rise] = normalize.apply( + def normalizeVectorized: Strategy[Rise] = normalize( someGentleReduction <+ mapFusion <+ transposeBeforeMapJoin <+ mapMapFBeforeTranspose() <+ vectorize.beforeMap @@ -207,7 +207,7 @@ object harrisCornerDetectionHalideRewrite { def vectorizeReductions(vwidth: Int): Strategy[Rise] = { afterTopLevel( - normalize.apply( + normalize( isAppliedMap `;` topDown(reduceMapFusion) `;` reducedFissionedForm `;` ( @@ -217,7 +217,7 @@ object harrisCornerDetectionHalideRewrite { topDown(vectorize.beforeMapDot) `;` normalizeVectorized ) `;` - normalize.apply( + normalize( isAppliedMap `;` function(argument(isReduceFI <+ body(isAppliedReduce))) `;` reducedFissionedForm `;` @@ -225,7 +225,7 @@ object harrisCornerDetectionHalideRewrite { argument(argument(vectorize.beforeMapReduce)) `;` normalizeVectorized ) `;` - normalize.apply( + normalize( takeOutisdeZip <+ takeAfterMap <+ removeTakeBeforePadEmpty ) `;` @@ -235,7 +235,7 @@ object harrisCornerDetectionHalideRewrite { argument(argument(isAppliedZip)) `;` vectorize.after(vwidth) `;` argument(vectorize.beforeMap) `;` - normalize.apply( + normalize( someGentleReduction <+ mapFusion <+ unzipZipIsPair <+ vectorize.asScalarAsVectorId ) @@ -245,7 +245,7 @@ object harrisCornerDetectionHalideRewrite { def movePadEmpty: Strategy[Rise] = afterTopLevel( - normalize.apply( + normalize( someGentleReduction <+ mapFusion <+ padEmptyBeforeTranspose <+ padEmptyBeforeMap <+ padEmptyBeforeSlide <+ padEmptyBeforeZip <+ @@ -279,7 +279,7 @@ object harrisCornerDetectionHalideRewrite { topDown(slideBeforeMapMapF) ) `;` reducedFusedForm `;` - normalize.apply( + normalize( someGentleReduction <+ mapFstBeforeMapSnd <+ mapFstFusion <+ mapSndFusion <+ removeTakeBeforePadEmpty <+ padEmptyFusion @@ -290,19 +290,19 @@ object harrisCornerDetectionHalideRewrite { afterTopLevel( topDown( isAppliedZip `;` argument(isAppliedZip) `;` - normalize.apply( + normalize( someGentleReduction <+ mapFusion <+ transposeBeforeMapJoin <+ slideBeforeMap <+ mapMapFBeforeTranspose() ) `;` - normalize.apply( + normalize( someGentleReduction <+ mapLastFission() <+ mapMapFBeforeJoin ) `;` - normalize.apply( + normalize( someGentleReduction <+ mapFusion <+ vectorize.beforeMap <+ slideBeforeMap ) `;` - normalize.apply( + normalize( someGentleReduction <+ mapLastFission() <+ mapMapFBeforeTranspose() ) `;` @@ -350,18 +350,18 @@ object harrisCornerDetectionHalideRewrite { def alignLoads: Strategy[Rise] = afterTopLevel( - normalize.apply( + normalize( isAppliedMap `;` argument(function(isEqualToUntyped(rise.core.primitives.transpose.primitive))) `;` reducedFissionedForm `;` topDown(vectorize.alignSlide) `;` reducedFusedForm `;` - normalize.apply( + normalize( someGentleReduction <+ mapFusion <+ vectorize.padEmptyBeforeAsScalar <+ vectorize.asScalarAsVectorId <+ padEmptyBeforeMap <+ padEmptyBeforeTranspose <+ removeTakeBeforePadEmpty ) `;` - normalize.apply( + normalize( someGentleReduction <+ mapLastFission() <+ mapMapFBeforeTranspose() ) `;` reducedFusedForm @@ -382,7 +382,7 @@ object harrisCornerDetectionHalideRewrite { afterTopLevel( topDown(lowering.mapGlobal()) `;` topDown(harrisBufferedLowering()) `;` - normalize.apply(vectorize.mapAfterShuffle) `;` + normalize(vectorize.mapAfterShuffle) `;` storeSlidingWindowsToPrivate ) )) @@ -394,7 +394,7 @@ object harrisCornerDetectionHalideRewrite { reducedFissionedForm `;` topDown(mapSlideBeforeTranspose) `;` reducedFusedForm `;` reducedFissionedForm `;` - normalize.apply(slideBeforeMapMapF) `;` + normalize(slideBeforeMapMapF) `;` reducedFusedForm def separateReductions: Strategy[Rise] = diff --git a/src/main/scala/apps/kmeans.scala b/src/main/scala/apps/kmeans.scala index 6e3ee6303..7627310d2 100644 --- a/src/main/scala/apps/kmeans.scala +++ b/src/main/scala/apps/kmeans.scala @@ -9,6 +9,7 @@ import rise.core.types.DataType._ import rise.openCL.DSL._ import rise.openCL.primitives.oclReduceSeq import shine.OpenCL.KernelExecutor._ +import reflect.Selectable.reflectiveSelectable object kmeans { private val update = fun(f32 ->: (f32 x f32) ->: f32)((dist, pair) => diff --git a/src/main/scala/apps/mm.scala b/src/main/scala/apps/mm.scala index 252a5f995..d646ee27b 100644 --- a/src/main/scala/apps/mm.scala +++ b/src/main/scala/apps/mm.scala @@ -8,6 +8,7 @@ import rise.core.types._ import rise.core.types.DataType._ import rise.openCL.DSL._ import rise.openCL.primitives.oclReduceSeq +import reflect.Selectable.reflectiveSelectable object mm { private val id = fun(x => x) diff --git a/src/main/scala/apps/molecularDynamics.scala b/src/main/scala/apps/molecularDynamics.scala index c500d1622..ed3770436 100644 --- a/src/main/scala/apps/molecularDynamics.scala +++ b/src/main/scala/apps/molecularDynamics.scala @@ -8,6 +8,7 @@ import rise.core.types._ import rise.core.types.DataType._ import rise.openCL.DSL._ import rise.openCL.primitives.oclReduceSeq +import reflect.Selectable.reflectiveSelectable object molecularDynamics { private val mdCompute = foreignFun("updateF", diff --git a/src/main/scala/apps/mriQ.scala b/src/main/scala/apps/mriQ.scala index 6966f5ef4..c23edf902 100644 --- a/src/main/scala/apps/mriQ.scala +++ b/src/main/scala/apps/mriQ.scala @@ -8,6 +8,7 @@ import rise.core.types._ import rise.core.types.DataType._ import rise.openCL.DSL._ import rise.openCL.primitives.oclReduceSeq +import reflect.Selectable.reflectiveSelectable object mriQ { private val phiMag = foreignFun("phiMag", diff --git a/src/main/scala/apps/nbody.scala b/src/main/scala/apps/nbody.scala index d91b61b7c..e012010ff 100644 --- a/src/main/scala/apps/nbody.scala +++ b/src/main/scala/apps/nbody.scala @@ -8,6 +8,7 @@ import rise.core.types._ import rise.core.types.DataType._ import rise.openCL.DSL._ import rise.openCL.primitives.oclReduceSeq +import reflect.Selectable.reflectiveSelectable object nbody { private val id = fun(x => x) diff --git a/src/main/scala/apps/nearestNeighbour.scala b/src/main/scala/apps/nearestNeighbour.scala index 4d8a89f1b..18034e8c4 100644 --- a/src/main/scala/apps/nearestNeighbour.scala +++ b/src/main/scala/apps/nearestNeighbour.scala @@ -7,6 +7,7 @@ import rise.core.types._ import rise.core.types.DataType._ import rise.core.primitives._ import rise.openCL.DSL._ +import reflect.Selectable.reflectiveSelectable object nearestNeighbour { private val distance = foreignFun("distance_", diff --git a/src/main/scala/exploration/strategies/defaultStrategies.scala b/src/main/scala/exploration/strategies/defaultStrategies.scala index f17b2145a..a28eca2d8 100644 --- a/src/main/scala/exploration/strategies/defaultStrategies.scala +++ b/src/main/scala/exploration/strategies/defaultStrategies.scala @@ -22,14 +22,14 @@ import rise.elevate.NormalizedThen object defaultStrategies { val outermost: (Strategy[Rise]) => (Strategy[Rise]) => Strategy[Rise] = - traversal.outermost(default.RiseTraversable) + traversal.outermost(using default.RiseTraversable) val innermost: (Strategy[Rise]) => (Strategy[Rise]) => Strategy[Rise] = - traversal.innermost(default.RiseTraversable) + traversal.innermost(using default.RiseTraversable) // -- BASELINE --------------------------------------------------------------- - val baseline: Strategy[Rise] = DFNF()(default.RiseTraversable) `;` + val baseline: Strategy[Rise] = DFNF()(using default.RiseTraversable) `;` fuseReduceMap `@` topDown[Rise] @@ -69,7 +69,7 @@ object defaultStrategies { val permuteB: Strategy[Rise] = splitJoin2(32) `;` DFNF() `;` argument(idAfter) `;` topDown(liftId()) `;` topDown(createTransposePair) `;` RNF() `;` - argument(argument(idAfter)) `;` normalize.apply(liftId()) `;` + argument(argument(idAfter)) `;` normalize(liftId()) `;` topDown(idToCopy) val packB: Strategy[Rise] = diff --git a/src/main/scala/rise/autotune/errors.scala b/src/main/scala/rise/autotune/errors.scala index 886151ed5..c365ec668 100644 --- a/src/main/scala/rise/autotune/errors.scala +++ b/src/main/scala/rise/autotune/errors.scala @@ -8,3 +8,5 @@ case object SUBSTITUTION_ERROR extends AutoTuningErrorLevel case object CODE_GENERATION_ERROR extends AutoTuningErrorLevel case object COMPILATION_ERROR extends AutoTuningErrorLevel case object EXECUTION_ERROR extends AutoTuningErrorLevel + +case class AutoTuningError(errorLevel: AutoTuningErrorLevel, message: Option[String]) \ No newline at end of file diff --git a/src/main/scala/rise/autotune/package.scala b/src/main/scala/rise/autotune/package.scala index fac2a17fe..c528fa09d 100644 --- a/src/main/scala/rise/autotune/package.scala +++ b/src/main/scala/rise/autotune/package.scala @@ -35,7 +35,6 @@ package object autotune { // todo add meta information (configuration, times, samples, ...) case class TuningResult(samples: Seq[Sample]) - case class AutoTuningError(errorLevel: AutoTuningErrorLevel, message: Option[String]) type Parameters = Set[NatIdentifier] // should we allow tuning params to be substituted during type inference? diff --git a/src/main/scala/rise/core/Builder.scala b/src/main/scala/rise/core/Builder.scala index 25ef2047c..9b483bedf 100644 --- a/src/main/scala/rise/core/Builder.scala +++ b/src/main/scala/rise/core/Builder.scala @@ -1,6 +1,7 @@ package rise.core import types._ +import rise.core.DSL trait Builder { def apply: DSL.ToBeTyped[Primitive] = diff --git a/src/main/scala/rise/core/DSL/package.scala b/src/main/scala/rise/core/DSL/package.scala index 01c9329b2..38888a701 100644 --- a/src/main/scala/rise/core/DSL/package.scala +++ b/src/main/scala/rise/core/DSL/package.scala @@ -536,7 +536,7 @@ package object DSL { toBeTyped(topLevel(toExpr(d))) def eraseTypeFromExpr[T <: Expr](e: T): T = - traverse(e, new PureExprTraversal { + rise.core.traverse.traverse(e, new PureExprTraversal { override def identifier[I <: Identifier] : VarType => I => Pure[I] = vt => i => return_(i.setType(TypePlaceholder).asInstanceOf[I]) override def expr : Expr => Pure[Expr] = { diff --git a/src/main/scala/rise/core/IsClosedForm.scala b/src/main/scala/rise/core/IsClosedForm.scala index cf8846664..e2d747b01 100644 --- a/src/main/scala/rise/core/IsClosedForm.scala +++ b/src/main/scala/rise/core/IsClosedForm.scala @@ -105,12 +105,12 @@ object IsClosedForm { } def freeVars(expr: Expr): (OrderedSet[Identifier], OrderedSet[Kind.Identifier]) = { - val ((fV, fT), _) = traverse(expr, Visitor(Set(), Set())) + val ((fV, fT), _) = rise.core.traverse.traverse(expr, Visitor(Set(), Set())) (fV, fT) } def freeVars(t: ExprType): OrderedSet[Kind.Identifier] = { - val ((_, ftv), _) = traverse(t, Visitor(Set(), Set())) + val ((_, ftv), _) = rise.core.traverse.traverse(t, Visitor(Set(), Set())) ftv } diff --git a/src/main/scala/rise/core/package.scala b/src/main/scala/rise/core/package.scala index 7ec3ce164..5b5708d77 100644 --- a/src/main/scala/rise/core/package.scala +++ b/src/main/scala/rise/core/package.scala @@ -49,4 +49,5 @@ package object core { implicit def primitiveBuilderToPrimitive(pb: Builder ): DSL.ToBeTyped[Primitive] = pb.apply + } diff --git a/src/main/scala/rise/core/substitute.scala b/src/main/scala/rise/core/substitute.scala index ce69486d1..8a3d25cf5 100644 --- a/src/main/scala/rise/core/substitute.scala +++ b/src/main/scala/rise/core/substitute.scala @@ -125,7 +125,7 @@ object substitute { if (`for` =~= t) { return_(ty.asInstanceOf[T]) } else super.`type`(t) } } - traverse(in, Visitor) + rise.core.traverse.traverse(in, Visitor) } def natsInType[T <: ExprType](subs: Map[Nat, Nat], in: T): T = { @@ -133,7 +133,7 @@ object substitute { override def nat: Nat => Pure[Nat] = in1 => return_(substitute.natsInNat(subs, in1)) } - traverse(in, Visitor) + rise.core.traverse.traverse(in, Visitor) } def natInType[T <: ExprType](n: Nat, `for`: Nat, in: T): T = @@ -152,7 +152,7 @@ object substitute { override def addressSpace: AddressSpace => Pure[AddressSpace] = b => if (`for` == b) return_(a) else super.addressSpace(b) } - traverse(in, Visitor) + rise.core.traverse.traverse(in, Visitor) } def n2nInType[T <: ExprType](n2n: NatToNat, `for`: NatToNatIdentifier, in: T ): T = { @@ -160,7 +160,7 @@ object substitute { override def natToNat: NatToNat => Pure[NatToNat] = n => if (`for` == n) return_(n2n) else super.natToNat(n) } - traverse(in, Visitor) + rise.core.traverse.traverse(in, Visitor) } def n2dInType[T <: ExprType](n2d: NatToData, `for`: NatToDataIdentifier, in: T): T = { @@ -168,7 +168,7 @@ object substitute { override def natToData: NatToData => Pure[NatToData] = n => if (`for` == n) return_(n2d) else super.natToData(n) } - traverse(in, Visitor) + rise.core.traverse.traverse(in, Visitor) } // substitute in Nat diff --git a/src/main/scala/rise/core/traverse.scala b/src/main/scala/rise/core/traverse.scala index aadfdad2a..ef6354824 100644 --- a/src/main/scala/rise/core/traverse.scala +++ b/src/main/scala/rise/core/traverse.scala @@ -196,8 +196,8 @@ object traverse { trait PureTraversal extends Traversal[Pure] {override def monad : PureMonad.type = PureMonad } trait PureExprTraversal extends PureTraversal with ExprTraversal[Pure] - trait AccumulatorTraversal[F,M[_]] extends Traversal[InMonad[M]#SetFst[F]#Type] { - type Pair[T] = InMonad[M]#SetFst[F]#Type[T] + trait AccumulatorTraversal[F,M[_]] extends Traversal[[S] =>> M[(F, S)]] { + type Pair[T] = M[(F, T)] implicit val accumulator : Monoid[F] implicit val wrapperMonad : Monad[M] def accumulate[T] : F => T => Pair[T] = f => t => wrapperMonad.return_((f, t)) diff --git a/src/main/scala/rise/core/uniqueNames.scala b/src/main/scala/rise/core/uniqueNames.scala index b73b86316..37d2dee9a 100644 --- a/src/main/scala/rise/core/uniqueNames.scala +++ b/src/main/scala/rise/core/uniqueNames.scala @@ -19,7 +19,7 @@ object uniqueNames { } def check(e: Expr): Boolean = { - val ((vs, ts), _) = traverse(e, collectNames) + val ((vs, ts), _) = rise.core.traverse.traverse(e, collectNames) vs == vs.distinct && ts == ts.distinct } diff --git a/src/main/scala/rise/elevate/meta/fission.scala b/src/main/scala/rise/elevate/meta/fission.scala index b2ca1359d..5d4ab11b9 100644 --- a/src/main/scala/rise/elevate/meta/fission.scala +++ b/src/main/scala/rise/elevate/meta/fission.scala @@ -3,29 +3,29 @@ package rise.elevate.meta import elevate.core.strategies.Traversable import elevate.core.strategies.basic._ import elevate.core.{Strategy, Success} -import elevate.macros.RuleMacro.rule +import elevate.core.macros.rule import rise.elevate.Rise import rise.elevate.rules.traversal.{argument, argumentOf, body, function} object fission { - @rule def bodyFission: Strategy[Strategy[Rise]] = { - case body(Seq(f, s)) => Success(seq(body(f))(body(s))) - } + def bodyFission: Strategy[Strategy[Rise]] = rule("bodyFission", { + case body(Seq(f: Strategy[Rise]@unchecked, s: Strategy[Rise]@unchecked)) => Success(seq(body(f))(body(s))) + }) - @rule def functionFission: Strategy[Strategy[Rise]] = { - case function(Seq(f,s)) => Success(seq(function(f))(function(s))) - } + def functionFission: Strategy[Strategy[Rise]] = rule("functionFission", { + case function(Seq(f: Strategy[Rise]@unchecked, s: Strategy[Rise]@unchecked)) => Success(seq(function(f))(function(s))) + }) - @rule def argumentFission: Strategy[Strategy[Rise]] = { - case argument(Seq(f,s)) => Success(seq(argument(f))(argument(s))) - } + def argumentFission: Strategy[Strategy[Rise]] = rule("argumentFission", { + case argument(Seq(f: Strategy[Rise]@unchecked, s: Strategy[Rise]@unchecked)) => Success(seq(argument(f))(argument(s))) + }) - @rule def argumentOfFission: Strategy[Strategy[Rise]] = { - case argumentOf(x,Seq(f,s)) => Success(seq(argumentOf(x,f))(argumentOf(x,s))) - } + def argumentOfFission: Strategy[Strategy[Rise]] = rule("argumentOfFission", { + case argumentOf(x,Seq(f: Strategy[Rise]@unchecked, s: Strategy[Rise]@unchecked)) => Success(seq(argumentOf(x,f))(argumentOf(x,s))) + }) // Fissioned-Normal-Form: Every single strategy application starts from the root - def FNF(implicit ev: Traversable[Strategy[Rise]]): Strategy[Strategy[Rise]] = - normalize(ev)(bodyFission <+ functionFission <+ argumentFission <+ argumentOfFission) + def FNF(using ev: Traversable[Strategy[Rise]]): Strategy[Strategy[Rise]] = + normalize(bodyFission <+ functionFission <+ argumentFission <+ argumentOfFission) } diff --git a/src/main/scala/rise/elevate/rules/algorithmic.scala b/src/main/scala/rise/elevate/rules/algorithmic.scala index 656e2f77b..f2fa4de69 100644 --- a/src/main/scala/rise/elevate/rules/algorithmic.scala +++ b/src/main/scala/rise/elevate/rules/algorithmic.scala @@ -2,18 +2,20 @@ package rise.elevate.rules import arithexpr.arithmetic.{ArithExpr, Cst} import elevate.core._ -import elevate.core.strategies.Traversable +import elevate.core.macros.rule +import elevate.core.RewriteResult._ import elevate.core.strategies.predicate._ import elevate.core.strategies.traversal.tryAll -import elevate.macros.RuleMacro.rule +import rise.elevate._ +import rise.elevate.strategies.normalForm.DFNF +import rise.elevate.strategies.predicate._ import rise.core.DSL._ import rise.core._ import rise.core.primitives._ -import rise.core.types.DataType._ import rise.core.types._ -import rise.elevate._ -import rise.elevate.strategies.normalForm.DFNF -import rise.elevate.strategies.predicate._ +import rise.core.types.DataType._ +import elevate.core.strategies.Traversable + // noinspection MutatorLikeMethodIsParameterless object algorithmic { @@ -28,44 +30,44 @@ object algorithmic { // divide & conquer def splitJoin(n: Nat): Strategy[Rise] = `*f -> S >> **f >> J`(n: Nat) - @rule def `*f -> S >> **f >> J`(n: Nat): Strategy[Rise] = { + def `*f -> S >> **f >> J`(n: Nat): Strategy[Rise] = rule("*f -> S >> **f >> J", { case e @ App(map(), f) => Success((split(n) >> map(map(f)) >> join) !: e.t) - } + }) - @rule def splitJoin2(n: Nat): Strategy[Rise] = e => e.t match { + def splitJoin2(n: Nat): Strategy[Rise] = rule("splitJoin2", e => e.t match { case ArrayType(_,_) => Success( (toBeTyped(e) |> split(n) |> join) !: e.t ) case _ => Failure(splitJoin2(n)) - } + }) // fusion / fission def mapFusion: Strategy[Rise] = `*g >> *f -> *(g >> f)` - @rule def `*g >> *f -> *(g >> f)`: Strategy[Rise] = { + def `*g >> *f -> *(g >> f)`: Strategy[Rise] = rule("*g >> *f -> *(g >> f)", { case e @ App(App(map(), f), App(App(map(), g), arg)) => Success(map(preserveType(g) >> f)(arg) !: e.t) - } + }) // mapFst g >> mapFst f -> mapFst (g >> f) - @rule def mapFstFusion: Strategy[Rise] = { + def mapFstFusion: Strategy[Rise] = rule("mapFstFusion", { case e @ App(App(mapFst(), f), App(App(mapFst(), g), in)) => Success(mapFst(preserveType(g) >> f)(in) !: e.t) - } + }) // mapSnd g >> mapSnd f -> mapSnd (g >> f) - @rule def mapSndFusion: Strategy[Rise] = { + def mapSndFusion: Strategy[Rise] = rule("mapSndFusion", { case e @ App(App(mapSnd(), f), App(App(mapSnd(), g), in)) => Success(mapSnd(preserveType(g) >> f)(in) !: e.t) - } + }) // padEmpty n >> padEmpty m -> padEmpty n + m - @rule def padEmptyFusion: Strategy[Rise] = { + def padEmptyFusion: Strategy[Rise] = rule("padEmptyFusion", { case e @ App(DepApp(NatKind, padEmpty(), m: Nat), App(DepApp(NatKind, padEmpty(), n: Nat), in)) => Success(padEmpty(n+m)(in) !: e.t) - } + }) def `map >> reduce -> reduce`: Strategy[Rise] = reduceMapFusion // *g >> reduce f init -> reduce (acc, x => f acc (g x)) init - @rule def reduceMapFusion: Strategy[Rise] = { + def reduceMapFusion: Strategy[Rise] = rule("reduceMapFusion", { case e @ App(App(App(r @ ReduceX(), f), init), App(App(map(), g), in)) => val red = (r, g.t) match { case (reduce(), FunType(i, o)) if i =~= o => reduce @@ -73,23 +75,23 @@ object algorithmic { } Success(red(fun(acc => fun(x => preserveType(f)(acc)(preserveType(g)(x)))))(init)(in) !: e.t) - } + }) def fuseReduceMap: Strategy[Rise] = reduceMapFusion - @rule def reduceMapFission()(implicit ev: Traversable[Rise]): Strategy[Rise] = { + def reduceMapFission()(implicit ev: Traversable[Rise]): Strategy[Rise] = rule("reduceMapFission", { case e @ App(App(ReduceX(), Lambda(acc, Lambda(y, App(App(op, acc2), f@App(_, y2))))), init) if acc =~= acc2 && contains[Rise](y).apply(y2) => Success((reduce(op)(init) o map(lambda(ToBeTyped[Identifier](y), preserveType(f)))) !: e.t ) - } + }) // fission of the last function to be applied inside a map // *(g >> .. >> f) -> *(g >> ..) >> *f - @rule def mapLastFission()(implicit ev: Traversable[Rise]): Strategy[Rise] = { + def mapLastFission()(implicit ev: Traversable[Rise]): Strategy[Rise] = rule("mapLastFission", { // this is an example where we don't want to fission if gx == Identifier: // (map λe4. (((((zip: (K.float -> (K.float -> K.(float, float)))) // (e3: K.float)): (K.float -> K.(float, float))) @@ -104,52 +106,52 @@ object algorithmic { Success((app(map, lambda(eraseType(x), gx)) >> map(f)) !: e.t) case _ => Failure(mapLastFission()) } - } + }) // identities - @rule def idAfter: Strategy[Rise] = e => Success((preserveType(e) |> id) !: e.t) + def idAfter: Strategy[Rise] = rule("idAfter", e => Success((preserveType(e) |> id) !: e.t)) - @rule def idToCopy: Strategy[Rise] = { + def idToCopy: Strategy[Rise] = rule("idToCopy", { case App(id() ::: FunType(in: ScalarType, out: ScalarType), arg ::: (argT: ScalarType)) if in =~= out && in =~= argT => Success(fun(x => x) $ arg) - } + }) - @rule def liftId()(implicit ev: Traversable[Rise]): Strategy[Rise] = { - case App(id() ::: FunType(ArrayType(_, _), _), arg) => Success(DFNF()(ev)((map(id) $ arg)).get) - } + def liftId()(using ev: Traversable[Rise]): Strategy[Rise] = rule("liftId", { + case App(id() ::: FunType(ArrayType(_, _), _), arg) => Success(DFNF()((map(id) $ arg)).get) + }) - @rule def createTransposePair: Strategy[Rise] = { + def createTransposePair: Strategy[Rise] = rule("createTransposePair", { case e @ App(id(), arg) => Success(app(transpose >> transpose, arg) !: e.t) - } + }) // _-> T >> T def transposePairAfter: Strategy[Rise] = idAfter `;` createTransposePair - @rule def removeTransposePair: Strategy[Rise] = { + def removeTransposePair: Strategy[Rise] = rule("removeTransposePair", { case e @ App(transpose(), App(transpose(), x)) => Success(x !: e.t) - } + }) // overlapped tiling // constraint: n - m = u - v // v = u + m - n - @rule def slideOverlap(u: Nat): Strategy[Rise] = { + def slideOverlap(u: Nat): Strategy[Rise] = rule("slideOverlap", { case e @ DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), m: Nat) => val v = u + m - n Success((slide(u)(v) >> map(slide(n)(m)) >> join) !: e.t) - } + }) // slide widening // slide n 1 >> drop l -> slide (n+l) 1 >> map(drop l) - @rule def dropInSlide: Strategy[Rise] = { + def dropInSlide: Strategy[Rise] = rule("dropInSlide", { case e@App(DepApp(NatKind, drop(), l: Nat), App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), Cst(1)), in)) => Success(app(map(drop(l)), app(slide(n + l)(1), preserveType(in))) !: e.t) - } + }) // slide n 1 >> take (N - r) -> slide (n+r) 1 >> map(take (n - r)) - @rule def takeInSlide: Strategy[Rise] = { + def takeInSlide: Strategy[Rise] = rule("takeInSlide", { case e@App(t@DepApp(NatKind, take(), rem: Nat), App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), Cst(1)), in)) => t.t match { case FunType(ArrayType(size, _), _) => @@ -157,59 +159,59 @@ object algorithmic { Success(app(map(take(n)), app(slide(n + r)(1), preserveType(in))) !: e.t) case _ => throw new Exception("this should not happen") } - } + }) - @rule def dropNothing: Strategy[Rise] = { + def dropNothing: Strategy[Rise] = rule("dropNothing", { case expr @ DepApp(NatKind, drop(), Cst(0)) => Success(fun(x => x) !: expr.t) - } + }) - @rule def takeAll: Strategy[Rise] = { + def takeAll: Strategy[Rise] = rule("takeAll", { case expr @ DepApp(NatKind, take(), n: Nat) => expr.t match { case FunType(ArrayType(m, _), _) if n == m => Success(fun(x => x) !: expr.t) case _ => Failure(takeAll) } - } + }) - @rule def padEmptyNothing: Strategy[Rise] = { + def padEmptyNothing: Strategy[Rise] = rule("padEmptyNothing", { case e @ DepApp(NatKind, padEmpty(), Cst(0)) => Success(fun(x => x) !: e.t) - } + }) - @rule def mapIdentity: Strategy[Rise] = { + def mapIdentity: Strategy[Rise] = rule("mapIdentity", { case expr @ App(map(), Lambda(x1, x2)) if x1 =~= x2 => Success(fun(x => x) !: expr.t) - } + }) // x -> join (slide 1 1 x) - @rule def slideAfter: Strategy[Rise] = e => Success(join(slide(1: Nat)(1: Nat)(e)) !: e.t) + def slideAfter: Strategy[Rise] = rule("slideAfter", e => Success(join(slide(1: Nat)(1: Nat)(e)) !: e.t)) - @rule def slideAfter2: Strategy[Rise] = e => Success(map(fun(x => x `@` lidx(0, 1)))(slide(1: Nat)(1: Nat)(e)) !: e.t) + def slideAfter2: Strategy[Rise] = rule("slideAfter2", e => Success(map(fun(x => x `@` lidx(0, 1)))(slide(1: Nat)(1: Nat)(e)) !: e.t)) // s -> map snd (zip f s) - @rule def zipFstAfter(f: Rise): Strategy[Rise] = s => (f.t, s.t) match { + def zipFstAfter(f: Rise): Strategy[Rise] = rule("zipFstAfter", s => (f.t, s.t) match { case (ArrayType(n, _), ArrayType(m, _)) if n == m => Success(map(snd)(zip(f)(s)) !: s.t) case _ => Failure(zipFstAfter(f)) - } + }) // f -> map fst (zip f s) - @rule def zipSndAfter(s: Rise): Strategy[Rise] = f => (f.t, s.t) match { + def zipSndAfter(s: Rise): Strategy[Rise] = rule("zipSndAfter", f => (f.t, s.t) match { case (ArrayType(n, _), ArrayType(m, _)) if n == m => Success(map(fst)(zip(f)(s)) !: f.t) case _ => Failure(zipSndAfter(s)) - } + }) // J >> drop d -> drop (d / m) >> J >> drop (d % m) - @rule def dropBeforeJoin: Strategy[Rise] = { + def dropBeforeJoin: Strategy[Rise] = rule("dropBeforeJoin", { case e @ App(DepApp(NatKind, drop(), d: Nat), App(join(), in)) => in.t match { case ArrayType(_, ArrayType(m, _)) => Success(app(drop(d % m), join(drop(d / m)(in))) !: e.t) case _ => throw new Exception("this should not happen") } - } + }) // J >> take (n*m - d) // -> dropLast (d / m) >> J >> dropLast (d % m) // -> take (n - d / m) >> J >> take ((n - d / m)*m - d % m) - @rule def takeBeforeJoin: Strategy[Rise] = { + def takeBeforeJoin: Strategy[Rise] = rule("takeBeforeJoin", { case e @ App(DepApp(NatKind, take(), nmd: Nat), App(join(), in)) => in.t match { case ArrayType(n, ArrayType(m, _)) => val d = n*m - nmd @@ -218,10 +220,10 @@ object algorithmic { Success(app(take(t2), join(take(t1)(in))) !: e.t) case _ => throw new Exception("this should not happen") } - } + }) // take n >> padEmpty m -> padEmpty m' - @rule def removeTakeBeforePadEmpty: Strategy[Rise] = { + def removeTakeBeforePadEmpty: Strategy[Rise] = rule("removeTakeBeforePadEmpty", { case e @ App(DepApp(NatKind, padEmpty(), m: Nat), App(DepApp(NatKind, take(), n: Nat), in)) => in.t match { case ArrayType(size, _) @@ -231,11 +233,11 @@ object algorithmic { case _ => Failure(removeTakeBeforePadEmpty) } - } + }) // makeArray(n)(map f1 e)..(map fn e) // -> e |> map(x => makeArray(n)(f1 x)..(fn x)) |> transpose - @rule def mapOutsideMakeArray: Strategy[Rise] = expr => { + def mapOutsideMakeArray: Strategy[Rise] = rule("mapOutsideMakeArray", expr => { def matchExpectedMakeArray(mka: Rise): Option[Rise] = mka match { case App(makeArray(_), App(App(map(), _), e)) => Some(e) case App(f, App(App(map(), _), e2)) => @@ -257,11 +259,11 @@ object algorithmic { !: expr.t) case None => Failure(mapOutsideMakeArray) } - } + }) // generate (i => select t (map f e) (map g e)) // -> e |> map (x => generate (i => select t (f x) (g x))) |> transpose - @rule def mapOutsideGenerateSelect()(implicit ev: Traversable[Rise]): Strategy[Rise] = { + def mapOutsideGenerateSelect()(implicit ev: Traversable[Rise]): Strategy[Rise] = rule("mapOutsideGenerateSelect", { case expr @ App(generate(), Lambda(i, App(App(App(select(), t), App(App(map(), f), e1)), App(App(map(), g), e2)))) @@ -270,19 +272,19 @@ object algorithmic { fun(x => generate(lambda(eraseType(i), select(t)(app(f, x), app(g, x))))))( e1 )) !: expr.t) - } + }) // select t (f a) (f b) -> f (select t a b) - @rule def fOutsideSelect: Strategy[Rise] = { + def fOutsideSelect: Strategy[Rise] = rule("fOutsideSelect", { case expr @ App(App(App(select(), t), App(f1, a)), App(f2, b)) if f1 =~= f2 => f1.t match { case FunType(_: DataType, _: DataType) => Success(app(f1, select(t)(a)(b)) !: expr.t) case _ => Failure(fOutsideSelect) } - } + }) // makeArray (f e1) .. (f en) -> map f (makeArray e1 .. en) - @rule def fOutsideMakeArray: Strategy[Rise] = expr => { + def fOutsideMakeArray: Strategy[Rise] = rule("fOutsideMakeArray", expr => { def matchExpectedMakeArray(mka: Rise): Option[(Int, Rise)] = mka match { case App(makeArray(n), App(f, _)) => f.t match { @@ -307,91 +309,91 @@ object algorithmic { Success(app(map(f), transformMakeArray(expr)) !: expr.t) case _ => Failure(fOutsideMakeArray) } - } + }) // zip (map fa a) (map fb b) -> zip a b >> map (p => pair (fa (fst p)) (fb (snd p))) - @rule def mapOutsideZip: Strategy[Rise] = { + def mapOutsideZip: Strategy[Rise] = rule("mapOutsideZip", { case expr @ App(App(zip(), App(App(map(), fa), a)), App(App(map(), fb), b)) => Success(map(fun(p => makePair(app(fa, fst(p)))(app(fb, snd(p)))))(zip(a)(b)) !: expr.t) - } + }) // pair (map fa a) (map fb b) // -> zip a b >> map (p => pair (fa (fst p)) (fb (snd p))) >> unzip - @rule def mapOutsidePair: Strategy[Rise] = { + def mapOutsidePair: Strategy[Rise] = rule("mapOutsidePair", { case expr @ App(App(makePair(), App(App(map(), fa), a)), App(App(map(), fb), b)) => Success(unzip(map(fun(p => makePair(app(fa, fst(p)))(app(fb, snd(p)))))(zip(a)(b))) !: expr.t) - } + }) // zip a a -> map (x => pair(x, x)) a - @rule def zipSame: Strategy[Rise] = { + def zipSame: Strategy[Rise] = rule("zipSame", { case expr @ App(App(zip(), a), a2) if a =~= a2 => Success(map(fun(x => makePair(x)(x)))(a) !: expr.t) - } + }) // zip(a, b) -> map (x => pair(snd(x), fst(x))) zip(b, a) - @rule def zipSwap: Strategy[Rise] = { + def zipSwap: Strategy[Rise] = rule("zipSwap", { case expr @ App(App(zip(), a), b) => Success(map(fun(x => makePair(snd(x))(fst(x))))(zip(b)(a)) !: expr.t) - } + }) // zip(a, zip(b, c)) -> map (x => pair(.., pair(..))) zip(zip(a, b), c) - @rule def zipRotateLeft: Strategy[Rise] = { + def zipRotateLeft: Strategy[Rise] = rule("zipRotateLeft", { case expr @ App(App(zip(), a), App(App(zip(), b), c)) => Success(map( fun(x => makePair(fst(fst(x)))(makePair(snd(fst(x)))(snd(x)))))( zip(zip(a)(b))(c) ) !: expr.t) - } + }) // zip(zip(a, b), c) -> map (x => pair(pair(..), ..)) zip(a, zip(b, c)) - @rule def zipRotateRight: Strategy[Rise] = { + def zipRotateRight: Strategy[Rise] = rule("zipRotateRight", { case expr @ App(App(zip(), App(App(zip(), a), b)), c) => Success(map( fun(x => makePair(makePair(fst(x))(fst(snd(x))))(snd(snd(x)))))( zip(a)(zip(b)(c)) ) !: expr.t) - } + }) def zipRotate: Strategy[Rise] = zipRotateLeft <+ zipRotateRight // e -> map (x => x) e - @rule def mapIdentityAfter: Strategy[Rise] = expr => expr.t match { + def mapIdentityAfter: Strategy[Rise] = rule("mapIdentityAfter", expr => expr.t match { case ArrayType(_, _) => Success(map(fun(x => x))(expr) !: expr.t) case _ => Failure(mapIdentityAfter) - } + }) // fst (pair a b) -> a - @rule def fstReduction: Strategy[Rise] = { + def fstReduction: Strategy[Rise] = rule("fstReduction", { case expr @ App(fst(), App(App(makePair(), a), _)) => Success(a !: expr.t) - } + }) // snd (pair a b) -> b - @rule def sndReduction: Strategy[Rise] = { + def sndReduction: Strategy[Rise] = rule("sndReduction", { case expr @ App(snd(), App(App(makePair(), _), b)) => Success(b !: expr.t) - } + }) // zip (slide n m a) (slide n m b) -> map unzip (slide n m (zip a b)) - @rule def slideOutsideZip: Strategy[Rise] = { + def slideOutsideZip: Strategy[Rise] = rule("slideOutsideZip", { case expr @ App(App(zip(), App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), m: Nat), a)), App(DepApp(NatKind, DepApp(NatKind, slide(), n2: Nat), m2: Nat), b) ) if n == n2 && m == m2 => Success(map(unzip)(slide(n)(m)(zip(a)(b))) !: expr.t) - } + }) // slide n m (zip a b) -> map zip (zip (slide n m a) (slide n m b)) - @rule def slideInsideZip: Strategy[Rise] = { + def slideInsideZip: Strategy[Rise] = rule("slideInsideZip", { case expr @ App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), m: Nat), App(App(zip(), a), b) ) => Success(map(fun(p => zip(fst(p))(snd(p))))( zip(slide(n)(m)(a))(slide(n)(m)(b))) !: expr.t) - } + }) // TODO? // map (x => g (f (fst x))) (zip a b) -> map (x => g (fst x)) (zip (map f a) b) // def fBeforeZipMapFst: Strategy[Rise] = // map (x => g (f (snd x))) (zip a b) -> map (x => g (snd x)) (zip a (map f b)) // def fBeforeZipMapSnd: Strategy[Rise] = - @rule def fBeforeZipMap: Strategy[Rise] = { + def fBeforeZipMap: Strategy[Rise] = rule("fBeforeZipMap", { case expr @ App( App(map(), Lambda(x, App(App(zip(), App(f, App(fst(), x2))), @@ -402,44 +404,44 @@ object algorithmic { map(fun(x => zip(fst(x))(snd(x)))), zip(map(f)(a))(map(g)(b)) ) !: expr.t) - } + }) // a |> map (zip b) |> transpose // -> transpose a |> zip2D (map (x => generate (_ => x) b) - @rule def transposeBeforeMapZip: Strategy[Rise] = { + def transposeBeforeMapZip: Strategy[Rise] = rule("transposeBeforeMapZip", { case e @ App(transpose(), App(App(map(), App(zip(), b)), a)) => Success(map(fun(p => zip(fst(p))(snd(p))))( zip(map(fun(x => generate(fun(_ => x))))(b))(transpose(a))) !: e.t) - } + }) // unzip (zip a b) -> pair a b - @rule def unzipZipIsPair: Strategy[Rise] = { + def unzipZipIsPair: Strategy[Rise] = rule("unzipZipIsPair", { case e @ App(unzip(), App(App(zip(), a), b)) => Success(makePair(a)(b) !: e.t) - } + }) // FIXME: fighting against beta-reduction // unzip ((p => zip (fst p) (snd p)) in) -> in - @rule def unzipZipIdentity: Strategy[Rise] = { + def unzipZipIdentity: Strategy[Rise] = rule("unzipZipIdentity", { case e @ App(unzip(), App(Lambda(p, App(App(zip(), App(fst(), p2)), App(snd(), p3))), in)) if p =~= p2 && p =~= p3 => Success(in !: e.t) - } + }) // FIXME: this is very specific // zip (fst/snd unzip e) (fst/snd unzip e) // -> map (p => pair (fst/snd p) (fst/snd p)) e - @rule def zipUnzipAccessSimplification: Strategy[Rise] = { + def zipUnzipAccessSimplification: Strategy[Rise] = rule("zipUnzipAccessSimplification", { case e @ App(App(zip(), App(a1 @ (fst() | snd()), App(unzip(), e1))), App(a2 @ (fst() | snd()), App(unzip(), e2)) ) if e1 =~= e2 => Success(map(fun(p => makePair(eraseType(a1)(p))(eraseType(a2)(p))))(e1) !: e.t) - } + }) // FIXME: this is very specific - @rule def zipAsVectorUnzipSimplification: Strategy[Rise] = { + def zipAsVectorUnzipSimplification: Strategy[Rise] = rule("zipAsVectorUnzipSimplification", { case e @ App( Lambda(x, App(App(zip(), App(DepApp(NatKind, asVector(), v: Nat), App(fst(), x2))), @@ -452,12 +454,12 @@ object algorithmic { mapFst(asVectorAligned(v)) |> mapSnd(asVectorAligned(v)) |> fun(p => zip(fst(p))(snd(p))) Success(r !: e.t) - } + }) // FIXME: this is very specific // map (p => g (fst p) (snd p)) (zip (fst/snd e) (fst/snd e)) // -> map (p => g (fst/snd p) (fst/snd p)) (zip (fst e) (snd e)) - @rule def mapProjZipUnification()(implicit ev: Traversable[Rise]): Strategy[Rise] = { + def mapProjZipUnification()(implicit ev: Traversable[Rise]): Strategy[Rise] = rule("mapProjZipUnification", { case e @ App(App(map(), Lambda(p, App(App(g, App(fst(), p1)), App(snd(), p2)))), App(App(zip(), @@ -468,7 +470,7 @@ object algorithmic { => Success(map(fun(p => preserveType(g)(eraseType(a1)(p), eraseType(a2)(p))))( zip(fst(e1))(snd(e1))) !: e.t) - } + }) // TODO: should not be in this file? // broadly speaking, f(x) -> x |> fun(y => f(y)) @@ -489,15 +491,15 @@ object algorithmic { // the inner strategies shouldn't be accessible from the outside // because they might change the semantics of a program - @rule def freshLambdaIdentifier()(implicit ev: Traversable[Rise]): Strategy[Rise] = e => { - @rule def freshIdentifier: Strategy[Rise] = { + def freshLambdaIdentifier()(implicit ev: Traversable[Rise]): Strategy[Rise] = rule("freshLambdaIdentifier", e => { + def freshIdentifier: Strategy[Rise] = rule("freshIdentifier", { case Identifier(name) ::: t => Success(Identifier(freshName("fresh_"+ name))(t)) - } + }) - @rule def replaceIdentifier(curr: Identifier, newId: Identifier): Strategy[Rise] = { + def replaceIdentifier(curr: Identifier, newId: Identifier): Strategy[Rise] = rule("replaceIdentifier", { case x: Identifier if curr =~= x => Success(newId) - } + }) e match { case Lambda(x,e) ::: t if contains[Rise](x).apply(e) => @@ -506,21 +508,21 @@ object algorithmic { Success(Lambda(newX, newE)(t)) case _ => Failure(freshLambdaIdentifier()) } - } + }) // different name for ICFP'20 - def splitStrategy(n: Nat)(implicit ev: Traversable[Rise]): Strategy[Rise] = blockedReduce(n) - @rule def blockedReduce(n: Nat)(implicit ev: Traversable[Rise]): Strategy[Rise] = { + def splitStrategy(n: Nat)(using ev: Traversable[Rise]): Strategy[Rise] = blockedReduce(n) + def blockedReduce(n: Nat)(using ev: Traversable[Rise]): Strategy[Rise] = rule("blockedReduce", { case App(App(App(reduce(), op ::: FunType(yT, FunType(initT, outT))), init), arg) if yT =~= outT => // avoid having two lambdas using the same identifiers val freshOp = tryAll(freshLambdaIdentifier()).apply(op).get - DFNF()(ev)( + DFNF()( (reduceSeq(fun((acc, y) => preserveType(op)(acc, reduce(freshOp)(init)(y))))(init) o split(n)) $ arg ) - } + }) @@ -530,7 +532,7 @@ object algorithmic { zip(a)(b) |> map(mulT) |> sum )) // TODO: check separability property? - @rule def separateDotHV(weights2d: Expr, wH: Expr, wV: Expr): Strategy[Rise] = { + def separateDotHV(weights2d: Expr, wH: Expr, wV: Expr): Strategy[Rise] = rule("separateDotHV", { case e @ App(App(App(reduce(), rf), init), App(App(map(), mf), App(App(zip(), App(join(), weights)), App(join(), nbh)) )) if rf =~= ((add !: rf.t): Expr) && @@ -539,9 +541,9 @@ object algorithmic { weights =~= weights2d => Success((preserveType(nbh) |> map(dot(wH)) |> dot(wV)) !: e.t) - } + }) - @rule def separateDotVH(weights2d: Expr, wV: Expr, wH: Expr): Strategy[Rise] = { + def separateDotVH(weights2d: Expr, wV: Expr, wH: Expr): Strategy[Rise] = rule("separateDotVH", { case e @ App(App(App(reduce(), rf), init), App(App(map(), mf), App(App(zip(), App(join(), weights)), App(join(), nbh)) )) if rf =~= ((add !: rf.t): Expr) && @@ -550,15 +552,15 @@ object algorithmic { weights =~= weights2d => Success((preserveType(nbh) |> transpose |> map(dot(wV)) |> dot(wH)) !: e.t) - } + }) - @rule def separateSumHV: Strategy[Rise] = { + def separateSumHV: Strategy[Rise] = rule("separateSumHV", { case e @ App(sum2, App(join(), in)) if sum2 =~= ((sum !: sum2.t): Expr) => Success((preserveType(in) |> map(sum) |> sum) !: e.t) - } + }) - @rule def separateSumVH: Strategy[Rise] = { + def separateSumVH: Strategy[Rise] = rule("separateSumVH", { case e @ App(sum2, App(join(), in)) if sum2 =~= ((sum !: sum2.t): Expr) => Success((preserveType(in) |> transpose |> map(sum) |> sum) !: e.t) - } + }) } diff --git a/src/main/scala/rise/elevate/rules/lowering.scala b/src/main/scala/rise/elevate/rules/lowering.scala index 4e6bc5555..11b41390f 100644 --- a/src/main/scala/rise/elevate/rules/lowering.scala +++ b/src/main/scala/rise/elevate/rules/lowering.scala @@ -5,9 +5,8 @@ import elevate.core.strategies.basic._ import elevate.core.strategies.predicate._ import elevate.core.strategies.traversal._ import elevate.core.strategies.{Traversable, predicate} -import elevate.core.{Failure, Strategy, Success} -import elevate.macros.RuleMacro.rule -import elevate.macros.StrategyMacro.strategy +import elevate.core._ +import elevate.core.macros._ import rise.core.DSL._ import rise.core.primitives.{not => _, _} import rise.core.types.DataType._ @@ -33,73 +32,73 @@ object lowering { } def `map -> mapSeq`: Strategy[Rise] = mapSeq - @rule def mapSeq: Strategy[Rise] = { + def mapSeq: Strategy[Rise] = rule("mapSeq", { case m@map() => Success(p.mapSeq !: m.t) - } + }) def `map -> mapPar`: Strategy[Rise] = mapPar - @rule def mapPar: Strategy[Rise] = { + def mapPar: Strategy[Rise] = rule("mapPar", { case m@map() => Success(omp.mapPar !: m.t) - } + }) def `map -> mapStream`: Strategy[Rise] = mapStream - @rule def mapStream: Strategy[Rise] = { + def mapStream: Strategy[Rise] = rule("mapStream", { case m@map() => Success(p.mapStream !: m.t) - } + }) def `map -> iterateStream`: Strategy[Rise] = iterateStream - @rule def iterateStream: Strategy[Rise] = { + def iterateStream: Strategy[Rise] = rule("iterateStream", { case m@map() => Success(p.iterateStream !: m.t) - } + }) def `map -> mapSeqUnroll`: Strategy[Rise] = mapSeqUnroll - @rule def mapSeqUnroll: Strategy[Rise] = { + def mapSeqUnroll: Strategy[Rise] = rule("mapSeqUnroll", { case m@map() => Success(p.mapSeqUnroll !: m.t) - } + }) def `map -> mapGlobal`(dim: Int = 0): Strategy[Rise] = mapGlobal(dim) - @rule def mapGlobal(dim: Int = 0): Strategy[Rise] = { + def mapGlobal(dim: Int = 0): Strategy[Rise] = rule("mapGlobal", { case m@map() => Success(rise.openCL.DSL.mapGlobal(dim) !: m.t) - } + }) def `reduce -> reduceSeq`: Strategy[Rise] = reduceSeq - @rule def reduceSeq: Strategy[Rise] = { + def reduceSeq: Strategy[Rise] = rule("reduceSeq", { case e@reduce() => Success(p.reduceSeq !: e.t) - } + }) def `reduce -> reduceSeqUnroll`: Strategy[Rise] = reduceSeqUnroll - @rule def reduceSeqUnroll: Strategy[Rise] = { + def reduceSeqUnroll: Strategy[Rise] = rule("reduceSeqUnroll", { case e@reduce() => Success(p.reduceSeqUnroll !: e.t) - } + }) // Specialized Lowering - @rule def mapSeqCompute()(implicit ev: Traversable[Rise]): Strategy[Rise] = { + def mapSeqCompute()(implicit ev: Traversable[Rise]): Strategy[Rise] = rule("mapSeqCompute", { case e@App(map(), f) if containsComputation()(ev)(f) && predicate.not(isMappingZip)(f) => Success(p.mapSeq(f) !: e.t) - } + }) - @rule def isMappingZip: Strategy[Rise] = { + def isMappingZip: Strategy[Rise] = rule("isMappingZip", { case l@Lambda(_, App(App(zip(), a), b)) => Success(l) case m@Lambda(_, App(App(map(), f), arg)) => isMappingZip(f) - } + }) // TODO: load identity instead, then change with other rules? - @rule def circularBuffer(load: Expr): Strategy[Rise] = { + def circularBuffer(load: Expr): Strategy[Rise] = rule("circularBuffer", { case e@DepApp(NatKind, DepApp(NatKind, slide(), sz: Nat), Cst(1)) => Success( p.circularBuffer(sz)(sz)(eraseType(load)) !: e.t) - } + }) - @rule def rotateValues(write: Expr): Strategy[Rise] = { + def rotateValues(write: Expr): Strategy[Rise] = rule("rotateValues", { case e@DepApp(NatKind, DepApp(NatKind, slide(), sz: Nat), Cst(1)) => Success( p.rotateValues(sz)(eraseType(write)) !: e.t) - } + }) - @rule def containsComputation()(implicit ev: Traversable[Rise]): Strategy[Rise] = - topDown(isComputation())(ev) + def containsComputation()(implicit ev: Traversable[Rise]): Strategy[Rise] = + rule("containsComputation", topDown(isComputation())(ev)) // requires type information! - @rule def isComputation()(implicit ev: Traversable[Rise]): Strategy[Rise] = e => { + def isComputation()(implicit ev: Traversable[Rise]): Strategy[Rise] = rule("isComputation", e => { def isPairOrBasicType(t: ExprType): Boolean = t match { case _ if typeHasTrivialCopy(t) => true case PairType(a, b) => isPairOrBasicType(a) && isPairOrBasicType(b) @@ -123,7 +122,7 @@ object lowering { case f@foreignFunction(_, _) => Success(f) case _ => Failure(containsComputation()) } - } + }) // case class slideSeq(rot: SlideSeq.Rotate, write_dt1: Expr) extends Strategy[Rise] { @@ -139,36 +138,36 @@ object lowering { // writing to memory // TODO: think about more complex cases - @rule def mapSeqUnrollWrite: Strategy[Rise] = e => e.t match { + def mapSeqUnrollWrite: Strategy[Rise] = rule("mapSeqUnrollWrite", e => e.t match { case ArrayType(_, t) if typeHasTrivialCopy(t) => Success(app(p.mapSeqUnroll(fun(x => x)), preserveType(e)) !: e.t) case _ => Failure(mapSeqUnrollWrite) - } + }) - @rule def toMemAfterMapSeq: Strategy[Rise] = { + def toMemAfterMapSeq: Strategy[Rise] = rule("toMemAfterMapSeq", { case a@App(App(p.mapSeq(), _), _) => Success((preserveType(a) |> p.toMem) !: a.t) - } + }) // Lowerings used in PLDI submission // adds copy after every generate - def materializeGenerate()(implicit ev: Traversable[Rise]): Strategy[Rise] = - normalize(ev)( + def materializeGenerate()(using ev: Traversable[Rise]): Strategy[Rise] = + normalize( argument(function(isGenerate)) `;` not(isCopy) `;` argument(copyAfterGenerate) ) // adds explicit copies for every init value in reductions - def materializeInitOfReduce()(implicit ev: Traversable[Rise]): Strategy[Rise] = - normalize(ev)( + def materializeInitOfReduce()(using ev: Traversable[Rise]): Strategy[Rise] = + normalize( function(function(isReduceX)) `;` argument(not(isCopy) `;` insertCopyAfter) ) - @rule def insertCopyAfter: Strategy[Rise] = e => { + def insertCopyAfter: Strategy[Rise] = rule("insertCopyAfter", e => { def constructCopy(t: ExprType): ToBeTyped[Rise] = t match { case ArrayType(_, dt) => p.mapSeq(fun(x => constructCopy(dt) $ x)) case _ if typeHasTrivialCopy(t) => fun(x => x) @@ -176,35 +175,35 @@ object lowering { } Success(constructCopy(e.t) $ e) - } + }) // todo currently only works for mapSeq - @rule def isCopy: Strategy[Rise] = { + def isCopy: Strategy[Rise] = rule("isCopy", { case c@App(p.let(), id) if isId(id) => Success(c) case c@App(App(p.mapSeq(), id), etaInput) if isId(id) => Success(c) case App(App(p.mapSeq(), Lambda(_, f)), etaInput) => isCopy(f) case c@App(id, _) if isId(id) => Success(c) - } + }) - @rule def isId: Strategy[Rise] = { + def isId: Strategy[Rise] = rule("isId", { case l@Lambda(x1, x2) if x1 =~= x2 => Success(l) - } + }) // requires expr to be in LCNF - def specializeSeq()(implicit ev: Traversable[Rise]): Strategy[Rise] = - normalize(ev)(lowering.mapSeqCompute() <+ lowering.reduceSeq) + def specializeSeq()(using ev: Traversable[Rise]): Strategy[Rise] = + normalize(lowering.mapSeqCompute() <+ lowering.reduceSeq) - def addRequiredCopies()(implicit ev: Traversable[Rise]): Strategy[Rise] = + def addRequiredCopies()(using ev: Traversable[Rise]): Strategy[Rise] = // `try`(oncetd(copyAfterReduce)) `;` LCNF `;` materializeInitOfReduce tryAll(copyAfterReduce) `;` DFNF() `;` materializeInitOfReduce() // todo gotta use a normalform for introducing copies! e.g., if we have two reduce primitives - def lowerToC(implicit ev: Traversable[Rise]): Strategy[Rise] = + def lowerToC(using ev: Traversable[Rise]): Strategy[Rise] = addRequiredCopies() `;` specializeSeq() // todo currently only works for mapSeq - @rule def copyAfterReduce: Strategy[Rise] = e => { + def copyAfterReduce: Strategy[Rise] = rule("copyAfterReduce", e => { def constructCopy(t: ExprType): ToBeTyped[Rise] = t match { case _ if typeHasTrivialCopy(t) => letf(fun(x => x)) case ArrayType(_, b) if typeHasTrivialCopy(b) => p.mapSeq(fun(x => x)) @@ -217,9 +216,9 @@ object lowering { Success((preserveType(e) |> constructCopy(reduceResult.t) ) !: e.t) case _ => Failure(copyAfterReduce) } - } + }) - @rule def copyAfterReduceInit: Strategy[Rise] = e => { + def copyAfterReduceInit: Strategy[Rise] = rule("copyAfterReduceInit", e => { def constructCopy(t: ExprType): ToBeTyped[Rise] = t match { case _ if typeHasTrivialCopy(t) => letf(fun(x => x)) case ArrayType(_, b) if typeHasTrivialCopy(b) => p.mapSeq(fun(x => x)) @@ -232,10 +231,10 @@ object lowering { Success((preserveType(init) |> constructCopy(init.t) |> a) !: e.t) case _ => Failure(copyAfterReduceInit) } - } + }) // todo currently only works for mapSeq - @rule def copyAfterGenerate: Strategy[Rise] = e => { + def copyAfterGenerate: Strategy[Rise] = rule("copyAfterGenerate", e => { def constructCopy(t: ExprType): ToBeTyped[Rise] = t match { case ArrayType(_, dt) => p.mapSeq(fun(x => constructCopy(dt) $ x)) case _ if typeHasTrivialCopy(t) => fun(x => x) @@ -247,31 +246,29 @@ object lowering { Success((preserveType(a) |> constructCopy(a.t)) !: e.t) case _ => Failure(copyAfterGenerate) } - } + }) - @rule def toMemAfterAsScalar: Strategy[Rise] = { + def toMemAfterAsScalar: Strategy[Rise] = rule("toMemAfterAsScalar", { case a@App(asScalar(), _) => Success((preserveType(a) |> p.toMem) !: a.t) - } + }) - @rule def toMemAfter: Strategy[Rise] = - e => Success((preserveType(e) |> p.toMem) !: e.t) + def toMemAfter: Strategy[Rise] = rule("toMemAfter", + e => Success((preserveType(e) |> p.toMem) !: e.t)) - @rule def toMemBefore: Strategy[Rise] = { + def toMemBefore: Strategy[Rise] = rule("toMemBefore", { case a@App(f, e) => Success((p.toMem(e) |> preserveType(f)) !: a.t) - } + }) - @strategy def storeTempsAsScalars: Strategy[Rise] = - innermost(isApplied(isPrimitive(asScalar)))(toMemAfter) + strategy("storeTempsAsScalars", innermost(isApplied(isPrimitive(asScalar)))(toMemAfter)) - @strategy def storeTempAsVectors: Strategy[Rise] = - innermost(isApplied(isPrimitive(asScalar)))(toMemBefore) + strategy("storeTempAsVectors", innermost(isApplied(isPrimitive(asScalar)))(toMemBefore)) def `map(f) -> asVector >> map(f_vec) >> asScalar`(n: Nat): Strategy[Rise] = vectorize(n)(default.RiseTraversable) - @rule def vectorize(n: Nat)(implicit ev: Traversable[Rise]): Strategy[Rise] = { + def vectorize(n: Nat)(implicit ev: Traversable[Rise]): Strategy[Rise] = rule("vectorize", { case a@App(App(map(), f), input) if isComputation()(ev)(f) && !isVectorArray(a.t) => @@ -296,49 +293,49 @@ object lowering { toBeTyped(input) |> vectorizeArrayBasedOnType(input.t) |> (map(newF) >> asScalar) ) case _ => Failure(vectorize(n)) - } + }) - @rule def untype: Strategy[Rise] = p => Success(p.setType(TypePlaceholder)) + def untype: Strategy[Rise] = rule("untype", p => Success(p.setType(TypePlaceholder))) def parallel()(implicit ev: Traversable[Rise]): Strategy[Rise] = mapParCompute() - @rule def mapParCompute()(implicit ev: Traversable[Rise]): Strategy[Rise] = { + def mapParCompute()(implicit ev: Traversable[Rise]): Strategy[Rise] = rule("mapParCompute", { case e@App(map(), f) if containsComputation()(ev)(f) => Success(omp.mapPar(f) !: e.t) - } + }) - @rule def unroll: Strategy[Rise] = { + def unroll: Strategy[Rise] = rule("unroll", { case e@p.reduceSeq() => Success(p.reduceSeqUnroll !: e.t) - } + }) object ocl { import rise.core.types.AddressSpace import rise.openCL.primitives._ // TODO shall we allow lowering from an already lowered reduceSeq? - @rule def reduceSeqUnroll(a: AddressSpace): Strategy[Rise] = { + def reduceSeqUnroll(a: AddressSpace): Strategy[Rise] = rule("reduceSeqUnroll", { case e@reduce() => Success(oclReduceSeqUnroll(a) !: e.t) case e@p.reduceSeq() => Success(oclReduceSeqUnroll(a) !: e.t) - } + }) - @rule def circularBuffer(a: AddressSpace): Strategy[Rise] = { + def circularBuffer(a: AddressSpace): Strategy[Rise] = rule("circularBuffer", { case e@DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), Cst(1)) => Success( oclCircularBuffer(a)(n)(n)(fun(x => x)) !: e.t) - } + }) - @rule def circularBufferLoadFusion: Strategy[Rise] = { + def circularBufferLoadFusion: Strategy[Rise] = rule("circularBufferLoadFusion", { case e@App(App( cb @ DepApp(NatKind, DepApp(NatKind, DepApp(AddressSpaceKind, oclCircularBuffer(), _), _), _), load), App(App(map(), f), in) ) => Success(eraseType(cb)(preserveType(f) >> load, in) !: e.t) - } + }) - @rule def rotateValues(a: AddressSpace, write: Expr): Strategy[Rise] = { + def rotateValues(a: AddressSpace, write: Expr): Strategy[Rise] = rule("rotateValues", { case e@DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), Cst(1)) => Success( oclRotateValues(a)(n)(eraseType(write)) !: e.t) - } + }) } } diff --git a/src/main/scala/rise/elevate/rules/movement.scala b/src/main/scala/rise/elevate/rules/movement.scala index 5dacc725c..14a5eb4ee 100644 --- a/src/main/scala/rise/elevate/rules/movement.scala +++ b/src/main/scala/rise/elevate/rules/movement.scala @@ -3,7 +3,7 @@ package rise.elevate.rules import elevate.core.strategies.Traversable import elevate.core.strategies.predicate._ import elevate.core._ -import elevate.macros.RuleMacro.rule +import elevate.core.macros.rule import rise.elevate._ import rise.core._ import rise.core.types._ @@ -32,7 +32,7 @@ object movement { // transpose def mapMapFBeforeTranspose()(implicit ev: Traversable[Rise]): Strategy[Rise] = `**f >> T -> T >> **f`()(ev) - @rule def `**f >> T -> T >> **f`()(implicit ev: Traversable[Rise]): Strategy[Rise] = { + def `**f >> T -> T >> **f`()(implicit ev: Traversable[Rise]): Strategy[Rise] = rule("**f >> T -> T >> **f", { case e@App( transpose(), App(App(map(), App(map(), f)), y)) => @@ -46,13 +46,13 @@ object movement { ) if etaReduction()(ev)(lamA) && etaReduction()(ev)(lamB) => // Success((typed(arg) |> transpose |> map(map(f))) :: e.t) Success((preserveType(arg) |> transpose |> map(fun(a => map(fun(b => preserveType(f)(b)))(a)))) !: e.t) - } + }) def transposeBeforeMapMapF: Strategy[Rise] = `T >> **f -> **f >> T` - @rule def `T >> **f -> **f >> T`: Strategy[Rise] = { + def `T >> **f -> **f >> T`: Strategy[Rise] = rule("T >> **f -> **f >> T", { case e@App(App(map(), App(map(), f)), App(transpose(), y)) => Success((preserveType(y) |> map(map(f)) |> transpose) !: e.t) - } + }) // split/slide @@ -64,241 +64,241 @@ object movement { } def slideBeforeMapMapF: Strategy[Rise] = `S >> **f -> *f >> S` - @rule def `S >> **f -> *f >> S`: Strategy[Rise] = { + def `S >> **f -> *f >> S`: Strategy[Rise] = rule("S >> **f -> *f >> S", { case e@App(App(map(), App(map(), f)), App(s, y)) if isSplitOrSlide(s) => Success((preserveType(y) |> map(f) |> eraseType(s)) !: e.t) - } + }) def slideBeforeMap: Strategy[Rise] = `*f >> S -> S >> **f` - @rule def `*f >> S -> S >> **f`: Strategy[Rise] = { + def `*f >> S -> S >> **f`: Strategy[Rise] = rule("*f >> S -> S >> **f", { case e@App(s @ DepApp(NatKind, DepApp(NatKind, slide(), _: Nat), _: Nat), App(App(map(), f), y)) => Success((preserveType(y) |> eraseType(s) |> map(map(f))) !: e.t) - } + }) // *f >> S -> S >> **f - @rule def splitBeforeMap: Strategy[Rise] = { + def splitBeforeMap: Strategy[Rise] = rule("splitBeforeMap", { case e@App(s @ DepApp(NatKind, split(), _: Nat), App(App(map(), f), y)) => Success((preserveType(y) |> eraseType(s) |> map(map(f))) !: e.t) - } + }) // join def joinBeforeMapF: Strategy[Rise] = `J >> *f -> **f >> J` - @rule def `J >> *f -> **f >> J`: Strategy[Rise] = { + def `J >> *f -> **f >> J`: Strategy[Rise] = rule("J >> *f -> **f >> J", { case e@App(App(map(), f),App(join(), y)) => Success((preserveType(y) |> map(map(f)) >> join) !: e.t) - } + }) def mapMapFBeforeJoin: Strategy[Rise] = `**f >> J -> J >> *f` - @rule def `**f >> J -> J >> *f`: Strategy[Rise] = { + def `**f >> J -> J >> *f`: Strategy[Rise] = rule("**f >> J -> J >> *f", { case e@App(join(), App(App(map(), App(map(), f)), y)) => Success((preserveType(y) |> join |> map(f)) !: e.t) - } + }) // drop and take def dropBeforeMap: Strategy[Rise] = `*f >> drop n -> drop n >> *f` - @rule def `*f >> drop n -> drop n >> *f`: Strategy[Rise] = { + def `*f >> drop n -> drop n >> *f`: Strategy[Rise] = rule("*f >> drop n -> drop n >> *f", { case expr @ App(DepApp(NatKind, drop(), n: Nat), App(App(map(), f), in)) => Success(app(map(f), app(drop(n), preserveType(in))) !: expr.t) - } + }) def takeBeforeMap: Strategy[Rise] = `*f >> take n -> take n >> *f` - @rule def `*f >> take n -> take n >> *f`: Strategy[Rise] = { + def `*f >> take n -> take n >> *f`: Strategy[Rise] = rule("*f >> take n -> take n >> *f", { case expr @ App(DepApp(NatKind, take(), n: Nat), App(App(map(), f), in)) => Success(app(map(f), app(take(n), preserveType(in))) !: expr.t) - } + }) // take n >> *f -> *f >> take n - @rule def takeAfterMap: Strategy[Rise] = { + def takeAfterMap: Strategy[Rise] = rule("takeAfterMap", { case e @ App(App(map(), f), App(DepApp(NatKind, take(), n: Nat), in)) => Success(take(n)(map(f)(in)) !: e.t) - } + }) def takeInZip: Strategy[Rise] = `take n (zip a b) -> zip (take n a) (take n b)` - @rule def `take n (zip a b) -> zip (take n a) (take n b)`: Strategy[Rise] = { + def `take n (zip a b) -> zip (take n a) (take n b)`: Strategy[Rise] = rule("take n (zip a b) -> zip (take n a) (take n b)", { case expr @ App(DepApp(NatKind, take(), n), App(App(zip(), a), b)) => Success(zip(depApp(NatKind, take, n)(a))(depApp(NatKind, take, n)(b)) !: expr.t) - } + }) // zip (take n a) (take n b) -> take n (zip a b) - @rule def takeOutisdeZip: Strategy[Rise] = { + def takeOutisdeZip: Strategy[Rise] = rule("takeOutisdeZip", { case e @ App(App(zip(), App(DepApp(NatKind, take(), n1: Nat), a)), App(DepApp(NatKind, take(), n2: Nat), b) ) if n1 == n2 => Success(take(n1)(zip(a)(b)) !: e.t) - } + }) // pair (take n a) (take m b) -> pair a b >> mapFst take n >> mapSnd take m // TODO: can get any function out, see asScalarOutsidePair - @rule def takeOutsidePair: Strategy[Rise] = { + def takeOutsidePair: Strategy[Rise] = rule("takeOutsidePair", { case e @ App(App(makePair(), App(DepApp(NatKind, take(), n: Nat), a)), App(DepApp(NatKind, take(), m: Nat), b) ) => Success((makePair(a)(b) |> mapFst(take(n)) |> mapSnd(take(m))) !: e.t) - } + }) def dropInZip: Strategy[Rise] = `drop n (zip a b) -> zip (drop n a) (drop n b)` - @rule def `drop n (zip a b) -> zip (drop n a) (drop n b)`: Strategy[Rise] = { + def `drop n (zip a b) -> zip (drop n a) (drop n b)`: Strategy[Rise] = rule("drop n (zip a b) -> zip (drop n a) (drop n b)", { case expr @ App(DepApp(NatKind, drop(), n), App(App(zip(), a), b)) => Success(zip(depApp(NatKind, drop, n)(a))(depApp(NatKind, drop, n)(b)) !: expr.t) - } + }) def takeInSelect: Strategy[Rise] = `take n (select t a b) -> select t (take n a) (take n b)` - @rule def `take n (select t a b) -> select t (take n a) (take n b)`: Strategy[Rise] = { + def `take n (select t a b) -> select t (take n a) (take n b)`: Strategy[Rise] = rule("take n (select t a b) -> select t (take n a) (take n b)", { case expr @ App(DepApp(NatKind, take(), n), App(App(App(select(), t), a), b)) => Success(select(t)(depApp(NatKind, take, n)(a), depApp(NatKind, take, n)(b)) !: expr.t) - } + }) def dropInSelect: Strategy[Rise] = `drop n (select t a b) -> select t (drop n a) (drop n b)` - @rule def `drop n (select t a b) -> select t (drop n a) (drop n b)`: Strategy[Rise] = { + def `drop n (select t a b) -> select t (drop n a) (drop n b)`: Strategy[Rise] = rule("drop n (select t a b) -> select t (drop n a) (drop n b)", { case expr @ App(DepApp(NatKind, drop(), n), App(App(App(select(), t), a), b)) => Success(select(t)(depApp(NatKind, drop, n)(a), depApp(NatKind, drop, n)(b)) !: expr.t) - } + }) def dropBeforeTake: Strategy[Rise] = `take (n+m) >> drop m -> drop m >> take n` - @rule def `take (n+m) >> drop m -> drop m >> take n`: Strategy[Rise] = { + def `take (n+m) >> drop m -> drop m >> take n`: Strategy[Rise] = rule("take (n+m) >> drop m -> drop m >> take n", { case expr @ App(DepApp(NatKind, drop(), m: Nat), App(DepApp(NatKind, take(), nm: Nat), in)) => Success(app(take(nm - m), app(drop(m), preserveType(in))) !: expr.t) - } + }) def takeBeforeDrop: Strategy[Rise] = `drop m >> take n -> take (n+m) >> drop m` - @rule def `drop m >> take n -> take (n+m) >> drop m`: Strategy[Rise] = { + def `drop m >> take n -> take (n+m) >> drop m`: Strategy[Rise] = rule("drop m >> take n -> take (n+m) >> drop m", { case expr @ App(DepApp(NatKind, take(), n: Nat), App(DepApp(NatKind, drop(), m: Nat), in)) => Success(app(drop(m), app(take(n+m), preserveType(in))) !: expr.t) - } + }) def takeBeforeSlide: Strategy[Rise] = `slide n m >> take t -> take (m * (t - 1) + n) >> slide n m` - @rule def `slide n m >> take t -> take (m * (t - 1) + n) >> slide n m`: Strategy[Rise] = { + def `slide n m >> take t -> take (m * (t - 1) + n) >> slide n m`: Strategy[Rise] = rule("slide n m >> take t -> take (m * (t - 1) + n) >> slide n m", { case expr @ App(DepApp(NatKind, take(), t: Nat), App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), m: Nat), in)) => Success(app(slide(n)(m), take(m * (t - 1) + n)(in)) !: expr.t) - } + }) def dropBeforeSlide: Strategy[Rise] = `slide n m >> drop d -> drop (d * m) >> slide n m` - @rule def `slide n m >> drop d -> drop (d * m) >> slide n m`: Strategy[Rise] = { + def `slide n m >> drop d -> drop (d * m) >> slide n m`: Strategy[Rise] = rule("slide n m >> drop d -> drop (d * m) >> slide n m", { case expr @ App(DepApp(NatKind, drop(), d: Nat), App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), m: Nat), in)) => Success(app(slide(n)(m), drop(d * m)(in)) !: expr.t) - } + }) // slide n m >> padEmpty p -> padEmpty (p * m) >> slide n m - @rule def padEmptyBeforeSlide: Strategy[Rise] = { + def padEmptyBeforeSlide: Strategy[Rise] = rule("padEmptyBeforeSlide", { case e @ App(DepApp(NatKind, padEmpty(), p: Nat), App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), m: Nat), in) ) => Success(slide(n)(m)(padEmpty(p * m)(in)) !: e.t) - } + }) // map f >> padEmpty n -> padEmpty n >> map f - @rule def padEmptyBeforeMap: Strategy[Rise] = { + def padEmptyBeforeMap: Strategy[Rise] = rule("padEmptyBeforeMap", { case e @ App(DepApp(NatKind, padEmpty(), n: Nat), App(App(map(), f), in)) => Success(map(f)(padEmpty(n)(in)) !: e.t) - } + }) // transpose >> padEmpty n -> map (padEmpty n) >> transpose - @rule def padEmptyBeforeTranspose: Strategy[Rise] = { + def padEmptyBeforeTranspose: Strategy[Rise] = rule("padEmptyBeforeTranspose", { case e @ App(DepApp(NatKind, padEmpty(), n: Nat), App(transpose(), in)) => Success(transpose(map(padEmpty(n))(in)) !: e.t) - } + }) // padEmpty n (zip a b) -> zip (padEmpty n a) (padEmpty n b) - @rule def padEmptyInsideZip: Strategy[Rise] = { + def padEmptyInsideZip: Strategy[Rise] = rule("padEmptyInsideZip", { case e @ App(DepApp(NatKind, padEmpty(), n: Nat), App(App(zip(), a), b)) => Success(zip(padEmpty(n)(a))(padEmpty(n)(b)) !: e.t) - } + }) // FIXME: this is very specific // zip (fst e) (snd e) |> padEmpty n -> // (mapFst padEmpty n) (mapSnd padEmpty n) |> fun(p => zip (fst p) (snd(p)) - @rule def padEmptyBeforeZip: Strategy[Rise] = { + def padEmptyBeforeZip: Strategy[Rise] = rule("padEmptyBeforeZip", { case e @ App(DepApp(NatKind, padEmpty(), n: Nat), App(App(zip(), App(fst(), e1)), App(snd(), e2))) if e1 =~= e2 => Success((preserveType(e1) |> mapFst(padEmpty(n)) |> mapSnd(padEmpty(n)) |> fun(p => zip(fst(p))(snd(p)))) !: e.t) - } + }) // special-cases // slide + transpose def transposeBeforeSlide: Strategy[Rise] = `T >> S -> *S >> T >> *T` - @rule def `T >> S -> *S >> T >> *T`: Strategy[Rise] = { + def `T >> S -> *S >> T >> *T`: Strategy[Rise] = rule("T >> S -> *S >> T >> *T", { case e@App(s, App(transpose(), y)) if isSplitOrSlide(s) => Success((preserveType(y) |> map(eraseType(s)) |> transpose.apply |> map(transpose)) !: e.t) - } + }) def transposeBeforeMapSlide: Strategy[Rise] = `T >> *S -> S >> *T >> T` - @rule def `T >> *S -> S >> *T >> T`: Strategy[Rise] = { + def `T >> *S -> S >> *T >> T`: Strategy[Rise] = rule("T >> *S -> S >> *T >> T", { case e@App(App(map(), s), App(transpose(), y)) if isSplitOrSlide(s) => Success((preserveType(y) |> eraseType(s) |> map(transpose) |> transpose) !: e.t) - } + }) def mapSlideBeforeTranspose: Strategy[Rise] = `*S >> T -> T >> S >> *T` - @rule def `*S >> T -> T >> S >> *T`: Strategy[Rise] = { + def `*S >> T -> T >> S >> *T`: Strategy[Rise] = rule("*S >> T -> T >> S >> *T", { case e@App(transpose(), App(App(map(), s), y)) if isSplitOrSlide(s) => Success((preserveType(y) |> transpose.apply |> eraseType(s) |> map(transpose)) !: e.t) - } + }) // transpose + join def joinBeforeTranspose: Strategy[Rise] = `J >> T -> *T >> T >> *J` - @rule def `J >> T -> *T >> T >> *J`: Strategy[Rise] = { + def `J >> T -> *T >> T >> *J`: Strategy[Rise] = rule("J >> T -> *T >> T >> *J", { case e@App(transpose(), App(join(), y)) => Success((preserveType(y) |> map(transpose) |> transpose |> map(join)) !: e.t) - } + }) def transposeBeforeMapJoin: Strategy[Rise] = `T >> *J -> *T >> J >> T` - @rule def `T >> *J -> *T >> J >> T`: Strategy[Rise] = { + def `T >> *J -> *T >> J >> T`: Strategy[Rise] = rule("T >> *J -> *T >> J >> T", { case e@App(App(map(), join()), App(transpose(), y)) => Success((preserveType(y) |> map(transpose) |> join |> transpose) !: e.t) - } + }) def mapTransposeBeforeJoin: Strategy[Rise] = `*T >> J -> T >> *J >> T` - @rule def `*T >> J -> T >> *J >> T`: Strategy[Rise] = { + def `*T >> J -> T >> *J >> T`: Strategy[Rise] = rule("*T >> J -> T >> *J >> T", { case e@App(join(), App(App(map(), transpose()), y)) => Success((preserveType(y) |> transpose |> map(join) |> transpose) !: e.t) - } + }) def mapJoinBeforeTranspose: Strategy[Rise] = `*J >> T -> T >> *T >> J` - @rule def `*J >> T -> T >> *T >> J`: Strategy[Rise] = { + def `*J >> T -> T >> *T >> J`: Strategy[Rise] = rule("*J >> T -> T >> *T >> J", { case e@App(transpose(), App(App(map(), join()), y)) => Success((preserveType(y) |> transpose |> map(transpose) |> join) !: e.t) - } + }) // join + join def joinBeforeJoin: Strategy[Rise] = `J >> J -> *J >> J` - @rule def `J >> J -> *J >> J`: Strategy[Rise] = { + def `J >> J -> *J >> J`: Strategy[Rise] = rule("J >> J -> *J >> J", { case e@App(join(), App(join(), y)) => Success((preserveType(y) |> map(join) >> join) !: e.t) - } + }) def mapJoinBeforeJoin: Strategy[Rise] = `*J >> J -> J >> J` - @rule def `*J >> J -> J >> J`: Strategy[Rise] = { + def `*J >> J -> J >> J`: Strategy[Rise] = rule("*J >> J -> J >> J", { case e@App(join(), App(App(map(), join()), y)) => Success((preserveType(y) |> join |> join) !: e.t) - } + }) // split + slide def slideBeforeSplit: Strategy[Rise] = `slide(n)(s) >> split(k) -> slide(k+n-s)(k) >> map(slide(n)(s))` - @rule def `slide(n)(s) >> split(k) -> slide(k+n-s)(k) >> map(slide(n)(s))`: Strategy[Rise] = { + def `slide(n)(s) >> split(k) -> slide(k+n-s)(k) >> map(slide(n)(s))`: Strategy[Rise] = rule("slide(n)(s) >> split(k) -> slide(k+n-s)(k) >> map(slide(n)(s))", { case e@App(DepApp(NatKind, split(), k: Nat), App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), s: Nat), y)) => Success((preserveType(y) |> slide(k + n - s)(k) |> map(slide(n)(s))) !: e.t) - } + }) // TODO: what if s != 1? // slide(n)(s=1) >> slide(m)(k) -> slide(m+n-1)(k) >> map(slide(n)(1)) - @rule def slideBeforeSlide: Strategy[Rise] = { + def slideBeforeSlide: Strategy[Rise] = rule("slideBeforeSlide", { case e@App(DepApp(NatKind, DepApp(NatKind, slide(), m: Nat), k: Nat), App(DepApp(NatKind, DepApp(NatKind, slide(), n: Nat), s: Nat), in) ) if s == (1: Nat) => Success((preserveType(in) |> slide(m+n-s)(k) |> map(slide(n)(s))) !: e.t) - } + }) // nested map + reduce // different variants for rewriting map(reduce) to reduce(map) // todo what makes them different? can we decompose them into simpler rules? - @rule def liftReduce: Strategy[Rise] = { + def liftReduce: Strategy[Rise] = rule("liftReduce", { // 2D array of pairs ---------------------------------------------------- case e@App(map(), Lambda(_, @@ -398,11 +398,11 @@ object movement { ) result } - } + }) // mapSnd f >> mapFst g -> mapFst g >> mapSnd f - @rule def mapFstBeforeMapSnd: Strategy[Rise] = { + def mapFstBeforeMapSnd: Strategy[Rise] = rule("mapFstBeforeMapSnd", { case e @ App(App(mapFst(), g), App(App(mapSnd(), f), in)) => Success(mapSnd(f)(mapFst(g)(in)) !: e.t) - } + }) } diff --git a/src/main/scala/rise/elevate/rules/package.scala b/src/main/scala/rise/elevate/rules/package.scala index 80b0666f2..cd853a33f 100644 --- a/src/main/scala/rise/elevate/rules/package.scala +++ b/src/main/scala/rise/elevate/rules/package.scala @@ -4,7 +4,7 @@ import elevate.core.strategies.Traversable import elevate.core.strategies.predicate._ import elevate.core.strategies.traversal._ import elevate.core.{Failure, Strategy, Success} -import elevate.macros.RuleMacro.rule +import elevate.core.macros.rule import rise.core.DSL._ import rise.core._ import rise.core.types._ @@ -44,14 +44,14 @@ package object rules { case _ => Failure(etaReduction()) } - @rule def etaAbstraction: Strategy[Rise] = f => f.t match { + def etaAbstraction: Strategy[Rise] = rule("etaAbstraction", f => f.t match { case FunType(_, _) => val x = identifier(freshName("η")) Success(lambda(x, app(f, x)) !: f.t) case _ => Failure(etaAbstraction) - } + }) - @rule def idxReduction: Strategy[Rise] = e => { + def idxReduction: Strategy[Rise] = rule("idxReduction", e => { import arithexpr.arithmetic._ import rise.core.primitives._ import rise.core.semantics._ @@ -77,13 +77,13 @@ package object rules { case _ => Failure(idxReduction) } - } + }) - @rule def checkType(msg: String = ""): Strategy[Rise] = e => { + def checkType(msg: String = ""): Strategy[Rise] = rule("checkType", e => { types.check(e) match { case scala.util.Success(_) => Success(e) case scala.util.Failure(exception) => Failure(checkType(exception.getMessage)) } - } + }) } diff --git a/src/main/scala/rise/elevate/rules/vectorize.scala b/src/main/scala/rise/elevate/rules/vectorize.scala index e51b5a661..a93bcb0cb 100644 --- a/src/main/scala/rise/elevate/rules/vectorize.scala +++ b/src/main/scala/rise/elevate/rules/vectorize.scala @@ -2,7 +2,7 @@ package rise.elevate.rules import arithexpr.arithmetic.Cst import elevate.core._ -import elevate.macros.RuleMacro.rule +import elevate.core.macros.rule import rise.core.DSL._ import rise.core._ import rise.core.primitives._ @@ -14,48 +14,48 @@ object vectorize { // FIXME: sometimes assuming loads or stores will be aligned // _ -> asVector >> asScalar - @rule def after(n: Nat): Strategy[Rise] = e => e.t match { + def after(n: Nat): Strategy[Rise] = rule("after", e => e.t match { // FIXME: m + n hack case ArrayType(m, _: ScalarType) if (m + n) % n == (0: Nat) => Success(asScalar(asVector(n)(e)) !: e.t) - } + }) // _ -> padEmpty >> asVector >> asScalar >> take - @rule def roundUpAfter(n: Nat): Strategy[Rise] = e => e.t match { + def roundUpAfter(n: Nat): Strategy[Rise] = rule("roundUpAfter", e => e.t match { case ArrayType(m, _: ScalarType) => val roundUp = padEmpty(n - ((m + n) % n)) // FIXME: m + n hack Success(take(m)(asScalar(asVector(n)(roundUp(e)))) !: e.t) case _ => Failure(after(n)) - } + }) // _ -> asVectorAligned >> asScalar - @rule def alignedAfter(n: Nat): Strategy[Rise] = e => e.t match { + def alignedAfter(n: Nat): Strategy[Rise] = rule("alignedAfter", e => e.t match { // FIXME: m + n hack case ArrayType(m, _: ScalarType) if (m + n) % n == (0: Nat) => Success(asScalar(asVectorAligned(n)(e)) !: e.t) case _ => Failure(alignedAfter(n)) - } + }) // asScalar >> asVector -> _ - @rule def asScalarAsVectorId: Strategy[Rise] = { + def asScalarAsVectorId: Strategy[Rise] = rule("asScalarAsVectorId", { case e @ App(v, App(asScalar(), in)) if isAsVector(v) && e.t =~= in.t => Success(in) - } + }) // map (reduce f init) >> asVector -> asVector >> map (reduce f init) - @rule def beforeMapReduce: Strategy[Rise] = { + def beforeMapReduce: Strategy[Rise] = rule("beforeMapReduce", { case e @ App(v, App(App(map(), App(App(reduce(), f), init)), in)) if isAsVector(v) && isScalarFun(f.t) => // TODO: generalize? val inV = preserveType(in) |> transpose |> map(eraseType(v)) |> transpose val fV = vectorizeScalarFun(f, Set()) Success(map(reduce(fV)(vectorFromScalar(init)))(inV) !: e.t) - } + }) // TODO: express as a combination of beforeMapReduce, beforeMap, and others. // a |> map (zip b) |> map (reduce f init) |> asVector // -> a |> transpose |> map(asVector) |> transpose |> .. - @rule def beforeMapDot: Strategy[Rise] = { + def beforeMapDot: Strategy[Rise] = rule("beforeMapDot", { case e @ App(v, App(App(map(), App(r @ App(ReduceX(), f), init)), App(App(map(), App(zip(), b)), a) )) if isAsVector(v) && isScalarFun(f.t) => @@ -63,51 +63,51 @@ object vectorize { val bV = map(vectorFromScalar)(b) val rV = vectorizeScalarFun(r, Set()) Success(map(zip(bV) >> rV(vectorFromScalar(init)))(aV) !: e.t) - } + }) // map f >> asVector -> asVector >> map f - @rule def beforeMap: Strategy[Rise] = { + def beforeMap: Strategy[Rise] = rule("beforeMap", { case e @ App(v, App(App(map(), f), in)) if isAsVector(v) && isScalarFun(f.t) => val inV = makeAsVector(v)(in.t)(in) val fV = vectorizeScalarFun(f, Set()) Success(map(fV)(inV) !: e.t) - } + }) // pair (asScalar a) (asScalar b) // -> pair a b >> mapFst asScalar >> mapSnd asScalar // TODO: can get any function out, see takeOutsidePair - @rule def asScalarOutsidePair: Strategy[Rise] = { + def asScalarOutsidePair: Strategy[Rise] = rule("asScalarOutsidePair", { case e @ App(App(makePair(), App(asScalar(), a)), App(asScalar(), b)) => Success((makePair(a)(b) |> mapFst(asScalar) |> mapSnd(asScalar)) !: e.t) - } + }) // zip (asScalar a) (asScalar b) // -> pair a b >> mapFst asScalar >> mapSnd asScalar - @rule def asScalarOutsideZip: Strategy[Rise] = { + def asScalarOutsideZip: Strategy[Rise] = rule("asScalarOutsideZip", { case e @ App(App(makePair(), App(asScalar(), a)), App(asScalar(), b)) => Success((makePair(a)(b) |> mapFst(asScalar) |> mapSnd(asScalar)) !: e.t) - } + }) // padEmpty (p*v) (asScalar in) -> asScalar (padEmpty p in) - @rule def padEmptyBeforeAsScalar: Strategy[Rise] = { + def padEmptyBeforeAsScalar: Strategy[Rise] = rule("padEmptyBeforeAsScalar", { case App(DepApp(NatKind, padEmpty(), pv: Nat), App(asScalar(), in)) => in.t match { case ArrayType(_, VectorType(v, _)) if (pv % v) == (0: Nat) => Success(asScalar(padEmpty(pv / v)(in))) case _ => Failure(padEmptyBeforeAsScalar) } - } + }) // padEmpty p (asVector v in) -> asVector v (padEmpty (p*v) in) - @rule def padEmptyBeforeAsVector: Strategy[Rise] = { + def padEmptyBeforeAsVector: Strategy[Rise] = rule("padEmptyBeforeAsVector", { case e @ App(DepApp(NatKind, padEmpty(), p: Nat), App(asV @ DepApp(NatKind, _, v: Nat), in)) if isAsVector(asV) => Success(eraseType(asV)(padEmpty(p*v)(in)) !: e.t) - } + }) // TODO: express as a combination of smaller rules - @rule def alignSlide: Strategy[Rise] = { + def alignSlide: Strategy[Rise] = rule("alignSlide", { case e @ App(transpose(), App(App(map(), DepApp(NatKind, asVector(), Cst(v))), App(join(), App(App(map(), transpose()), @@ -152,11 +152,11 @@ object vectorize { padEmpty(pV) >> asVectorAligned(v) >> slide(2)(1) >> map(asScalar >> take(v+2) >> slide(v)(1) >> join >> asVector(v)) Success(r !: e.t) - } + }) // TODO: express as a combination of smaller rules // FIXME: function f needs to be element-wise (a hidden mapVec) - @rule def mapAfterShuffle: Strategy[Rise] = { + def mapAfterShuffle: Strategy[Rise] = rule("mapAfterShuffle", { case e @ App(DepApp(NatKind, asVector(), v: Nat), App(join(), App(DepApp(NatKind, DepApp(NatKind, slide(), v2: Nat), Cst(1)), App(DepApp(NatKind, take(), t: Nat), App(asScalar(), @@ -169,10 +169,10 @@ object vectorize { slide(v)(1) >> join >> asVector(v) )(in.t) Success((preserveType(in) |> shuffle |> map(f)) !: e.t) - } + }) // FIXME: this is very specific - @rule def padEmptyBeforeZipAsVector: Strategy[Rise] = { + def padEmptyBeforeZipAsVector: Strategy[Rise] = rule("padEmptyBeforeZipAsVector", { case e @ App(DepApp(NatKind, padEmpty(), p: Nat), App( Lambda(x, App(App(zip(), App(asV @ DepApp(NatKind, _, v: Nat), App(fst(), x2))), @@ -184,7 +184,7 @@ object vectorize { // FIXME: aligning although we have no alignment information fun(p => zip(asVectorAligned(v)(fst(p)))(asVectorAligned(v)(snd(p)))) ) !: e.t) - } + }) def isAsVector: Rise => Boolean = { case DepApp(NatKind, asVector(), _: Nat) => true diff --git a/src/main/scala/rise/elevate/strategies/algorithmic.scala b/src/main/scala/rise/elevate/strategies/algorithmic.scala index 9ae1b4c93..a58bbfd27 100644 --- a/src/main/scala/rise/elevate/strategies/algorithmic.scala +++ b/src/main/scala/rise/elevate/strategies/algorithmic.scala @@ -5,7 +5,7 @@ import elevate.core.strategies.basic.{applyNTimes, id} import elevate.core.strategies.traversal._ import rise.elevate.strategies.traversal._ import elevate.core.{Failure, RewriteResult, Strategy, Success} -import elevate.macros.StrategyMacro.strategy +import elevate.core.macros._ import rise.elevate.Rise import rise.elevate.rules.algorithmic.fuseReduceMap import rise.elevate.rules.movement._ @@ -24,7 +24,7 @@ object algorithmic { // fission of the first function to be applied inside a map // *(g >> .. >> f) -> *g >> *(.. >> f) - @strategy def mapFirstFission: Strategy[Rise] = e => { + def mapFirstFission: Strategy[Rise] = strategy("mapFirstFission", e => { // TODO: this should be expressed with elevate strategies @scala.annotation.tailrec def mapFirstFissionRec(x: Identifier, f: ToBeTyped[Rise], gx: Rise): RewriteResult[Rise] = { @@ -43,11 +43,11 @@ object algorithmic { case App(primitives.map(), Lambda(x, gx)) => mapFirstFissionRec(x, fun(e => e), gx) case _ => Failure(mapFirstFission) } - } + }) // fission of all the functions chained inside a map // *(g >> .. >> f) -> *g >> .. >> *f - @strategy def mapFullFission: Strategy[Rise] = e => { + def mapFullFission: Strategy[Rise] = strategy("mapFullFission", e => { // TODO: this should be expressed with elevate strategies def mapFullFissionRec(x: Identifier, gx: Rise): Option[ToBeTyped[Rise]] = { gx match { @@ -68,7 +68,7 @@ object algorithmic { } case _ => Failure(mapFullFission) } - } + }) //scalastyle:off def normForReorder(implicit ev: Traversable[Rise]): Strategy[Rise] = @@ -76,9 +76,10 @@ object algorithmic { (fuseReduceMap `@` topDown[Rise]) `;;` (fuseReduceMap `@` topDown[Rise]) `;;` RNF() - @strategy def reorder(l: List[Int])(implicit ev: Traversable[Rise]): Strategy[Rise] = normForReorder `;` (reorderRec(l) `@` topDown[Rise]) + def reorder(l: List[Int])(implicit ev: Traversable[Rise]): Strategy[Rise] = + strategy("reorder", normForReorder `;` (reorderRec(l) `@` topDown[Rise])) - @strategy def reorderRec(l: List[Int])(implicit ev: Traversable[Rise]): Strategy[Rise] = e => { + def reorderRec(l: List[Int])(implicit ev: Traversable[Rise]): Strategy[Rise] = strategy("reorderRec", e => { def freduce(s: Strategy[Rise]): Strategy[Rise] = function(function(argumentOf(reduceSeq.primitive, body(body(s))))) @@ -114,5 +115,5 @@ object algorithmic { case Nil => id(e) case _ => Failure(reorderRec(l)) } - } + }) } diff --git a/src/main/scala/rise/elevate/strategies/normalForm.scala b/src/main/scala/rise/elevate/strategies/normalForm.scala index 9d49c3de4..6948ac388 100644 --- a/src/main/scala/rise/elevate/strategies/normalForm.scala +++ b/src/main/scala/rise/elevate/strategies/normalForm.scala @@ -5,7 +5,7 @@ import elevate.core.strategies.Traversable import elevate.core.strategies.basic._ import elevate.core.strategies.predicate._ import elevate.core.strategies.traversal.one -import elevate.macros.StrategyMacro.strategy +import elevate.core.macros.strategy import rise.core.{primitives => p} import rise.elevate.Rise import rise.elevate.rules._ @@ -17,30 +17,30 @@ import rise.elevate.strategies.predicate._ object normalForm { // Beta-Eta-Normal-Form - @strategy def BENF()(implicit ev: Traversable[Rise]): Strategy[Rise] = - normalize(ev)(etaReduction() <+ betaReduction) + def BENF()(using ev: Traversable[Rise]): Strategy[Rise] = + strategy("BEND", normalize(etaReduction() <+ betaReduction)) // Data-Flow-Normal-Form - @strategy def DFNF()(implicit ev: Traversable[Rise]): Strategy[Rise] = - (BENF() `;` + def DFNF()(using ev: Traversable[Rise]): Strategy[Rise] = + strategy("DFNF", (BENF() `;` // there is no argument of a map which is not eta-abstracted, i.e., every argument of a map is a lambda - normalize(ev)(argumentOf(p.map.primitive, (not(isLambda) `;` etaAbstraction))) `;` + normalize(argumentOf(p.map.primitive, (not(isLambda) `;` etaAbstraction))) `;` // a reduce always contains two lambdas declaring y and acc - normalize(ev)(argumentOf(p.reduce.primitive, (not(isLambda) `;` etaAbstraction))) `;` - normalize(ev)(argumentOf(p.reduce.primitive, body((not(isLambda) `;` etaAbstraction)))) `;` + normalize(argumentOf(p.reduce.primitive, (not(isLambda) `;` etaAbstraction))) `;` + normalize(argumentOf(p.reduce.primitive, body((not(isLambda) `;` etaAbstraction)))) `;` // there is no map(f) without an argument == there is no way to get to a map without visiting two applies // same for reduce and three applies - normalize(ev)( + normalize( one(function(isMap) <+ one(function(isReduce))) `;` // there is a map in two hops, i.e, Something(Apply(map, f)) not(isApply) `;` // and the current node is not an Apply i.e. Something != Apply one((function(isMap) <+ one(function(isReduce))) `;` etaAbstraction) // eta-abstract the inner Apply - )) + ))) // Rewrite-Normal-Form (Fission all maps) - @strategy def RNF()(implicit ev: Traversable[Rise]): Strategy[Rise] = - normalize(ev)(DFNF() `;` mapLastFission()) `;` DFNF() + def RNF()(using ev: Traversable[Rise]): Strategy[Rise] = + strategy("RNF", normalize(DFNF() `;` mapLastFission()) `;` DFNF()) // Codegen-Normal-Form (Fuse all maps) - @strategy def CNF()(implicit ev: Traversable[Rise]): Strategy[Rise] = - normalize(ev)(mapFusion) + def CNF()(using ev: Traversable[Rise]): Strategy[Rise] = + strategy("CNF", normalize(mapFusion)) } diff --git a/src/main/scala/rise/elevate/strategies/predicate.scala b/src/main/scala/rise/elevate/strategies/predicate.scala index 1a866688e..8b55e26a5 100644 --- a/src/main/scala/rise/elevate/strategies/predicate.scala +++ b/src/main/scala/rise/elevate/strategies/predicate.scala @@ -3,6 +3,7 @@ package rise.elevate.strategies import elevate.core._ import elevate.core.strategies.predicate._ import elevate.core.strategies.{Traversable, basic} +import elevate.core.RewriteResult._ import rise.core.DSL.ToBeTyped import rise.core._ import rise.core.primitives._ diff --git a/src/main/scala/rise/elevate/strategies/traversal.scala b/src/main/scala/rise/elevate/strategies/traversal.scala index 3fc08b268..ea4abee75 100644 --- a/src/main/scala/rise/elevate/strategies/traversal.scala +++ b/src/main/scala/rise/elevate/strategies/traversal.scala @@ -6,8 +6,7 @@ import rise.core.primitives._ import rise.elevate.rules.algorithmic._ import elevate.core.strategies.traversal._ import elevate.core.strategies.basic._ -import elevate.macros.CombinatorMacro.combinator -import elevate.macros.StrategyMacro.strategy +import elevate.core.macros._ import rise.elevate.Rise import rise.elevate.rules.traversal._ import rise.elevate.strategies.algorithmic._ @@ -21,31 +20,28 @@ object traversal { // (map λe14. (transpose ((map (map e12)) e14))) // result of `function` // λe14. (transpose ((map (map e12)) e14)) // result of `argument` // (transpose ((map (map e12)) e14)) // result of 'body' -> here we can apply s - @combinator + def fmap: Strategy[Rise] => Strategy[Rise] = - s => function(argumentOf(map.primitive, body(s))) + transformer("fmap", s => function(argumentOf(map.primitive, body(s)))) // fmap applied for expressions in rewrite normal form: // fuse -> fmap -> fission - @combinator - def fmapRNF(implicit ev: Traversable[Rise]): Strategy[Rise] => Strategy[Rise] = + def fmapRNF(implicit ev: Traversable[Rise]): Strategy[Rise] => Strategy[Rise] = transformer("fmapRNF", s => DFNF() `;` mapFusion `;` DFNF() `;` fmap(s) `;` - DFNF() `;` one(mapFullFission) + DFNF() `;` one(mapFullFission)) // applying a strategy to an expression nested in one or multiple lift `map`s - @combinator def mapped(implicit ev: Traversable[Rise]): Strategy[Rise] => Strategy[Rise] = - s => s <+ (e => fmapRNF(ev)(mapped(ev)(s))(e)) + transformer("mapped", s => s <+ (e => fmapRNF(ev)(mapped(ev)(s))(e))) // moves along RNF-normalized expression // e.g., expr == ***f o ****g o *h // move(0)(s) == s(***f o ****g o *h) // move(1)(s) == s(****g o *h) // move(2)(s) == s(*h) - @combinator def moveTowardsArgument(i: Int): Strategy[Rise] => Strategy[Rise] = - s => applyNTimes(i)((e: Strategy[Rise]) => argument(e))(s) + transformer("moveTowardsArgument", s => applyNTimes(i)((e: Strategy[Rise]) => argument(e))(s)) // TRAVERSAL DSL as described in ICFP'20 ///////////////////////////////////// implicit class AtHelper[P](s: Strategy[P]) { @@ -55,39 +51,30 @@ object traversal { traversal(s) } - def outermost: Strategy[Rise] => Strategy[Rise] => Strategy[Rise] = - outermost(default.RiseTraversable) - - @combinator - def outermost(implicit ev: Traversable[Rise]): Strategy[Rise] => Strategy[Rise] => Strategy[Rise] = { - predicate => s => topDown(predicate `;` s) - } + def outermost(using ev: Traversable[Rise] = default.RiseTraversable): Strategy[Rise] => Strategy[Rise] => Strategy[Rise] = + combinator("outermost", { + predicate => s => topDown(predicate `;` s) + }) - def innermost: Strategy[Rise] => Strategy[Rise] => Strategy[Rise] = - traversal.innermost(default.RiseTraversable) - - @combinator - def innermost(implicit ev: Traversable[Rise]): Strategy[Rise] => Strategy[Rise] => Strategy[Rise] = { - predicate => s => bottomUp(predicate `;` s) - } + def innermost(using ev: Traversable[Rise] = default.RiseTraversable): Strategy[Rise] => Strategy[Rise] => Strategy[Rise] = + combinator("innermost", { + predicate => s => bottomUp(predicate `;` s) + }) - @combinator def everywhere: Strategy[Rise] => Strategy[Rise] = - s => basic.normalize(default.RiseTraversable)(s) + transformer("everywhere", s => basic.normalize(s)(using default.RiseTraversable)) - @combinator def check: Strategy[Rise] => Strategy[Rise] => Strategy[Rise] = - predicate => predicate `;` _ + combinator("check", predicate => predicate `;` _) - @strategy - def mapNest(d: Int): Strategy[Rise] = p => (d match { + def mapNest(d: Int): Strategy[Rise] = strategy("mapNest", p => (d match { case x if x == 0 => Success(p) case x if x < 0 => Failure(mapNest(d)) case _ => fmap(mapNest(d-1))(p) - }) + })) - def blocking(implicit ev: Traversable[Rise]): Strategy[Rise] = { - basic.id `@` outermost(ev)(mapNest(2)) - basic.id `@` outermost(ev)(isReduce) + def blocking(using ev: Traversable[Rise]): Strategy[Rise] = { + basic.id `@` outermost(using ev)(mapNest(2)) + basic.id `@` outermost(using ev)(isReduce) } } diff --git a/src/main/scala/shine/C/Compilation/CodeGenerator.scala b/src/main/scala/shine/C/Compilation/CodeGenerator.scala index 2efd4a8a5..0d65b6271 100644 --- a/src/main/scala/shine/C/Compilation/CodeGenerator.scala +++ b/src/main/scala/shine/C/Compilation/CodeGenerator.scala @@ -1046,8 +1046,8 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, case BoolExpr.False => cont(C.AST.Literal("false")) case BoolExpr.ArithPredicate(lhs, rhs, op) => val cOp = op match { - case ArithPredicate.Operator.!= => C.AST.BinaryOperator.!= - case ArithPredicate.Operator.== => C.AST.BinaryOperator.== + case ArithPredicate.Operator.notEqual => C.AST.BinaryOperator.!= + case ArithPredicate.Operator.equal => C.AST.BinaryOperator.== case ArithPredicate.Operator.< => C.AST.BinaryOperator.< case ArithPredicate.Operator.<= => C.AST.BinaryOperator.<= case ArithPredicate.Operator.> => C.AST.BinaryOperator.> @@ -1227,7 +1227,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, // FIXME: we should know that (i - l) is in [0; n[ here array |> exp(env, CIntExpr(i - l) :: ps, arrayExpr => { - def cOperator(op:ArithPredicate.Operator.Value):C.AST.BinaryOperator.Value = op match { + def cOperator(op:ArithPredicate.Operator):C.AST.BinaryOperator.Value = op match { case ArithPredicate.Operator.< => C.AST.BinaryOperator.< case ArithPredicate.Operator.> => C.AST.BinaryOperator.> case ArithPredicate.Operator.>= => C.AST.BinaryOperator.>= @@ -1235,7 +1235,7 @@ class CodeGenerator(val decls: CodeGenerator.Declarations, } def genBranch(lhs:ArithExpr, rhs:ArithExpr, - operator:ArithPredicate.Operator.Value, taken:Expr, notTaken:Expr): Expr = { + operator:ArithPredicate.Operator, taken:Expr, notTaken:Expr): Expr = { import BoolExpr._ arithPredicate(lhs, rhs, operator) match { case True => taken diff --git a/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala b/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala index d9967dd8e..438265ea7 100644 --- a/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala +++ b/src/main/scala/shine/DPIA/Compilation/AcceptorTranslation.scala @@ -1,26 +1,26 @@ package shine.DPIA.Compilation -import shine.DPIA.Compilation.TranslationToImperative._ -import shine.DPIA.DSL._ -import shine.DPIA.Phrases._ +import shine.DPIA.Compilation.TranslationToImperative.* +import shine.DPIA.DSL.* +import shine.DPIA.Phrases.* import rise.core.types.{DataType, Fragment, MatrixLayout, NatIdentifier, NatKind, read, write} -import rise.core.DSL.Type._ -import rise.core.types.DataType._ -import rise.core.substitute.{natInType => substituteNatInType} -import shine.DPIA.Types.{AccType, CommType, ExpType, TypeCheck, comm} -import rise.core.types.DataTypeOps._ -import shine.DPIA._ -import shine.DPIA.primitives.functional._ -import shine.DPIA.primitives.imperative.{Seq => _, _} -import shine.DPIA.primitives.intermediate._ -import shine.OpenMP.primitives.{functional => omp} -import shine.OpenMP.primitives.{intermediate => ompI} -import shine.OpenCL.primitives.{functional => ocl} -import shine.OpenCL.primitives.{intermediate => oclI} -import shine.OpenCL.primitives.{imperative => oclImp} -import shine.cuda.primitives.{functional => cuda} -import shine.cuda.primitives.{intermediate => cudaI} -import shine.cuda.primitives.{imperative => cudaImp} +import rise.core.DSL.Type.* +import rise.core.types.DataType.* +import rise.core.substitute.natInType as substituteNatInType +import shine.DPIA.Types.{AccType, CommType, DepFunType, ExpType, TypeCheck, comm} +import rise.core.types.DataTypeOps.* +import shine.DPIA.* +import shine.DPIA.primitives.functional.* +import shine.DPIA.primitives.imperative.{Seq as _, *} +import shine.DPIA.primitives.intermediate.* +import shine.OpenMP.primitives.functional as omp +import shine.OpenMP.primitives.intermediate as ompI +import shine.OpenCL.primitives.functional as ocl +import shine.OpenCL.primitives.intermediate as oclI +import shine.OpenCL.primitives.imperative as oclImp +import shine.cuda.primitives.functional as cuda +import shine.cuda.primitives.intermediate as cudaI +import shine.cuda.primitives.imperative as cudaImp object AcceptorTranslation { def acc(E: Phrase[ExpType]) @@ -138,7 +138,7 @@ object AcceptorTranslation { case IterateStream(n, dt1, dt2, f, array) => val fI = fun(expT(dt1, read))(x => fun(accT(dt2))(o => acc(f(x))(o))) val i = NatIdentifier(freshName("i")) - str(array)(fun((i: NatIdentifier) ->: + str(array)(fun[DepFunType[NatIdentifier, (ExpType ->: CommType) ->: CommType]]((i: NatIdentifier) ->: (expT(dt1, read) ->: (comm: CommType)) ->: (comm: CommType) )(next => comment("iterateStream") `;` diff --git a/src/main/scala/shine/DPIA/Compilation/StreamTranslation.scala b/src/main/scala/shine/DPIA/Compilation/StreamTranslation.scala index c1610ba09..8fcad5b77 100644 --- a/src/main/scala/shine/DPIA/Compilation/StreamTranslation.scala +++ b/src/main/scala/shine/DPIA/Compilation/StreamTranslation.scala @@ -45,7 +45,7 @@ object StreamTranslation { (implicit context: TranslationContext): Phrase[CommType] = E match { case CircularBuffer(n, alloc, sz, dt1, dt2, load, input) => val i = NatIdentifier(freshName("i")) - str(input)(fun((i: NatIdentifier) ->: + str(input)(fun[DepFunType[NatIdentifier, (ExpType ->: CommType) ->: CommType]]((i: NatIdentifier) ->: (expT(dt1, read) ->: (comm: CommType)) ->: (comm: CommType) )(nextIn => CircularBufferI(n, sz, 1, dt1, dt2, @@ -55,7 +55,7 @@ object StreamTranslation { case MapStream(n, dt1, dt2, f, array) => val i = NatIdentifier(freshName("i")) - str(array)(fun((i: NatIdentifier) ->: + str(array)(fun[DepFunType[NatIdentifier, (ExpType ->: CommType) ->: CommType]]((i: NatIdentifier) ->: (expT(dt1, read) ->: (comm: CommType)) ->: (comm: CommType) )(next => C(nFun(i => @@ -67,7 +67,7 @@ object StreamTranslation { case RotateValues(n, sz, dt, write, input) => val i = NatIdentifier(freshName("i")) - str(input)(fun((i: NatIdentifier) ->: + str(input)(fun[DepFunType[NatIdentifier, (ExpType ->: CommType) ->: CommType]]((i: NatIdentifier) ->: (expT(dt, read) ->: (comm: CommType)) ->: (comm: CommType) )(nextIn => RotateValuesI(n, sz, 1, dt, dt, @@ -77,10 +77,10 @@ object StreamTranslation { case Zip(n, dt1, dt2, _, e1, e2) => val i = NatIdentifier("i") - str(e1)(fun((i: NatIdentifier) ->: + str(e1)(fun[DepFunType[NatIdentifier, (ExpType ->: CommType) ->: CommType]]((i: NatIdentifier) ->: (expT(dt1, read) ->: (comm: CommType)) ->: (comm: CommType) )(next1 => - str(e2)(fun((i: NatIdentifier) ->: + str(e2)(fun[DepFunType[NatIdentifier, (ExpType ->: CommType) ->: CommType]]((i: NatIdentifier) ->: (expT(dt2, read) ->: (comm: CommType)) ->: (comm: CommType) )(next2 => C(nFun(i => fun(expT(dt1 x dt2, read) ->: (comm: CommType))(k => @@ -95,7 +95,7 @@ object StreamTranslation { // OpenCL case ocl.CircularBuffer(a, n, alloc, sz, dt1, dt2, load, input) => val i = NatIdentifier(freshName("i")) - str(input)(fun((i: NatIdentifier) ->: + str(input)(fun[DepFunType[NatIdentifier, (ExpType ->: CommType) ->: CommType]]((i: NatIdentifier) ->: (expT(dt1, read) ->: (comm: CommType)) ->: (comm: CommType) )(nextIn => oclI.CircularBufferI(a, n, alloc, sz, dt1, dt2, @@ -105,7 +105,7 @@ object StreamTranslation { case ocl.RotateValues(a, n, sz, dt, write, input) => val i = NatIdentifier(freshName("i")) - str(input)(fun((i: NatIdentifier) ->: + str(input)(fun[DepFunType[NatIdentifier, (ExpType ->: CommType) ->: CommType]]((i: NatIdentifier) ->: (expT(dt, read) ->: (comm: CommType)) ->: (comm: CommType) )(nextIn => oclI.RotateValuesI(a, n, sz, dt, diff --git a/src/main/scala/shine/DPIA/NatFunCall.scala b/src/main/scala/shine/DPIA/NatFunCall.scala index c5a390df1..26b6f873f 100644 --- a/src/main/scala/shine/DPIA/NatFunCall.scala +++ b/src/main/scala/shine/DPIA/NatFunCall.scala @@ -45,7 +45,7 @@ class NatFunCall(val fun:LetNatIdentifier, val args:Seq[NatFunArg]) extends Arit override lazy val toString = s"⌈${this.callAndParameterListString}⌉" - override val HashSeed = 0x31111112 + override lazy val HashSeed = 0x31111112 override def equals(that: Any): Boolean = that match { case f: NatFunCall => this.name.equals(f.name) && this.args == f.args diff --git a/src/main/scala/shine/DPIA/fromRise.scala b/src/main/scala/shine/DPIA/fromRise.scala index de7fc6b44..926912913 100644 --- a/src/main/scala/shine/DPIA/fromRise.scala +++ b/src/main/scala/shine/DPIA/fromRise.scala @@ -17,12 +17,12 @@ import shine.DPIA.primitives.functional._ import scala.collection.mutable object fromRise { - def apply(expr: r.Expr)(implicit ev: Traversable[Rise]): Phrase[_ <: PhraseType] = { + def apply(expr: r.Expr)(using ev: Traversable[Rise]): Phrase[_ <: PhraseType] = { if (!r.IsClosedForm(expr)) { val (fV, fT) = r.IsClosedForm.varsToClose(expr) throw new Exception(s"expression is not in closed form: $expr\n\n with type ${expr.t}\n free vars: $fV\n free type vars: $fT\n\n") } - val bnfExpr = normalize(ev).apply(betaReduction)(expr).get + val bnfExpr = normalize(betaReduction)(expr).get val rwMap = inferAccess(bnfExpr) expression(bnfExpr, rwMap) } diff --git a/src/main/scala/shine/OpenCL/BuiltInFunctionCall.scala b/src/main/scala/shine/OpenCL/BuiltInFunctionCall.scala index 75d159460..228a5db1c 100644 --- a/src/main/scala/shine/OpenCL/BuiltInFunctionCall.scala +++ b/src/main/scala/shine/OpenCL/BuiltInFunctionCall.scala @@ -14,7 +14,7 @@ class BuiltInFunctionCall private(name: String, val param: Int, range: Range) override lazy val digest: Int = HashSeed ^ /*range.digest() ^*/ name.hashCode ^ param - override val HashSeed = 0x31111111 + override lazy val HashSeed = 0x31111111 override def equals(that: Any): Boolean = that match { case f: BuiltInFunctionCall => diff --git a/src/main/scala/util/gen.scala b/src/main/scala/util/gen.scala index 6bb8d6872..e600a36db 100644 --- a/src/main/scala/util/gen.scala +++ b/src/main/scala/util/gen.scala @@ -14,7 +14,7 @@ object gen { type Phrase = DPIA.Phrases.Phrase[_ <: DPIA.Types.PhraseType] private def exprToPhrase: Expr => Phrase = - shine.DPIA.fromRise(_)(default.RiseTraversable) + shine.DPIA.fromRise(_)(using default.RiseTraversable) type CModule = C.Module diff --git a/src/main/scala/util/monads.scala b/src/main/scala/util/monads.scala index ba8d618a9..f1cb6d12f 100644 --- a/src/main/scala/util/monads.scala +++ b/src/main/scala/util/monads.scala @@ -57,16 +57,15 @@ object monads { def append : Map[K,V] => Map[K,V] => Map[K,V] = x => y => x ++ y } - implicit def PairMonoid[F,S](fst : Monoid[F], snd : Monoid[S]) : Monoid[Tuple2[F,S]] = new Monoid[Tuple2[F,S]] { - override def empty : Tuple2[F,S] = (fst.empty, snd.empty) - override def append : Tuple2[F,S] => Tuple2[F,S] => Tuple2[F,S] = { + implicit def PairMonoid[F, S](fst : Monoid[F], snd : Monoid[S]) : Monoid[(F, S)] = new Monoid[(F, S)] { + override def empty : (F, S) = (fst.empty, snd.empty) + override def append : ((F, S)) => ((F, S)) => (F, S) = { case (f1, s1) => { case (f2, s2) => (fst.append(f1)(f2), snd.append(s1)(s2)) } } } - trait InMonad[M[_]] { trait SetFst[F] { type Type[S] = M[Tuple2[F, S]] } } - trait PairMonoidMonad[F, M[_]] extends Monad[InMonad[M]#SetFst[F]#Type] { - type Pair[T] = InMonad[M]#SetFst[F]#Type[T] + trait PairMonoidMonad[F, M[_]] extends Monad[[S] =>> M[(F, S)]] { + type Pair[T] = M[(F, T)] implicit val monoid : Monoid[F] implicit val monad : Monad[M] override def return_[T]: T => Pair[T] = t => monad.return_((monoid.empty, t)) diff --git a/src/test/scala/apps/Acoustic3D.scala b/src/test/scala/apps/Acoustic3D.scala index 55757f45e..4e0643661 100644 --- a/src/test/scala/apps/Acoustic3D.scala +++ b/src/test/scala/apps/Acoustic3D.scala @@ -3,6 +3,7 @@ package apps import acoustic3D._ import shine.OpenCL._ import util.{Time, TimeSpan, gen} +import reflect.Selectable.reflectiveSelectable class Acoustic3D extends test_util.TestsWithExecutor { private val N = 128 @@ -63,7 +64,7 @@ class Acoustic3D extends test_util.TestsWithExecutor { runsWithSameResult(Seq( ("original", runOriginalKernel("acoustic3D.cl", mat1, mat2)), ("originalMSS", runOriginalKernel("acoustic3DMSS.cl", mat1, mat2)), - ("dpia", runKernel(gen.opencl.kernel.fromExpr(stencil), mat1, mat2)), + ("dpia", runKernel(gen.opencl.kernel.fromExpr(acoustic3D.stencil), mat1, mat2)), ("dpiaMSS", runKernel(gen.opencl.kernel.fromExpr(stencilMSS), mat1, mat2)) )) } diff --git a/src/test/scala/apps/asum.scala b/src/test/scala/apps/asum.scala index f764bf61c..a33702b98 100644 --- a/src/test/scala/apps/asum.scala +++ b/src/test/scala/apps/asum.scala @@ -1,19 +1,20 @@ package apps +import rise.core.* +import rise.core.DSL.* +import rise.core.DSL.HighLevelConstructs.reorderWithStride +import rise.core.DSL.Type.* +import rise.core.primitives.* +import rise.core.types.* +import rise.core.types.DataType.* +import rise.elevate.rules.traversal.default.* import shine.DPIA.Types.ExpType -import shine.OpenCL.{GlobalSize, LocalSize} -import rise.core._ -import rise.core.types._ -import rise.core.types.DataType._ -import rise.core.DSL._ -import rise.core.primitives._ -import Type._ -import HighLevelConstructs.reorderWithStride -import util.{SyntaxChecker, gen} -import rise.elevate.rules.traversal.default._ import shine.OpenCL.KernelExecutor.KernelNoSizes.fromKernelModule +import shine.OpenCL.{GlobalSize, LocalSize} import util.gen.c.function +import util.{SyntaxChecker, gen} +import scala.reflect.Selectable.reflectiveSelectable import scala.util.Random //noinspection TypeAnnotation @@ -45,7 +46,7 @@ class asum extends test_util.TestsWithExecutor { // OpenMP code gen test("Intel derived no warp compiles to syntactically correct OpenMP code") { - import rise.openMP.primitives._ + import rise.openMP.primitives.* val intelDerivedNoWarp1 = depFun((n: Nat) => fun(inputT(n))(input => @@ -69,7 +70,7 @@ class asum extends test_util.TestsWithExecutor { test( "Second kernel of Intel derived compiles to syntactically correct OpenMP code" ) { - import rise.openMP.primitives._ + import rise.openMP.primitives.* val intelDerived2 = depFun((n: Nat) => fun(inputT(n))(input => @@ -87,7 +88,7 @@ class asum extends test_util.TestsWithExecutor { test( "AMD/Nvidia second kernel derived compiles to syntactically correct OpenMP code" ) { - import rise.openMP.primitives._ + import rise.openMP.primitives.* val amdNvidiaDerived2 = depFun((n: Nat) => fun(inputT(n))(input => @@ -110,8 +111,8 @@ class asum extends test_util.TestsWithExecutor { } { // OpenCL code gen - import rise.openCL.DSL._ - import rise.openCL.primitives.{oclReduceSeq, oclIterate} + import rise.openCL.DSL.* + import rise.openCL.primitives.{oclIterate, oclReduceSeq} import shine.OpenCL val random = new Random() @@ -126,7 +127,7 @@ class asum extends test_util.TestsWithExecutor { localSize: LocalSize, globalSize: GlobalSize )(n: Int, input: Array[Float]): Array[Float] = { - import shine.OpenCL._ + import shine.OpenCL.* val k = gen.opencl.kernel.fromExpr(kernel) val runKernel = k.as[Args `(` Int `,` Array[Float], Array[Float]] val (output, _) = runKernel(localSize, globalSize)(n `,` input) diff --git a/src/test/scala/apps/cameraPipelineCheck.scala b/src/test/scala/apps/cameraPipelineCheck.scala index 08e42e261..15124ca3d 100644 --- a/src/test/scala/apps/cameraPipelineCheck.scala +++ b/src/test/scala/apps/cameraPipelineCheck.scala @@ -98,8 +98,8 @@ float clamp_f32(float v, float l, float h) { #define pow_f32 powf """ - val DFNF = rise.elevate.strategies.normalForm.DFNF()(alternative.RiseTraversable) - val CNF = rise.elevate.strategies.normalForm.CNF()(alternative.RiseTraversable) + val DFNF = rise.elevate.strategies.normalForm.DFNF()(using alternative.RiseTraversable) + val CNF = rise.elevate.strategies.normalForm.CNF()(using alternative.RiseTraversable) def check( lowered: Rise, callCFun: String => String, diff --git a/src/test/scala/apps/convolution1D.scala b/src/test/scala/apps/convolution1D.scala index 87489ab9c..34629508a 100644 --- a/src/test/scala/apps/convolution1D.scala +++ b/src/test/scala/apps/convolution1D.scala @@ -1,18 +1,20 @@ package apps import apps.separableConvolution2D._ -import rise.core.primitives._ -import rise.core.DSL._ -import rise.core.DSL.HighLevelConstructs._ -import rise.core.DSL.Type._ +import rise.core.primitives.* +import rise.core.DSL.* +import rise.core.DSL.HighLevelConstructs.* +import rise.core.DSL.Type.* import rise.openCL.primitives.oclReduceSeqUnroll -import rise.openCL.DSL._ -import rise.core._ -import rise.core.types._ -import rise.core.types.DataType._ +import rise.openCL.DSL.* +import rise.core.* +import rise.core.types.* +import rise.core.types.DataType.* import shine.OpenCL.KernelExecutor.KernelNoSizes.fromKernelModule import util.gen +import reflect.Selectable.reflectiveSelectable + class convolution1D extends test_util.Tests { val binomialWeights = binomialWeightsV @@ -22,7 +24,7 @@ class convolution1D extends test_util.Tests { )) val binomial: ToBeTyped[Expr] = - slide(3)(1) >> map(fun(nbh => dot(nbh)(binomialWeights))) + slide(3)(1) >> map(fun(nbh => separableConvolution2D.dot(nbh)(binomialWeights))) val binomialSeq: ToBeTyped[Expr] = slide(3)(1) >> mapSeq(fun(nbh => dotSeq(nbh)(binomialWeights))) @@ -53,7 +55,7 @@ class convolution1D extends test_util.Tests { val binomialTileDep: ToBeTyped[Expr] = impl{ n: Nat => // depSlide(34)(32) >> depTile(32)( - depMapSeq(depFun { i: Nat => // TODO: depMapGlobal(0) + depMapSeq(depFun { (i: Nat) => // TODO: depMapGlobal(0) import arithexpr.arithmetic.IfThenElse import arithexpr.arithmetic.BoolExpr.arithPredicate import arithexpr.arithmetic.BoolExpr.ArithPredicate.Operator diff --git a/src/test/scala/apps/dot.scala b/src/test/scala/apps/dot.scala index a0ab165b7..277d39073 100644 --- a/src/test/scala/apps/dot.scala +++ b/src/test/scala/apps/dot.scala @@ -37,7 +37,7 @@ class dot extends test_util.Tests { test("Simple dot product translation to phrase works and preserves types") { import rise.core.types.DataType._ import shine.DPIA._ - val phrase = shine.DPIA.fromRise(simpleDotProduct)(default.RiseTraversable) + val phrase = shine.DPIA.fromRise(simpleDotProduct)(using default.RiseTraversable) val N = phrase.t.asInstanceOf[`(nat)->:`[ExpType ->: ExpType]].x val dt = f32 diff --git a/src/test/scala/apps/gemmTensorCheck.scala b/src/test/scala/apps/gemmTensorCheck.scala index f7a594f5b..d251ceee0 100644 --- a/src/test/scala/apps/gemmTensorCheck.scala +++ b/src/test/scala/apps/gemmTensorCheck.scala @@ -7,6 +7,8 @@ import shine.OpenCL._ import shine.cuda.KernelExecutor.{KernelNoSizes, KernelWithSizes} import util._ +import reflect.Selectable.reflectiveSelectable + //Cause some TypeChecking-Bugs the execution of the entire test-class could be fail //Running each test individually should be successfull class gemmTensorCheck extends test_util.TestWithCUDA { diff --git a/src/test/scala/apps/harrisCornerDetectionHalideCheck.scala b/src/test/scala/apps/harrisCornerDetectionHalideCheck.scala index c923237d0..3aaaa1b6d 100644 --- a/src/test/scala/apps/harrisCornerDetectionHalideCheck.scala +++ b/src/test/scala/apps/harrisCornerDetectionHalideCheck.scala @@ -7,6 +7,8 @@ import rise.core._ import shine.OpenCL.KernelExecutor.KernelNoSizes.fromKernelModule import util.gen +import reflect.Selectable.reflectiveSelectable + class harrisCornerDetectionHalideCheck extends test_util.TestsWithExecutor { diff --git a/src/test/scala/apps/localLaplacianCheck.scala b/src/test/scala/apps/localLaplacianCheck.scala index 55acb5005..d43dc4470 100644 --- a/src/test/scala/apps/localLaplacianCheck.scala +++ b/src/test/scala/apps/localLaplacianCheck.scala @@ -14,7 +14,7 @@ class localLaplacianCheck extends test_util.TestsWithExecutor { private val beta = 1.0f test("localLaplacian typechecks") { - logger.debug(localLaplacian(2).toExpr.t) + logger.debug(localLaplacian.localLaplacian(2).toExpr.t) } def lowerOMP(e: ToBeTyped[Expr]): Expr = diff --git a/src/test/scala/apps/mmTensorCheck.scala b/src/test/scala/apps/mmTensorCheck.scala index fb7302f25..69fd84298 100644 --- a/src/test/scala/apps/mmTensorCheck.scala +++ b/src/test/scala/apps/mmTensorCheck.scala @@ -6,6 +6,8 @@ import shine.OpenCL._ import shine.cuda.KernelExecutor.{KernelNoSizes, KernelWithSizes} import util._ +import reflect.Selectable.reflectiveSelectable + class mmTensorCheck extends test_util.TestWithCUDA { import mmCheckUtils._ diff --git a/src/test/scala/apps/separableConvolution2DCheck.scala b/src/test/scala/apps/separableConvolution2DCheck.scala index 5440cf84a..fcefbf92d 100644 --- a/src/test/scala/apps/separableConvolution2DCheck.scala +++ b/src/test/scala/apps/separableConvolution2DCheck.scala @@ -12,6 +12,8 @@ import shine.OpenCL.KernelExecutor.KernelNoSizes.fromKernelModule import util.gen import util.gen.c.function +import reflect.Selectable.reflectiveSelectable + object separableConvolution2DCheck { def wrapExpr(e: ToBeTyped[Expr]): ToBeTyped[Expr] = { import arithexpr.arithmetic.{PosInf, RangeAdd} diff --git a/src/test/scala/apps/separableConvolution2DNaiveEqsat.scala b/src/test/scala/apps/separableConvolution2DNaiveEqsat.scala index 76d831f43..8b34c3546 100644 --- a/src/test/scala/apps/separableConvolution2DNaiveEqsat.scala +++ b/src/test/scala/apps/separableConvolution2DNaiveEqsat.scala @@ -28,7 +28,7 @@ class separableConvolution2DNaiveEqsat extends test_util.Tests { private val separateDotT: Strategy[Rise] = separateDotVH(weights2d, weightsV, weightsH) - private val BENF = rise.elevate.strategies.normalForm.BENF()(alternative.RiseTraversable) + private val BENF = rise.elevate.strategies.normalForm.BENF()(using alternative.RiseTraversable) case class ExprWrapper(e: Expr) { override def hashCode(): Int = exprAlphaEq(typeErasure).hash(e) diff --git a/src/test/scala/apps/separableConvolution2DRewrite.scala b/src/test/scala/apps/separableConvolution2DRewrite.scala index 195f1e9ec..2d3651937 100644 --- a/src/test/scala/apps/separableConvolution2DRewrite.scala +++ b/src/test/scala/apps/separableConvolution2DRewrite.scala @@ -29,10 +29,10 @@ class separableConvolution2DRewrite extends test_util.Tests { private val P = padClamp2D(1) private val Sh = slide(3)(1) private val Sv = slide(3)(1) - private val Dh = dot(weightsH) - private val Dv = dot(weightsV) + private val Dh = separableConvolution2D.dot(weightsH) + private val Dv = separableConvolution2D.dot(weightsV) - private val BENF = rise.elevate.strategies.normalForm.BENF()(alternative.RiseTraversable) + private val BENF = rise.elevate.strategies.normalForm.BENF()(using alternative.RiseTraversable) private def ben_eq(a: Expr, b: Expr): Boolean = { val na = BENF(a).get @@ -73,7 +73,7 @@ class separableConvolution2DRewrite extends test_util.Tests { test("base to scanline") { rewrite_steps(base(weights2d), scala.collection.Seq( idS - -> (P >> *(Sh) >> Sv >> *(T) >> *(*(fun(nbh => dot(join(weights2d))(join(nbh)))))), + -> (P >> *(Sh) >> Sv >> *(T) >> *(*(fun(nbh => separableConvolution2D.dot(join(weights2d))(join(nbh)))))), topDown(separateDotT) -> (P >> *(Sh) >> Sv >> *(T) >> *(*(T >> *(Dv) >> Dh))), topDown(`*f >> S -> S >> **f`) @@ -113,7 +113,7 @@ class separableConvolution2DRewrite extends test_util.Tests { test("base to scanline (mapLastFission)") { rewrite_steps(base(weights2d), scala.collection.Seq( idS - -> (P >> *(Sh) >> Sv >> *(T) >> *(*(fun(nbh => dot(join(weights2d))(join(nbh)))))), + -> (P >> *(Sh) >> Sv >> *(T) >> *(*(fun(nbh => separableConvolution2D.dot(join(weights2d))(join(nbh)))))), topDown(separateDotT) -> (P >> *(Sh) >> Sv >> *(T) >> *(*(T >> *(Dv) >> Dh))), topDown(`*f >> S -> S >> **f`) diff --git a/src/test/scala/apps/stencil.scala b/src/test/scala/apps/stencil.scala index 43f3cb097..4d30f773f 100644 --- a/src/test/scala/apps/stencil.scala +++ b/src/test/scala/apps/stencil.scala @@ -15,6 +15,7 @@ import shine.OpenCL.{GlobalSize, KernelExecutor, LocalSize} import util.Time.ms import util.gen.c.function import util.{Display, TimeSpan, gen} +import arithexpr.arithmetic.Cst import scala.util.Random @@ -125,9 +126,9 @@ class stencil extends test_util.Tests { input |> padCst(padSize)(padSize)(lf32(0.0f)) |> slide(stencilSize)(1) |> - partition(3)(n2nFun(m => + partition(Cst(3))(n2nFun(m => SteppedCase(m, - Seq(padSize, n - 2 * padSize + ((1 + stencilSize) % 2), padSize) + Seq[Nat](padSize, n - 2 * padSize + ((1 + stencilSize) % 2), padSize) ) )) |> depMapSeq(mapGlobal(fun(nbh => @@ -199,8 +200,8 @@ class stencil extends test_util.Tests { padCst2D(padSize)(lf32(0.0f)) |> slide2D(stencilSize, 1) |> // partition2D(padSize, N - 2*padSize + ((1 + stencilSize) % 2)) :>> - partition(3)(n2nFun(m => - SteppedCase(m, Seq(padSize, n - 2 * padSize, padSize)) + partition(Cst(3))(n2nFun(m => + SteppedCase(m, Seq[Nat](padSize, n - 2 * padSize, padSize)) )) |> depMapSeq( // mapGlobal(0)(depMapSeqUnroll(mapGlobal(1)(join() >>> reduceSeq(add, 0.0f)))) diff --git a/src/test/scala/rise/core/dependentTypes.scala b/src/test/scala/rise/core/dependentTypes.scala index a927d588a..23813c932 100644 --- a/src/test/scala/rise/core/dependentTypes.scala +++ b/src/test/scala/rise/core/dependentTypes.scala @@ -160,7 +160,7 @@ class dependentTypes extends test_util.Tests { depMapSeq(depFun((_: Nat) => mapSeq(fun(x => x))))(array) )) - val inferred: Expr = infer(e) + val inferred: Expr = DSL.infer(e) logger.debug(inferred) logger.debug(inferred.t) assert(inferred.t =~= @@ -175,7 +175,7 @@ class dependentTypes extends test_util.Tests { depMapSeq(depFun((_: Nat) => reduceSeq(fun(x => fun(y => x + y)))(lf32(0.0f))))(array) )) - val inferred: Expr = infer(e) + val inferred: Expr = DSL.infer(e) logger.debug(inferred) logger.debug(inferred.t) assert(inferred.t =~= @@ -204,7 +204,7 @@ class dependentTypes extends test_util.Tests { } )))) - val inferred: Expr = infer(e) + val inferred: Expr = DSL.infer(e) logger.debug(inferred) logger.debug(inferred.t) function.asStringFromExpr(inferred) diff --git a/src/test/scala/rise/core/showScalaTest.scala b/src/test/scala/rise/core/showScalaTest.scala index 6c73c4798..fa4ed7f9a 100644 --- a/src/test/scala/rise/core/showScalaTest.scala +++ b/src/test/scala/rise/core/showScalaTest.scala @@ -29,26 +29,26 @@ class showScalaTest extends test_util.Tests { }))(lf32(0.0f))(zip(join(elem))(weights)) ) - test("show dotElemWeights as an example") { - import scala.reflect.runtime.universe - import scala.tools.reflect.ToolBox - - val typedDotElemWeights = dotElemWeights.toExpr - - val untypedScala = prefixImports(showScala.expr(dotElemWeights)) - val typedScala = prefixImports(showScala.expr(typedDotElemWeights)) - - logger.debug(untypedScala) - logger.debug(typedScala) - - val toolbox = universe.runtimeMirror(getClass.getClassLoader).mkToolBox() - val expr = toolbox.eval(toolbox.parse(untypedScala)).asInstanceOf[Expr] - val typedExpr = toolbox.eval(toolbox.parse(typedScala)).asInstanceOf[Expr] - - logger.debug(expr) - logger.debug(typedExpr) - - assert(expr =~~= dotElemWeights.toUntypedExpr) - assert(typedExpr =~~= typedDotElemWeights) - } +// test("show dotElemWeights as an example") { +// import scala.reflect.runtime.universe +// import scala.tools.reflect.ToolBox +// +// val typedDotElemWeights = dotElemWeights.toExpr +// +// val untypedScala = prefixImports(showScala.expr(dotElemWeights)) +// val typedScala = prefixImports(showScala.expr(typedDotElemWeights)) +// +// logger.debug(untypedScala) +// logger.debug(typedScala) +// +// val toolbox = universe.runtimeMirror(getClass.getClassLoader).mkToolBox() +// val expr = toolbox.eval(toolbox.parse(untypedScala)).asInstanceOf[Expr] +// val typedExpr = toolbox.eval(toolbox.parse(typedScala)).asInstanceOf[Expr] +// +// logger.debug(expr) +// logger.debug(typedExpr) +// +// assert(expr =~~= dotElemWeights.toUntypedExpr) +// assert(typedExpr =~~= typedDotElemWeights) +// } } diff --git a/src/test/scala/rise/core/traverseTest.scala b/src/test/scala/rise/core/traverseTest.scala index 908842454..09e8d8fc7 100644 --- a/src/test/scala/rise/core/traverseTest.scala +++ b/src/test/scala/rise/core/traverseTest.scala @@ -32,7 +32,7 @@ class traverseTest extends test_util.Tests { test("traversing an expression should traverse identifiers in order") { val equivs = Seq(Seq(0, 3), Seq(1, 2)) - val result = traverse(e, new ExprTraceVisitor()) + val result = traverse.traverse(e, new ExprTraceVisitor()) // the expression should not have changed assert(result._2 == e) @@ -43,7 +43,7 @@ class traverseTest extends test_util.Tests { test("traversing a type should traverse identifiers in order") { val equivs = Seq(Seq(0, 2, 4), Seq(1, 3, 5)) - val result = traverse(e.t, new TypeTraceVisitor()) + val result = traverse.traverse(e.t, new TypeTraceVisitor()) // the type should not have changed assert(result._2 == e.t) // the trace should match expectations @@ -59,7 +59,7 @@ class traverseTest extends test_util.Tests { } } - val result = traverse(e, new Visitor) + val result = traverse.traverse(e, new Visitor) // the expression should have changed assert(result =~~= @@ -86,7 +86,7 @@ class traverseTest extends test_util.Tests { } } - val result = traverse(e, new Visitor) + val result = traverse.traverse(e, new Visitor) // the expression should have changed val expected = depFun((n: Nat) => diff --git a/src/test/scala/rise/elevate/algorithmic.scala b/src/test/scala/rise/elevate/algorithmic.scala index 42e59defc..039746501 100644 --- a/src/test/scala/rise/elevate/algorithmic.scala +++ b/src/test/scala/rise/elevate/algorithmic.scala @@ -31,10 +31,10 @@ class algorithmic extends test_util.Tests { def tileND = rise.elevate.strategies.tiling.tileND(default.RiseTraversable) def tileNDList = rise.elevate.strategies.tiling.tileNDList(default.RiseTraversable) - def DFNF = rise.elevate.strategies.normalForm.DFNF()(default.RiseTraversable) - def RNF = rise.elevate.strategies.normalForm.RNF()(default.RiseTraversable) - def CNF = rise.elevate.strategies.normalForm.CNF()(default.RiseTraversable) - def BENF = rise.elevate.strategies.normalForm.BENF()(default.RiseTraversable) + def DFNF = rise.elevate.strategies.normalForm.DFNF()(using default.RiseTraversable) + def RNF = rise.elevate.strategies.normalForm.RNF()(using default.RiseTraversable) + def CNF = rise.elevate.strategies.normalForm.CNF()(using default.RiseTraversable) + def BENF = rise.elevate.strategies.normalForm.BENF()(using default.RiseTraversable) // Loop Interchange @@ -55,7 +55,7 @@ class algorithmic extends test_util.Tests { )) assert(betaEtaEquals( - body(body(fmap(loopInterchange) `;` DFNF `;` RNF))(input).get, + body(body(rise.elevate.strategies.traversal.fmap(loopInterchange) `;` DFNF `;` RNF))(input).get, gold )) } @@ -286,7 +286,7 @@ class algorithmic extends test_util.Tests { val typed = tile.apply(mm).get // these should be correct, it's just that the mapAcceptorTranslation for split is not defined yet - val lower: Strategy[Rise] = DFNF `;` CNF `;` normalize.apply(lowering.mapSeq <+ lowering.reduceSeq) `;` BENF + val lower: Strategy[Rise] = DFNF `;` CNF `;` normalize(lowering.mapSeq <+ lowering.reduceSeq) `;` BENF logger.debug(gen.c.function.asStringFromExpr(lower(typed).get)) /// TILE + REORDER diff --git a/src/test/scala/rise/elevate/circularBuffering.scala b/src/test/scala/rise/elevate/circularBuffering.scala index 21acf7f80..80f7cad20 100644 --- a/src/test/scala/rise/elevate/circularBuffering.scala +++ b/src/test/scala/rise/elevate/circularBuffering.scala @@ -183,7 +183,7 @@ class circularBuffering extends test_util.Tests { } private val id = fun(x => x) - private val norm = normalize(alternative.RiseTraversable).apply(gentleBetaReduction()) + private val norm = normalize[rise.core.Expr](gentleBetaReduction())(using alternative.RiseTraversable) private def rewriteSteps(a: Rise, steps: scala.collection.Seq[(Strategy[Rise], Rise)]): Unit = { steps.foldLeft[Rise](norm(a).get)({ case (e, (s, expected)) => @@ -215,7 +215,7 @@ class circularBuffering extends test_util.Tests { x |> slide(4)(1) >> map(sum) )) ), - normalize.apply(mapFusion) + normalize(mapFusion) -> ( slide(3)(1) >> map(sum) >> fun(x => makeArray(2)( @@ -231,7 +231,7 @@ class circularBuffering extends test_util.Tests { x |> sum ))) >> transpose ), - (normalize.apply(lowering.reduceSeq) `;` + (normalize(lowering.reduceSeq) `;` topDown(dropBeforeTake) `;` topDown(isApply `;` one(isApply `;` one(isMakeArray)) `;` lowering.mapSeqUnrollWrite) `;` diff --git a/src/test/scala/rise/elevate/fissionFusion.scala b/src/test/scala/rise/elevate/fissionFusion.scala index b3df4c3ac..d07d175e0 100644 --- a/src/test/scala/rise/elevate/fissionFusion.scala +++ b/src/test/scala/rise/elevate/fissionFusion.scala @@ -14,7 +14,7 @@ import rise.elevate.strategies.algorithmic.{mapFirstFission, mapFullFission} class fissionFusion extends test_util.Tests { - val BENF = rise.elevate.strategies.normalForm.BENF()(RiseTraversable) + val BENF = rise.elevate.strategies.normalForm.BENF()(using RiseTraversable) def eq(a: Expr, b: Expr): Unit = { if (! (BENF(a).get =~= BENF(b).get)) { @@ -76,6 +76,6 @@ class fissionFusion extends test_util.Tests { fun(f1 => fun(f2 => fun(f3 => map(f1 >> f2 >> f3)))), position(3)(mapFullFission), fun(f1 => fun(f2 => fun(f3 => map(f1) >> map(f2) >> map(f3)))), - normalize(RiseTraversable)(mapFusion)) + normalize(mapFusion)(using RiseTraversable)) } } diff --git a/src/test/scala/rise/elevate/halide.scala b/src/test/scala/rise/elevate/halide.scala index 09babc0bc..0bb902f3b 100644 --- a/src/test/scala/rise/elevate/halide.scala +++ b/src/test/scala/rise/elevate/halide.scala @@ -11,7 +11,7 @@ import rise.elevate.strategies.halide._ class halide extends test_util.Tests { - private val DFNF = rise.elevate.strategies.normalForm.DFNF()(RiseTraversable) + private val DFNF = rise.elevate.strategies.normalForm.DFNF()(using RiseTraversable) private def LCNFrewrite(a: Rise, s: Strategy[Rise], b: Rise): Unit = { val (closedA, nA) = makeClosed.withCount(a) diff --git a/src/test/scala/rise/elevate/movement.scala b/src/test/scala/rise/elevate/movement.scala index e38274105..32d0cdd7e 100644 --- a/src/test/scala/rise/elevate/movement.scala +++ b/src/test/scala/rise/elevate/movement.scala @@ -16,8 +16,8 @@ class movement extends test_util.Tests { // transpose - val BENF = rise.elevate.strategies.normalForm.BENF()(RiseTraversable) - val DFNF = rise.elevate.strategies.normalForm.DFNF()(RiseTraversable) + val BENF = rise.elevate.strategies.normalForm.BENF()(using RiseTraversable) + val DFNF = rise.elevate.strategies.normalForm.DFNF()(using RiseTraversable) def betaEtaEquals(a: Rise, b: Rise): Boolean = { val na = BENF(a).get diff --git a/src/test/scala/rise/elevate/tiling.scala b/src/test/scala/rise/elevate/tiling.scala index c909851df..adf708d25 100644 --- a/src/test/scala/rise/elevate/tiling.scala +++ b/src/test/scala/rise/elevate/tiling.scala @@ -22,13 +22,13 @@ import scala.language.implicitConversions class tiling extends test_util.Tests { - val BENF = rise.elevate.strategies.normalForm.BENF()(default.RiseTraversable) - val DFNF = rise.elevate.strategies.normalForm.DFNF()(default.RiseTraversable) - val CNF = rise.elevate.strategies.normalForm.CNF()(default.RiseTraversable) - val RNF = rise.elevate.strategies.normalForm.RNF()(default.RiseTraversable) + val BENF = rise.elevate.strategies.normalForm.BENF()(using default.RiseTraversable) + val DFNF = rise.elevate.strategies.normalForm.DFNF()(using default.RiseTraversable) + val CNF = rise.elevate.strategies.normalForm.CNF()(using default.RiseTraversable) + val RNF = rise.elevate.strategies.normalForm.RNF()(using default.RiseTraversable) - def tileND = rise.elevate.strategies.tiling.tileND(default.RiseTraversable) - def tileNDList = rise.elevate.strategies.tiling.tileNDList(default.RiseTraversable) + def tileND = rise.elevate.strategies.tiling.tileND(using default.RiseTraversable) + def tileNDList = rise.elevate.strategies.tiling.tileNDList(using default.RiseTraversable) implicit def rewriteResultToExpr(r: RewriteResult[Expr]): Expr = r.get @@ -79,7 +79,7 @@ class tiling extends test_util.Tests { // inner assert(betaEtaEquals( - body(body(fmap(tileND(1)(tileSize))))(input2D), + body(body(rise.elevate.strategies.traversal.fmap(tileND(1)(tileSize))))(input2D), λ(i => λ(f => *(J o **(f) o S) $ i)) )) } @@ -95,13 +95,13 @@ class tiling extends test_util.Tests { // middle assert(betaEtaEquals( - body(body(fmap(tileND(1)(tileSize))))(input3D), + body(body(rise.elevate.strategies.traversal.fmap(tileND(1)(tileSize))))(input3D), λ(i => λ(f => *(J o ***(f) o S) $ i)) )) // inner assert(betaEtaEquals( - body(body(fmap(fmap(tileND(1)(tileSize)))))(input3D), + body(body(rise.elevate.strategies.traversal.fmap(rise.elevate.strategies.traversal.fmap(tileND(1)(tileSize)))))(input3D), λ(i => λ(f => **(J o **(f) o S) $ i)) )) } @@ -118,19 +118,19 @@ class tiling extends test_util.Tests { // O assert(betaEtaEquals( - body(body(fmap(tileND(1)(tileSize))))(input4D), + body(body(rise.elevate.strategies.traversal.fmap(tileND(1)(tileSize))))(input4D), λ(i => λ(f => *(J o ****(f) o S) $ i)) )) // N assert(betaEtaEquals( - body(body(fmap(fmap(tileND(1)(tileSize)))))(input4D), + body(body(rise.elevate.strategies.traversal.fmap(rise.elevate.strategies.traversal.fmap(tileND(1)(tileSize)))))(input4D), λ(i => λ(f => **(J o ***(f) o S) $ i)) )) // M assert(betaEtaEquals( - body(body(fmap(fmap(fmap(tileND(1)(tileSize))))))(input4D), + body(body(rise.elevate.strategies.traversal.fmap(rise.elevate.strategies.traversal.fmap(rise.elevate.strategies.traversal.fmap(tileND(1)(tileSize))))))(input4D), λ(i => λ(f => ***(J o **(f) o S) $ i)) )) } @@ -155,7 +155,7 @@ class tiling extends test_util.Tests { // inner two assert(betaEtaEquals( - body(body(fmap(tileND(2)(tileSize))))(input3D), + body(body(rise.elevate.strategies.traversal.fmap(tileND(2)(tileSize))))(input3D), DFNF(λ(i => λ(f => *(J o **(J) o *(T) o ****(f) o *(T) o **(S) o S) $ i))) )) } @@ -171,13 +171,13 @@ class tiling extends test_util.Tests { // middle two assert(betaEtaEquals( - body(body(fmap(tileND(2)(tileSize))))(input4D), + body(body(rise.elevate.strategies.traversal.fmap(tileND(2)(tileSize))))(input4D), λ(i => λ(f => *(J o **(J) o *(T) o *****(f) o *(T) o **(S) o S) $ i)) )) // inner two assert(betaEtaEquals( - body(body(fmap(fmap(tileND(2)(tileSize)))))(input4D), + body(body(rise.elevate.strategies.traversal.fmap(rise.elevate.strategies.traversal.fmap(tileND(2)(tileSize)))))(input4D), λ(i => λ(f => **(J o **(J) o *(T) o ****(f) o *(T) o **(S) o S) $ i)) )) } @@ -212,7 +212,7 @@ class tiling extends test_util.Tests { // inner three assert(betaEtaEquals( - body(body(fmap(tileND(3)(tileSize))))(input4D), + body(body(rise.elevate.strategies.traversal.fmap(tileND(3)(tileSize))))(input4D), λ(i => λ(f => *( J o **(J) o ****(J) o ***(T) o *(T) o **(T) o @@ -259,7 +259,7 @@ class tiling extends test_util.Tests { // todo: this should use mapSeqCompute and CNF instead of RNF // ... but mapAcceptorTranslation for split is missing - val lower: Strategy[Rise] = DFNF `;` CNF `;` normalize.apply(lowering.mapSeq) `;` BENF + val lower: Strategy[Rise] = DFNF `;` CNF `;` normalize(lowering.mapSeq) `;` BENF val identity = depFun((t: DataType) => foreignFun("identity", immutable.Seq("y"), "{ return y; }", t ->: t)) val floatId: Expr = identity(f32) @@ -293,7 +293,7 @@ class tiling extends test_util.Tests { //TODO make this work without implicit array assignments ignore("codegen two innermost of three loops") { val highLevel = wrapInLambda(3, i => ***!(floatId) $ i, inputT(3, _)) - val tiled = one(one(one(body(fmap(tileND(2)(tileSize)))))).apply(highLevel).get + val tiled = one(one(one(body(rise.elevate.strategies.traversal.fmap(tileND(2)(tileSize)))))).apply(highLevel).get logger.debug(gen.c.function.asStringFromExpr(lower(highLevel))) diff --git a/src/test/scala/rise/elevate/traversals.scala b/src/test/scala/rise/elevate/traversals.scala index f43b2981c..a8cbb41d0 100644 --- a/src/test/scala/rise/elevate/traversals.scala +++ b/src/test/scala/rise/elevate/traversals.scala @@ -18,9 +18,9 @@ import rise.elevate.rules.traversal.{argument, argumentOf, body, function} class traversals extends test_util.Tests { def tileND = rise.elevate.strategies.tiling.tileND(RiseTraversable) - val DFNF = rise.elevate.strategies.normalForm.DFNF()(RiseTraversable) - val RNF = rise.elevate.strategies.normalForm.RNF()(RiseTraversable) - val FNF = rise.elevate.meta.fission.FNF(rise.elevate.meta.traversal.MetaRiseTraversable(RiseTraversable)) + val DFNF = rise.elevate.strategies.normalForm.DFNF()(using RiseTraversable) + val RNF = rise.elevate.strategies.normalForm.RNF()(using RiseTraversable) + val FNF = rise.elevate.meta.fission.FNF(using rise.elevate.meta.traversal.MetaRiseTraversable(RiseTraversable)) test("rewrite simple elevate strategy") { val expr = fun(f => fun(g => map(f) >> map(g))) diff --git a/src/test/scala/rise/elevate/tvmGemm.scala b/src/test/scala/rise/elevate/tvmGemm.scala index 0cf49dc0f..72c0ac912 100644 --- a/src/test/scala/rise/elevate/tvmGemm.scala +++ b/src/test/scala/rise/elevate/tvmGemm.scala @@ -25,9 +25,9 @@ import _root_.util.gen object tvmGemm { val outermost: (Strategy[Rise]) => (Strategy[Rise]) => Strategy[Rise] = - traversal.outermost(default.RiseTraversable) + traversal.outermost(using default.RiseTraversable) val innermost: (Strategy[Rise]) => (Strategy[Rise]) => Strategy[Rise] = - traversal.innermost(default.RiseTraversable) + traversal.innermost(using default.RiseTraversable) //// MM INPUT EXPRESSION ///////////////////////////////////////////////////// val N = 1024 @@ -48,7 +48,7 @@ object tvmGemm { //// ICFP'20 TVM - STRATEGIES //////////////////////////////////////////////// // -- BASELINE --------------------------------------------------------------- - val baseline: Strategy[Rise] = DFNF()(default.RiseTraversable) `;` + val baseline: Strategy[Rise] = DFNF()(using default.RiseTraversable) `;` fuseReduceMap `@` topDown[Rise] // -- BLOCKING --------------------------------------------------------------- @@ -83,7 +83,7 @@ object tvmGemm { val permuteB: Strategy[Rise] = splitJoin2(32) `;` DFNF() `;` argument(idAfter) `;` topDown(liftId()) `;` topDown(createTransposePair) `;` RNF() `;` - argument(argument(idAfter)) `;` normalize.apply(liftId()) `;` + argument(argument(idAfter)) `;` normalize(liftId()) `;` topDown(idToCopy) val packB: Strategy[Rise] = @@ -171,7 +171,7 @@ class tvmGemm extends test_util.Tests { val versionUC = version.toUpperCase() // reset rewrite step counter - Success.rewriteCount = 0 + elevate.core.SuccessRewriteCounter.rewriteCount = 0 // rewrite the matmul input expresssion val time0 = currentTimeSec @@ -179,7 +179,7 @@ class tvmGemm extends test_util.Tests { val time1 = currentTimeSec logger.debug(s"[$versionUC] rewrite time: ${time1 - time0}s") if (generateFiles) { - val steps = Success.rewriteCount + val steps = elevate.core.SuccessRewriteCounter.rewriteCount logger.debug(s"[$versionUC] required rewrite steps: $steps\n") writeToFile(plotsFolder, version, s"$version,$steps", ".csv") } diff --git a/src/test/scala/rise/elevate/util/package.scala b/src/test/scala/rise/elevate/util/package.scala index d3588d668..f5f428975 100644 --- a/src/test/scala/rise/elevate/util/package.scala +++ b/src/test/scala/rise/elevate/util/package.scala @@ -11,8 +11,8 @@ package object util { // Rise-related utils - def betaEtaEquals(a: Rise, b: Rise)(implicit ev: Traversable[Rise]): Boolean = - BENF()(ev)(makeClosed(a)).get =~= BENF()(ev)(makeClosed(b)).get + def betaEtaEquals(a: Rise, b: Rise)(using ev: Traversable[Rise]): Boolean = + BENF()(using ev)(makeClosed(a)).get =~= BENF()(using ev)(makeClosed(b)).get val tileSize = 4 diff --git a/src/test/scala/rise/eqsat/Basic.scala b/src/test/scala/rise/eqsat/Basic.scala index 4c97ccd6e..d88ebe02f 100644 --- a/src/test/scala/rise/eqsat/Basic.scala +++ b/src/test/scala/rise/eqsat/Basic.scala @@ -2,6 +2,7 @@ package rise.eqsat import rise.{core => rc} import rise.core.{types => rct} +import scala.language.postfixOps class Basic extends test_util.Tests { import Basic.proveEquiv diff --git a/src/test/scala/rise/eqsat/CircularBuffering.scala b/src/test/scala/rise/eqsat/CircularBuffering.scala index 02a54514c..132ae12e8 100644 --- a/src/test/scala/rise/eqsat/CircularBuffering.scala +++ b/src/test/scala/rise/eqsat/CircularBuffering.scala @@ -64,7 +64,7 @@ class CircularBuffering extends test_util.Tests { import rise.elevate.rules.traversal.alternative._ import elevate.core.strategies.basic.normalize - val normGoal = normalize.apply(gentleBetaReduction() <+ etaReduction())(goal).get + val normGoal = normalize(gentleBetaReduction() <+ etaReduction())(goal).get println(s"normalized goal: $normGoal") Basic.proveEquiv(Expr.fromNamed(start), Expr.simplifyNats(Expr.fromNamed(normGoal)), rules) diff --git a/src/test/scala/rise/eqsat/Reorder.scala b/src/test/scala/rise/eqsat/Reorder.scala index 5dc62ddac..5d7e20837 100644 --- a/src/test/scala/rise/eqsat/Reorder.scala +++ b/src/test/scala/rise/eqsat/Reorder.scala @@ -1,24 +1,24 @@ package rise.eqsat -import rise.core.Expr -import rise.core.DSL._ -import rise.core.DSL.Type._ -import rise.core.types._ -import Basic.proveEquiv -import rise.elevate.util._ +import rise.core.DSL.* +import rise.core.DSL.Type.* +import rise.core.Expr as RiseExpr +import rise.core.types.{Nat as RiseNat, DataType as RiseDataType, *} +import rise.elevate.util.* +import rise.eqsat.Basic.proveEquiv class Reorder extends test_util.Tests { test("reorder 2D") { - def wrap(inner: ToBeTyped[Expr] => ToBeTyped[Expr] => ToBeTyped[Expr]): ToBeTyped[Expr] = { - depFun((n: Nat) => depFun((m: Nat) => - depFun((dt1: DataType) => depFun((dt2: DataType) => + def wrap(inner: ToBeTyped[RiseExpr] => ToBeTyped[RiseExpr] => ToBeTyped[RiseExpr]): ToBeTyped[RiseExpr] = { + depFun((n: RiseNat) => depFun((m: RiseNat) => + depFun((dt1: RiseDataType) => depFun((dt2: RiseDataType) => fun(i => fun(f => inner(i :: (n`.`m`.`dt1))(f) :: (n`.`m`.`dt2) )))))) } - val expr: Expr = wrap(i => f => **!(f) $ i) - val gold: Expr = wrap(i => f => (T o **!(f) o T) $ i) + val expr: RiseExpr = wrap(i => f => **!(f) $ i) + val gold: RiseExpr = wrap(i => f => (T o **!(f) o T) $ i) proveEquiv(expr, gold, Seq( rules.eta, rules.beta, rules.betaNat, @@ -29,9 +29,9 @@ class Reorder extends test_util.Tests { // FIXME: difficulties reaching all of the goals ignore("reorder 3D") { - def wrap(inner: ToBeTyped[Expr] => ToBeTyped[Expr] => ToBeTyped[Expr]): Expr = { - depFun((n: Nat) => depFun((m: Nat) => depFun((o: Nat) => - depFun((dt1: DataType) => depFun((dt2: DataType) => + def wrap(inner: ToBeTyped[RiseExpr] => ToBeTyped[RiseExpr] => ToBeTyped[RiseExpr]): RiseExpr = { + depFun((n: RiseNat) => depFun((m: RiseNat) => depFun((o: RiseNat) => + depFun((dt1: RiseDataType) => depFun((dt2: RiseDataType) => fun(i => fun(f => inner(i :: (n`.`m`.`o`.`dt1))(f) :: (n`.`m`.`o`.`dt2) ))))))) @@ -57,19 +57,19 @@ class Reorder extends test_util.Tests { // FIXME: difficulties reaching all of the goals ignore("reorder 4D") { - def wrap(inner: ToBeTyped[Expr] => ToBeTyped[Expr] => ToBeTyped[Expr]): ToBeTyped[Expr] = { - depFun((n: Nat) => depFun((m: Nat) => depFun((o: Nat) => depFun((p: Nat) => - depFun((dt1: DataType) => depFun((dt2: DataType) => + def wrap(inner: ToBeTyped[RiseExpr] => ToBeTyped[RiseExpr] => ToBeTyped[RiseExpr]): ToBeTyped[RiseExpr] = { + depFun((n: RiseNat) => depFun((m: RiseNat) => depFun((o: RiseNat) => depFun((p: RiseNat) => + depFun((dt1: RiseDataType) => depFun((dt2: RiseDataType) => fun(i => fun(f => inner(i :: (n`.`m`.`o`.`p`.`dt1))(f) :: (n`.`m`.`o`.`p`.`dt2) )))))))) } - val expr: Expr = wrap(i => f => ****!(f) $ i) - val gold1243: Expr = wrap(i => f => (**!(T) o ****!(f) o **!(T)) $ i) - val gold1324: Expr = wrap(i => f => (*!(T) o ****!(f) o *!(T)) $ i) - val gold2134: Expr = wrap(i => f => (T o ****!(f) o T) $ i) - val gold4321: Expr = wrap(i => f => (**!(T) o *!(T) o T o **!(T) o *!(T) o **!(T) o ****!(f) o + val expr: RiseExpr = wrap(i => f => ****!(f) $ i) + val gold1243: RiseExpr = wrap(i => f => (**!(T) o ****!(f) o **!(T)) $ i) + val gold1324: RiseExpr = wrap(i => f => (*!(T) o ****!(f) o *!(T)) $ i) + val gold2134: RiseExpr = wrap(i => f => (T o ****!(f) o T) $ i) + val gold4321: RiseExpr = wrap(i => f => (**!(T) o *!(T) o T o **!(T) o *!(T) o **!(T) o ****!(f) o **!(T) o *!(T) o **!(T) o T o *!(T) o **!(T)) $ i) proveEquiv(expr, Seq(gold1243, gold1324, gold2134, gold4321), Seq( diff --git a/src/test/scala/rise/eqsat/TvmGemm.scala b/src/test/scala/rise/eqsat/TvmGemm.scala index f36a6e43a..56c33fca4 100644 --- a/src/test/scala/rise/eqsat/TvmGemm.scala +++ b/src/test/scala/rise/eqsat/TvmGemm.scala @@ -1,12 +1,12 @@ package rise.eqsat -import rise.core.Expr +import rise.core.{Expr => RiseExpr} import rise.elevate.tvmGemm import Basic.proveEquiv class TvmGemm extends test_util.Tests { test("TVM GEMM") { - val mm: Expr = tvmGemm.mm + val mm: RiseExpr = tvmGemm.mm val baseline = tvmGemm.baseline(mm).get val blocking = tvmGemm.blocking(mm).get val vectorization = tvmGemm.vectorization(mm).get diff --git a/src/test/scala/shine/DPIA/Primitives/Pad.scala b/src/test/scala/shine/DPIA/Primitives/Pad.scala index 087aee5c2..f28e99f96 100644 --- a/src/test/scala/shine/DPIA/Primitives/Pad.scala +++ b/src/test/scala/shine/DPIA/Primitives/Pad.scala @@ -1,15 +1,17 @@ package shine.DPIA.Primitives -import rise.core.DSL._ -import rise.core.primitives._ -import Type._ -import rise.core.types._ -import rise.core.types.DataType._ -import HighLevelConstructs.padClamp2D +import rise.core.DSL.* +import rise.core.DSL.HighLevelConstructs.padClamp2D +import rise.core.DSL.Type.* +import rise.core.primitives.* +import rise.core.types.* +import rise.core.types.DataType.* import shine.OpenCL.KernelExecutor.KernelNoSizes.fromKernelModule import util.gen import util.gen.c.function +import scala.reflect.Selectable.reflectiveSelectable + class Pad extends test_util.Tests { private val id = fun(x => x) @@ -38,7 +40,7 @@ class Pad extends test_util.Tests { } test("Simple OpenMP constant pad input and copy") { - import rise.openMP.primitives._ + import rise.openMP.primitives.* val e = depFun((n: Nat) => fun(ArrayType(n, f32))( xs => xs |> padCst(2)(3)(lf32(5.0f)) |> mapPar(fun(x => x)) @@ -48,7 +50,7 @@ class Pad extends test_util.Tests { } test("Simple OpenCL pad input and copy") { - import rise.openCL.DSL._ + import rise.openCL.DSL.* val e = depFun((n: Nat) => fun(ArrayType(n, f32))( xs => xs |> padCst(2)(3)(lf32(5.0f)) |> mapGlobal(fun(x => x)) @@ -58,7 +60,7 @@ class Pad extends test_util.Tests { } test("OpenCL Pad only left") { - import rise.openCL.DSL._ + import rise.openCL.DSL.* val e = depFun((n: Nat) => fun(ArrayType(n, f32))( xs => xs |> padCst(2)(0)(lf32(5.0f)) |> mapGlobal(fun(x => x)) @@ -68,7 +70,7 @@ class Pad extends test_util.Tests { } test("OpenCL Pad only right") { - import rise.openCL.DSL._ + import rise.openCL.DSL.* val e = depFun((n: Nat) => fun(ArrayType(n, f32))( xs => xs |> padCst(0)(3)(lf32(5.0f)) |> mapGlobal(fun(x => x)) @@ -78,7 +80,7 @@ class Pad extends test_util.Tests { } test("OpenCL pad before or after transpose") { - import rise.openCL.DSL._ + import rise.openCL.DSL.* val range = arithexpr.arithmetic.RangeAdd(1, arithexpr.arithmetic.PosInf, 1) val k1 = gen.opencl.kernel.fromExpr(depFun(range, (n: Nat) => @@ -96,7 +98,7 @@ class Pad extends test_util.Tests { val random = new scala.util.Random() val input = Array.fill(4, N)(random.nextInt()) - import shine.OpenCL._ + import shine.OpenCL.* val localSize = LocalSize(1) val globalSize = GlobalSize(1) diff --git a/src/test/scala/shine/DPIA/Primitives/Reduce.scala b/src/test/scala/shine/DPIA/Primitives/Reduce.scala index 1110eef4e..45b347b48 100644 --- a/src/test/scala/shine/DPIA/Primitives/Reduce.scala +++ b/src/test/scala/shine/DPIA/Primitives/Reduce.scala @@ -1,20 +1,20 @@ package shine.DPIA.Primitives import arithexpr.arithmetic.Cst -import rise.core.DSL.Type._ -import rise.core.DSL._ -import rise.core.Expr -import rise.core.primitives -import rise.core.primitives._ -import rise.core.types.DataType._ -import rise.core.types.{AddressSpace, _} +import rise.core.DSL.* +import rise.core.DSL.Type.* +import rise.core.primitives.* +import rise.core.types.DataType.* +import rise.core.types.{AddressSpace, *} +import rise.core.{Expr, primitives} import rise.openCL.primitives.oclReduceSeq +import shine.OpenCL.* import shine.OpenCL.KernelExecutor.KernelNoSizes.fromKernelModule -import shine.OpenCL._ import util.gen import util.gen.c.function import scala.language.postfixOps +import scala.reflect.Selectable.reflectiveSelectable class Reduce extends test_util.TestsWithExecutor { val add = fun(a => fun(b => a + b)) @@ -66,7 +66,7 @@ class Reduce extends test_util.TestsWithExecutor { val e = depFun((m: Nat, n: Nat) => fun(m`.`n`.`f32)(arr => arr - |> oclReduceSeq (AddressSpace.Private) + |> oclReduceSeq(AddressSpace.Private) (fun((in1, in2) => zip (in1) (in2) |> mapSeq (fun(t => t.`1` + t.`2`)))) (initExp (n)) |> mapSeq (fun(x => x)))) diff --git a/src/test/scala/shine/DPIA/Primitives/Scatter.scala b/src/test/scala/shine/DPIA/Primitives/Scatter.scala index 64e7d2013..5c0462675 100644 --- a/src/test/scala/shine/DPIA/Primitives/Scatter.scala +++ b/src/test/scala/shine/DPIA/Primitives/Scatter.scala @@ -1,19 +1,20 @@ package shine.DPIA.Primitives -import rise.core.primitives._ -import rise.core.DSL._ -import rise.core.types._ -import rise.core.types.DataType._ -import rise.core.DSL.Type._ +import rise.core.DSL.* +import rise.core.DSL.Type.* +import rise.core.primitives.* +import rise.core.types.* +import rise.core.types.DataType.* import shine.OpenCL.KernelExecutor.KernelNoSizes.fromKernelModule import util.gen import scala.language.postfixOps +import scala.reflect.Selectable.reflectiveSelectable class Scatter extends test_util.Tests { test("Reversing scatter should generate valid OpenCL") { - import rise.openCL.DSL._ - import shine.OpenCL._ + import rise.openCL.DSL.* + import shine.OpenCL.* val N = 20 val n: Nat = N @@ -36,8 +37,8 @@ class Scatter extends test_util.Tests { } test("Overriding scatter should generate valid OpenCL") { - import rise.openCL.DSL._ - import shine.OpenCL._ + import rise.openCL.DSL.* + import shine.OpenCL.* val N = 20 val n: Nat = N diff --git a/src/test/scala/shine/cuda/MMTest.scala b/src/test/scala/shine/cuda/MMTest.scala index 90dce7939..c1abfcc2e 100644 --- a/src/test/scala/shine/cuda/MMTest.scala +++ b/src/test/scala/shine/cuda/MMTest.scala @@ -1,19 +1,21 @@ package shine.cuda -import shine.DPIA.Phrases._ -import rise.core.types.MatrixLayout._ -import rise.core.types.{Fragment, MatrixLayout, MatrixLayoutIdentifier, NatIdentifier, NatKind, read, write} -import rise.core.types.DataType._ -import shine.DPIA.Types._ -import shine.DPIA._ +import rise.core.types.DataType.* +import rise.core.types.MatrixLayout.* +import rise.core.types.{FunType => _, *} +import shine.DPIA.* +import shine.DPIA.Phrases.* +import shine.DPIA.Types.* import shine.DPIA.primitives.functional.{Fst, Join, Snd, Split, Transpose, Zip} -import shine.OpenCL._ -import shine.cuda.primitives.functional._ +import shine.OpenCL.* import shine.OpenCL.primitives.functional.{ReduceSeq, ToMem} import shine.cuda.KernelExecutor.KernelNoSizes +import shine.cuda.primitives.functional.* import test_util.similar import util.gen +import scala.reflect.Selectable.reflectiveSelectable + class MMTest extends test_util.TestWithCUDA { val n: NatIdentifier = NatIdentifier(freshName("n")) val m: NatIdentifier = NatIdentifier(freshName("m")) diff --git a/src/test/scala/test_util/package.scala b/src/test/scala/test_util/package.scala index 39d84502d..a863a1991 100644 --- a/src/test/scala/test_util/package.scala +++ b/src/test/scala/test_util/package.scala @@ -2,11 +2,11 @@ import opencl.executor.Executor import org.scalatest.BeforeAndAfter import org.scalatest.matchers.should.Matchers import org.scalatest.funsuite.AnyFunSuite -import org.apache.logging.log4j.scala.Logging import util.{AssertSame, Time, TimeSpan} +import wvlet.log.LogSupport package object test_util { - abstract class Tests extends AnyFunSuite with Matchers with Logging { + abstract class Tests extends AnyFunSuite with Matchers with LogSupport { def runsWithSameResult[R, U <: Time.Unit](runs: Seq[(String, (R, TimeSpan[U]))]) (implicit assertSame: AssertSame[R]): Unit = { runs.tail.foreach(r => assertSame(r._2._1, runs.head._2._1, s"${r._1} had a different result"))