Skip to content
This repository has been archived by the owner on Dec 18, 2020. It is now read-only.

Change data types, adding some fixed width #15

Merged
merged 2 commits into from
Jan 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/main/scala/rise/core/semantics/Data.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ final case class BoolData(b: Boolean) extends ScalarData(bool)

final case class IntData(i: Int) extends ScalarData(int)

final case class FloatData(f: Float) extends ScalarData(float)
final case class FloatData(f: Float) extends ScalarData(f32)

final case class DoubleData(d: Double) extends ScalarData(double)
final case class DoubleData(d: Double) extends ScalarData(f64)

final case class VectorData(v: Seq[ScalarData])
extends Data(VectorType(v.length, v.head.dataType))
Expand Down
38 changes: 15 additions & 23 deletions src/main/scala/rise/core/types/Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,19 @@ object int extends ScalarType {
override def toString: String = "int"
}

object float extends ScalarType {
override def toString: String = "float"
}
object i8 extends ScalarType { override def toString: String = "i8" }
object i16 extends ScalarType { override def toString: String = "i16" }
object i32 extends ScalarType { override def toString: String = "i32" }
object i64 extends ScalarType { override def toString: String = "i64" }

object u8 extends ScalarType { override def toString: String = "u8" }
object u16 extends ScalarType { override def toString: String = "u16" }
object u32 extends ScalarType { override def toString: String = "u32" }
object u64 extends ScalarType { override def toString: String = "u64" }

object double extends ScalarType { override def toString: String = "double" }
object f16 extends ScalarType { override def toString: String = "f16" }
object f32 extends ScalarType { override def toString: String = "f32" }
object f64 extends ScalarType { override def toString: String = "f64" }

object NatType extends ScalarType { override def toString: String = "nat" }

Expand All @@ -105,25 +113,9 @@ sealed case class VectorType(size: Nat, elemType: DataType) extends BasicType {
override def toString: String = s"<$size>$elemType"
}

object int2 extends VectorType(2, int)

object int3 extends VectorType(3, int)

object int4 extends VectorType(4, int)

object int8 extends VectorType(8, int)

object int16 extends VectorType(16, int)

object float2 extends VectorType(2, float)

object float3 extends VectorType(3, float)

object float4 extends VectorType(4, float)

object float8 extends VectorType(8, float)

object float16 extends VectorType(16, float)
object vec {
def apply(size: Nat, elemType: DataType) = VectorType(size, elemType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to enforce here that size is a Cst and that it is one of: 2, 3, 4, 8, 16?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we want to write some vector-width generic code with types such as nFunT(w => f32 ->: vec(w, f32)? Although you could still use VectorType(w, f32) if you wanted.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like such checks would be better suited when you try to generate code with a specific backend/target?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This form of polymorphism is dangerous anyway because the type suggests that we can build a VectorType(w, f32) for all w: Nat. But that isn't true.

We will have to be looking a bit more carefully how to define the VectorType on the formal side, I guess.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can open an issue about that? There is also the issue that elemType should be a ScalarType ideally but we have no way of creating ScalarTypeIdentifiers for example.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we want subtyping constraints (#11) but also nat constraints?

}

final class NatToDataApply(val f: NatToData, val n: Nat) extends DataType {
override def toString: String = s"$f($n)"
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/rise/core/showRise.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class showRise extends test_util.Tests {

private val blurXTiled2D: Expr = nFun(n =>
fun(
(n `.` n `.` float) ->: (17 `.` float) ->: (n `.` n `.` float)
(n `.` n `.` f32) ->: (17 `.` f32) ->: (n `.` n `.` f32)
)((matrix, weights) =>
unslide2D o mapWorkGroup(1)(
mapWorkGroup(0)(
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/rise/core/structuralEquality.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class structuralEquality extends test_util.Tests {
)
!=
nFun(m =>
fun(ArrayType(m, float))(b =>
fun(ArrayType(m, f32))(b =>
reduceSeq(fun(y => fun(x => y + x)))(0)(b)
)
)
Expand Down
9 changes: 4 additions & 5 deletions src/test/scala/rise/core/traverse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ import scala.collection.mutable
class traverse extends test_util.Tests {
val e = nFun(h =>
nFun(w =>
fun(ArrayType(h, ArrayType(w, float)))(input =>
map(map(fun(x => x)))(input)
fun(ArrayType(h, ArrayType(w, f32)))(input => map(map(fun(x => x)))(input)
)
)
)
Expand Down Expand Up @@ -120,7 +119,7 @@ class traverse extends test_util.Tests {
r ==
nFun(h =>
nFun(w =>
fun(ArrayType(h, ArrayType(w, float)))(input =>
fun(ArrayType(h, ArrayType(w, f32)))(input =>
app(fun(x => x), input)
)
)
Expand All @@ -135,7 +134,7 @@ class traverse extends test_util.Tests {

test("traverse an expression depth-first with global stop") {
val e = nFun(n =>
fun(ArrayType(n, float))(input =>
fun(ArrayType(n, f32))(input =>
input |> map(fun(x => x)) |> map(fun(x => x))
)
)
Expand All @@ -157,7 +156,7 @@ class traverse extends test_util.Tests {
(result: @unchecked) match {
case traversal.Stop(r) =>
val expected = nFun(n =>
fun(ArrayType(n, float))(input => {
fun(ArrayType(n, f32))(input => {
val x = identifier(freshName("x"))
app(lambda(x, x), input |> map(fun(x => x)))
})
Expand Down
4 changes: 2 additions & 2 deletions src/test/scala/rise/core/typedDSL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ class typedDSL extends test_util.Tests {
val e =
nFun(n =>
fun(
DepArrayType(n, n2dtFun(i => (i + 1) `.` float)) ->: DepArrayType(
DepArrayType(n, n2dtFun(i => (i + 1) `.` f32)) ->: DepArrayType(
n,
n2dtFun(i => (i + 1) `.` float)
n2dtFun(i => (i + 1) `.` f32)
)
)(xs => xs |> depMapSeq(nFun(_ => mapSeq(fun(x => x)))))
)
Expand Down