Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved RISE syntax #70

Merged
merged 10 commits into from
Oct 7, 2020
12 changes: 6 additions & 6 deletions src/main/scala/apps/cameraPipe.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,19 +101,19 @@ object cameraPipe {
)
}

def interleaveX: Expr = implDT(dt => implNat(h => implNat(w => fun(
def interleaveX: Expr = impl{ dt: DataType => impl{ h: Nat => impl{ w: Nat => fun(
(h`.`w`.`dt) ->: (h`.`w`.`dt) ->: (h`.`(2*w)`.`dt)
)((a, b) =>
generate(fun(i => select(i =:= lidx(0, 2))(a)(b))) |>
transpose >> map(transpose >> join)
))))
) }}}

def interleaveY: Expr = implDT(dt => implNat(h => implNat(w => fun(
def interleaveY: Expr = impl{ dt: DataType => impl{ h: Nat => impl{ w: Nat => fun(
(h`.`w`.`dt) ->: (h`.`w`.`dt) ->: ((2*h)`.`w`.`dt)
)((a, b) =>
generate(fun(i => select(i =:= lidx(0, 2))(a)(b))) |>
transpose >> join
))))
) }}}

val demosaic: Expr = nFun(h => nFun(w => fun(
(4`.`(h+2)`.`(w+2)`.`i16) ->: (3`.`(2*h)`.`(2*w)`.`i16)
Expand Down Expand Up @@ -425,8 +425,8 @@ object cameraPipe {
// mapSeq(mapSeqUnroll(fun(x => x)))
// )) >> join >> map(transpose) >> transpose >>
// --
map(implNat(w => map(drop(1) >> take(w))) >>
implNat(h => drop(1) >> take(h))) >>
map(impl{ w: Nat => map(drop(1) >> take(w)) } >>
impl{ h: Nat => drop(1) >> take(h) }) >>
fun(x => color_correct(2*h+2)(2*w+2)(hm)(wm)(x)
(matrix_3200)(matrix_7000)(color_temp)) >>
// TODO: reorder and store with elevate
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/apps/convolution.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ object convolution {
))

def padEmpty(l: Nat, r: Nat): Expr = padClamp(l)(r)
def unpadEmpty(l: Nat, r: Nat): Expr = implNat(n => implDT(t =>
def unpadEmpty(l: Nat, r: Nat): Expr = impl{ n: Nat => impl{ t: DataType =>
drop(l) >> (take(n) :: ((n + r) `.` t) ->: (n `.` t))
))
}}

val blurYTiled2DTiledLoadingTransposed: Expr = nFun(n => fun(
(n `.` n `.` f32) ->: (17 `.` f32) ->: (n `.` n `.` f32)
Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/apps/separableConvolution2D.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ object separableConvolution2D {
val shuffle: Expr =
asScalar >> drop(3) >> take(6) >> slideVectors(4)
val scanlinePar: Expr = fun(3`.`f32)(weightsV => fun(3`.`f32)(weightsH =>
map(implNat(w => fun(w`.`f32)(x =>
map(impl{ w: Nat => fun(w`.`f32)(x =>
x |> asVectorAligned(4)
|> padCst(1)(0)(vectorFromScalar(x `@` lidx(0, w)))
|> padCst(0)(1)(vectorFromScalar(x `@` lidx(w - 1, w)))
))) >> padClamp(1)(1) >>
)}) >> padClamp(1)(1) >>
slide(3)(1) >> mapGlobal(
transpose >>
toGlobalFun(mapSeq(weightsSeqVecUnroll(weightsV))) >>
Expand All @@ -156,11 +156,11 @@ object separableConvolution2D {
val Dv = weightsSeqVecUnroll(weightsV)
val Dh = weightsSeqVecUnroll(weightsH)
// map(padClamp(4)(4) >> asVectorAligned(4)) >> padClamp(1)(1) >>
map(implNat(w => fun(w`.`f32)(x =>
map(impl{ w: Nat => fun(w`.`f32)(x =>
x |> asVectorAligned(4)
|> padCst(1)(0)(vectorFromScalar(x `@` lidx(0, w)))
|> padCst(0)(1)(vectorFromScalar(x `@` lidx(w - 1, w)))
))) >> padClamp(1)(1) >>
)}) >> padClamp(1)(1) >>
slide(3)(1) >> mapGlobal(
transpose >>
map(Dv) >>
Expand Down
14 changes: 7 additions & 7 deletions src/main/scala/apps/sgemm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import scala.collection.parallel.CollectionConverters._

object sgemm {
// we can use implicit type parameters and type annotations to specify the function type of mult
val mult = implDT(dt => fun(x => x._1 * x._2) :: ((dt x dt) ->: dt))
val mult = impl{ dt: DataType => fun(x => x._1 * x._2) :: ((dt x dt) ->: dt) }
val add = fun(x => fun(y => x + y))
val scal = implNat(n => fun(xs => fun(a => mapSeq(fun(x => a * x), xs))) :: (ArrayType(n, f32) ->: f32 ->: ArrayType(n, f32)))
val scal = impl{ n: Nat => fun(xs => fun(a => mapSeq(fun(x => a * x), xs))) :: (ArrayType(n, f32) ->: f32 ->: ArrayType(n, f32)) }
val dot = fun(x => foreignFun("dot", vec(4, f32) ->: vec(4, f32) ->: f32)(x._1, x._2))
def id: Expr = fun(x => x)

Expand Down Expand Up @@ -57,9 +57,9 @@ object sgemm {
val p3: Nat = 4
val vw: Nat = 4

val write_zeros = implNat(n => implNat(m =>
generate(fun(IndexType(m))(_ => generate(fun(IndexType(n))(_ => l(0.0f)))))
|> mapSeq(mapSeq(id))))
val write_zeros = impl{ n: Nat => impl{ m: Nat =>
generate(fun(IndexType(m))(_ =>
generate(fun(IndexType(n))(_ => l(0.0f))))) |> mapSeq(mapSeq(id)) }}


nFun((n, m, k) =>
Expand Down Expand Up @@ -113,8 +113,8 @@ object sgemm {
generate(fun(IndexType(n2))(_ =>
generate(fun(IndexType(n1))(_ => l(0.0f)))))))))))))

def tile2: Expr = nFun(s1 => nFun(s2 => implNat(n1 => implNat(n2 => fun(ArrayType(n1, ArrayType(n2, f32)))(x =>
transpose (map (transpose) (split (s1) (map (split (s2)) (x)))) )))))
def tile2: Expr = nFun(s1 => nFun(s2 => impl{ n1: Nat => impl{ n2: Nat => fun(ArrayType(n1, ArrayType(n2, f32)))(x =>
transpose (map (transpose) (split (s1) (map (split (s2)) (x)))) ) }}))

def redOp: Expr = fun((8`.`32`.`8`.`4`.`f32) ->: ( (8`.`64`.`f32) x (8`.`128`.`f32) ) ->: (8`.`32`.`8`.`4`.`f32) )((p14, p15) =>
let (p15 |> fun(p29 =>
Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/rise/core/HighLevelConstructs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ object HighLevelConstructs {
val reorderWithStride: Expr = {
nFun(s => {
val f =
implNat(n =>
impl{ n: Nat =>
fun(IndexType(n))(i =>
natAsIndex(n)(
(indexAsNat(i) / (n /^ s)) +
((s: Expr) * (indexAsNat(i) % (n /^ s)))
)
)
)
}
reorder(f)(f)
})
}
Expand All @@ -52,8 +52,8 @@ object HighLevelConstructs {
fun(a => fun(b => rec(n, a, b)))
}

def dropLast: Expr = nFun(n =>
implNat(m => implDT(dt => take(m) :: ((m + n) `.` dt) ->: (m `.` dt)))
def dropLast: Expr = nFun(n => impl{ m: Nat => impl{ dt: DataType =>
take(m) :: ((m + n) `.` dt) ->: (m `.` dt) }}
)

// TODO: Investigate. this might be wrong
Expand Down
Loading