Skip to content
This repository has been archived by the owner on Dec 18, 2020. It is now read-only.

Commit

Permalink
Merge pull request #15 from rise-lang/fixed-width-types
Browse files Browse the repository at this point in the history
Change data types, adding some fixed width
  • Loading branch information
Bastacyclop authored Jan 20, 2020
2 parents 23d215c + fd5ab85 commit c0b768b
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 34 deletions.
4 changes: 2 additions & 2 deletions src/main/scala/rise/core/semantics/Data.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ final case class BoolData(b: Boolean) extends ScalarData(bool)

final case class IntData(i: Int) extends ScalarData(int)

final case class FloatData(f: Float) extends ScalarData(float)
final case class FloatData(f: Float) extends ScalarData(f32)

final case class DoubleData(d: Double) extends ScalarData(double)
final case class DoubleData(d: Double) extends ScalarData(f64)

final case class VectorData(v: Seq[ScalarData])
extends Data(VectorType(v.length, v.head.dataType))
Expand Down
38 changes: 15 additions & 23 deletions src/main/scala/rise/core/types/Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,19 @@ object int extends ScalarType {
override def toString: String = "int"
}

object float extends ScalarType {
override def toString: String = "float"
}
object i8 extends ScalarType { override def toString: String = "i8" }
object i16 extends ScalarType { override def toString: String = "i16" }
object i32 extends ScalarType { override def toString: String = "i32" }
object i64 extends ScalarType { override def toString: String = "i64" }

object u8 extends ScalarType { override def toString: String = "u8" }
object u16 extends ScalarType { override def toString: String = "u16" }
object u32 extends ScalarType { override def toString: String = "u32" }
object u64 extends ScalarType { override def toString: String = "u64" }

object double extends ScalarType { override def toString: String = "double" }
object f16 extends ScalarType { override def toString: String = "f16" }
object f32 extends ScalarType { override def toString: String = "f32" }
object f64 extends ScalarType { override def toString: String = "f64" }

object NatType extends ScalarType { override def toString: String = "nat" }

Expand All @@ -105,25 +113,9 @@ sealed case class VectorType(size: Nat, elemType: DataType) extends BasicType {
override def toString: String = s"<$size>$elemType"
}

object int2 extends VectorType(2, int)

object int3 extends VectorType(3, int)

object int4 extends VectorType(4, int)

object int8 extends VectorType(8, int)

object int16 extends VectorType(16, int)

object float2 extends VectorType(2, float)

object float3 extends VectorType(3, float)

object float4 extends VectorType(4, float)

object float8 extends VectorType(8, float)

object float16 extends VectorType(16, float)
object vec {
def apply(size: Nat, elemType: DataType) = VectorType(size, elemType)
}

final class NatToDataApply(val f: NatToData, val n: Nat) extends DataType {
override def toString: String = s"$f($n)"
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/rise/core/showRise.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class showRise extends test_util.Tests {

private val blurXTiled2D: Expr = nFun(n =>
fun(
(n `.` n `.` float) ->: (17 `.` float) ->: (n `.` n `.` float)
(n `.` n `.` f32) ->: (17 `.` f32) ->: (n `.` n `.` f32)
)((matrix, weights) =>
unslide2D o mapWorkGroup(1)(
mapWorkGroup(0)(
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/rise/core/structuralEquality.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class structuralEquality extends test_util.Tests {
)
!=
nFun(m =>
fun(ArrayType(m, float))(b =>
fun(ArrayType(m, f32))(b =>
reduceSeq(fun(y => fun(x => y + x)))(0)(b)
)
)
Expand Down
9 changes: 4 additions & 5 deletions src/test/scala/rise/core/traverse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ import scala.collection.mutable
class traverse extends test_util.Tests {
val e = nFun(h =>
nFun(w =>
fun(ArrayType(h, ArrayType(w, float)))(input =>
map(map(fun(x => x)))(input)
fun(ArrayType(h, ArrayType(w, f32)))(input => map(map(fun(x => x)))(input)
)
)
)
Expand Down Expand Up @@ -120,7 +119,7 @@ class traverse extends test_util.Tests {
r ==
nFun(h =>
nFun(w =>
fun(ArrayType(h, ArrayType(w, float)))(input =>
fun(ArrayType(h, ArrayType(w, f32)))(input =>
app(fun(x => x), input)
)
)
Expand All @@ -135,7 +134,7 @@ class traverse extends test_util.Tests {

test("traverse an expression depth-first with global stop") {
val e = nFun(n =>
fun(ArrayType(n, float))(input =>
fun(ArrayType(n, f32))(input =>
input |> map(fun(x => x)) |> map(fun(x => x))
)
)
Expand All @@ -157,7 +156,7 @@ class traverse extends test_util.Tests {
(result: @unchecked) match {
case traversal.Stop(r) =>
val expected = nFun(n =>
fun(ArrayType(n, float))(input => {
fun(ArrayType(n, f32))(input => {
val x = identifier(freshName("x"))
app(lambda(x, x), input |> map(fun(x => x)))
})
Expand Down
4 changes: 2 additions & 2 deletions src/test/scala/rise/core/typedDSL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ class typedDSL extends test_util.Tests {
val e =
nFun(n =>
fun(
DepArrayType(n, n2dtFun(i => (i + 1) `.` float)) ->: DepArrayType(
DepArrayType(n, n2dtFun(i => (i + 1) `.` f32)) ->: DepArrayType(
n,
n2dtFun(i => (i + 1) `.` float)
n2dtFun(i => (i + 1) `.` f32)
)
)(xs => xs |> depMapSeq(nFun(_ => mapSeq(fun(x => x)))))
)
Expand Down

0 comments on commit c0b768b

Please sign in to comment.