Skip to content

Commit

Permalink
Merge pull request #70 from rise-lang/feature/improve-RISE-syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
Michel Steuwer authored Oct 7, 2020
2 parents cde826c + b9d4885 commit 90586cd
Show file tree
Hide file tree
Showing 20 changed files with 564 additions and 541 deletions.
12 changes: 6 additions & 6 deletions src/main/scala/apps/cameraPipe.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,19 +101,19 @@ object cameraPipe {
)
}

def interleaveX: Expr = implDT(dt => implNat(h => implNat(w => fun(
def interleaveX: Expr = impl{ dt: DataType => impl{ h: Nat => impl{ w: Nat => fun(
(h`.`w`.`dt) ->: (h`.`w`.`dt) ->: (h`.`(2*w)`.`dt)
)((a, b) =>
generate(fun(i => select(i =:= lidx(0, 2))(a)(b))) |>
transpose >> map(transpose >> join)
))))
) }}}

def interleaveY: Expr = implDT(dt => implNat(h => implNat(w => fun(
def interleaveY: Expr = impl{ dt: DataType => impl{ h: Nat => impl{ w: Nat => fun(
(h`.`w`.`dt) ->: (h`.`w`.`dt) ->: ((2*h)`.`w`.`dt)
)((a, b) =>
generate(fun(i => select(i =:= lidx(0, 2))(a)(b))) |>
transpose >> join
))))
) }}}

val demosaic: Expr = nFun(h => nFun(w => fun(
(4`.`(h+2)`.`(w+2)`.`i16) ->: (3`.`(2*h)`.`(2*w)`.`i16)
Expand Down Expand Up @@ -425,8 +425,8 @@ object cameraPipe {
// mapSeq(mapSeqUnroll(fun(x => x)))
// )) >> join >> map(transpose) >> transpose >>
// --
map(implNat(w => map(drop(1) >> take(w))) >>
implNat(h => drop(1) >> take(h))) >>
map(impl{ w: Nat => map(drop(1) >> take(w)) } >>
impl{ h: Nat => drop(1) >> take(h) }) >>
fun(x => color_correct(2*h+2)(2*w+2)(hm)(wm)(x)
(matrix_3200)(matrix_7000)(color_temp)) >>
// TODO: reorder and store with elevate
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/apps/convolution.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ object convolution {
))

def padEmpty(l: Nat, r: Nat): Expr = padClamp(l)(r)
def unpadEmpty(l: Nat, r: Nat): Expr = implNat(n => implDT(t =>
def unpadEmpty(l: Nat, r: Nat): Expr = impl{ n: Nat => impl{ t: DataType =>
drop(l) >> (take(n) :: ((n + r) `.` t) ->: (n `.` t))
))
}}

val blurYTiled2DTiledLoadingTransposed: Expr = nFun(n => fun(
(n `.` n `.` f32) ->: (17 `.` f32) ->: (n `.` n `.` f32)
Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/apps/separableConvolution2D.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ object separableConvolution2D {
val shuffle: Expr =
asScalar >> drop(3) >> take(6) >> slideVectors(4)
val scanlinePar: Expr = fun(3`.`f32)(weightsV => fun(3`.`f32)(weightsH =>
map(implNat(w => fun(w`.`f32)(x =>
map(impl{ w: Nat => fun(w`.`f32)(x =>
x |> asVectorAligned(4)
|> padCst(1)(0)(vectorFromScalar(x `@` lidx(0, w)))
|> padCst(0)(1)(vectorFromScalar(x `@` lidx(w - 1, w)))
))) >> padClamp(1)(1) >>
)}) >> padClamp(1)(1) >>
slide(3)(1) >> mapGlobal(
transpose >>
toGlobalFun(mapSeq(weightsSeqVecUnroll(weightsV))) >>
Expand All @@ -156,11 +156,11 @@ object separableConvolution2D {
val Dv = weightsSeqVecUnroll(weightsV)
val Dh = weightsSeqVecUnroll(weightsH)
// map(padClamp(4)(4) >> asVectorAligned(4)) >> padClamp(1)(1) >>
map(implNat(w => fun(w`.`f32)(x =>
map(impl{ w: Nat => fun(w`.`f32)(x =>
x |> asVectorAligned(4)
|> padCst(1)(0)(vectorFromScalar(x `@` lidx(0, w)))
|> padCst(0)(1)(vectorFromScalar(x `@` lidx(w - 1, w)))
))) >> padClamp(1)(1) >>
)}) >> padClamp(1)(1) >>
slide(3)(1) >> mapGlobal(
transpose >>
map(Dv) >>
Expand Down
14 changes: 7 additions & 7 deletions src/main/scala/apps/sgemm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import scala.collection.parallel.CollectionConverters._

object sgemm {
// we can use implicit type parameters and type annotations to specify the function type of mult
val mult = implDT(dt => fun(x => x._1 * x._2) :: ((dt x dt) ->: dt))
val mult = impl{ dt: DataType => fun(x => x._1 * x._2) :: ((dt x dt) ->: dt) }
val add = fun(x => fun(y => x + y))
val scal = implNat(n => fun(xs => fun(a => mapSeq(fun(x => a * x), xs))) :: (ArrayType(n, f32) ->: f32 ->: ArrayType(n, f32)))
val scal = impl{ n: Nat => fun(xs => fun(a => mapSeq(fun(x => a * x), xs))) :: (ArrayType(n, f32) ->: f32 ->: ArrayType(n, f32)) }
val dot = fun(x => foreignFun("dot", vec(4, f32) ->: vec(4, f32) ->: f32)(x._1, x._2))
def id: Expr = fun(x => x)

Expand Down Expand Up @@ -57,9 +57,9 @@ object sgemm {
val p3: Nat = 4
val vw: Nat = 4

val write_zeros = implNat(n => implNat(m =>
generate(fun(IndexType(m))(_ => generate(fun(IndexType(n))(_ => l(0.0f)))))
|> mapSeq(mapSeq(id))))
val write_zeros = impl{ n: Nat => impl{ m: Nat =>
generate(fun(IndexType(m))(_ =>
generate(fun(IndexType(n))(_ => l(0.0f))))) |> mapSeq(mapSeq(id)) }}


nFun((n, m, k) =>
Expand Down Expand Up @@ -113,8 +113,8 @@ object sgemm {
generate(fun(IndexType(n2))(_ =>
generate(fun(IndexType(n1))(_ => l(0.0f)))))))))))))

def tile2: Expr = nFun(s1 => nFun(s2 => implNat(n1 => implNat(n2 => fun(ArrayType(n1, ArrayType(n2, f32)))(x =>
transpose (map (transpose) (split (s1) (map (split (s2)) (x)))) )))))
def tile2: Expr = nFun(s1 => nFun(s2 => impl{ n1: Nat => impl{ n2: Nat => fun(ArrayType(n1, ArrayType(n2, f32)))(x =>
transpose (map (transpose) (split (s1) (map (split (s2)) (x)))) ) }}))

def redOp: Expr = fun((8`.`32`.`8`.`4`.`f32) ->: ( (8`.`64`.`f32) x (8`.`128`.`f32) ) ->: (8`.`32`.`8`.`4`.`f32) )((p14, p15) =>
let (p15 |> fun(p29 =>
Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/rise/core/HighLevelConstructs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ object HighLevelConstructs {
val reorderWithStride: Expr = {
nFun(s => {
val f =
implNat(n =>
impl{ n: Nat =>
fun(IndexType(n))(i =>
natAsIndex(n)(
(indexAsNat(i) / (n /^ s)) +
((s: Expr) * (indexAsNat(i) % (n /^ s)))
)
)
)
}
reorder(f)(f)
})
}
Expand All @@ -52,8 +52,8 @@ object HighLevelConstructs {
fun(a => fun(b => rec(n, a, b)))
}

def dropLast: Expr = nFun(n =>
implNat(m => implDT(dt => take(m) :: ((m + n) `.` dt) ->: (m `.` dt)))
def dropLast: Expr = nFun(n => impl{ m: Nat => impl{ dt: DataType =>
take(m) :: ((m + n) `.` dt) ->: (m `.` dt) }}
)

// TODO: Investigate. this might be wrong
Expand Down
Loading

0 comments on commit 90586cd

Please sign in to comment.