From 5f81b9dfe0b7f71832beceb09cd177000be49ca0 Mon Sep 17 00:00:00 2001 From: Uma Zalakain Date: Thu, 11 Feb 2021 14:27:25 +0000 Subject: [PATCH 01/11] Move monad trait into utils --- src/main/scala/rise/core/traverse.scala | 34 ++++++------------------- src/main/scala/util/monad.scala | 31 ++++++++++++++++++++++ 2 files changed, 39 insertions(+), 26 deletions(-) create mode 100644 src/main/scala/util/monad.scala diff --git a/src/main/scala/rise/core/traverse.scala b/src/main/scala/rise/core/traverse.scala index b41491c9d..0575b7fa3 100644 --- a/src/main/scala/rise/core/traverse.scala +++ b/src/main/scala/rise/core/traverse.scala @@ -1,23 +1,17 @@ package rise.core -import scala.language.implicitConversions -import arithexpr.arithmetic.NamedVar +import util.monad import rise.core.semantics._ import rise.core.types._ +import scala.language.implicitConversions object traverse { - trait Monad[M[_]] { - def return_[T] : T => M[T] - def bind[T,S] : M[T] => (T => M[S]) => M[S] - def traverse[A] : Seq[M[A]] => M[Seq[A]] = - _.foldRight(return_(Nil : Seq[A]))({case (mx, mxs) => - bind(mx)(x => bind(mxs)(xs => return_(x +: xs)))}) - } - - implicit def monadicSyntax[M[_], A](m: M[A])(implicit tc: Monad[M]) = new { - def map[B](f: A => B): M[B] = tc.bind(m)(a => tc.return_(f(a)) ) - def flatMap[B](f: A => M[B]): M[B] = tc.bind(m)(f) - } + // Reexport util.monad.* + type Monad[M[_]] = monad.Monad[M] + type Pure[T] = monad.Pure[T] + implicit def monadicSyntax[M[_], A](m: M[A])(implicit tc: monad.Monad[M]) = monad.monadicSyntax(m)(tc) + val PureMonad = monad.PureMonad + val OptionMonad = monad.OptionMonad sealed trait VarType case object Binding extends VarType @@ -183,18 +177,6 @@ object traverse { override def `type`[T <: Type] : T => M[T] = return_ } - case class Pure[T](unwrap : T) - implicit object PureMonad extends Monad[Pure] { - override def return_[T] : T => Pure[T] = t => Pure(t) - override def bind[T,S] : Pure[T] => (T => Pure[S]) => Pure[S] = - v => f => v match { case Pure(v) => f(v) } - } - - implicit object OptionMonad extends Monad[Option] { - def return_[T]: T => Option[T] = Some(_) - def bind[T, S]: Option[T] => (T => Option[S]) => Option[S] = v => v.flatMap - } - trait PureTraversal extends Traversal[Pure] { override def monad = PureMonad } trait PureExprTraversal extends PureTraversal with ExprTraversal[Pure] diff --git a/src/main/scala/util/monad.scala b/src/main/scala/util/monad.scala new file mode 100644 index 000000000..2e10fd78c --- /dev/null +++ b/src/main/scala/util/monad.scala @@ -0,0 +1,31 @@ +package util + +import scala.language.implicitConversions + +object monad { + trait Monad[M[_]] { + def return_[T] : T => M[T] + def bind[T,S] : M[T] => (T => M[S]) => M[S] + def traverse[A] : Seq[M[A]] => M[Seq[A]] = + _.foldRight(return_(Nil : Seq[A]))({case (mx, mxs) => + bind(mx)(x => bind(mxs)(xs => return_(x +: xs)))}) + } + + implicit def monadicSyntax[M[_], A](m: M[A])(implicit tc: Monad[M]) = new { + def map[B](f: A => B): M[B] = tc.bind(m)(a => tc.return_(f(a)) ) + def flatMap[B](f: A => M[B]): M[B] = tc.bind(m)(f) + } + + case class Pure[T](unwrap : T) + implicit object PureMonad extends Monad[Pure] { + override def return_[T] : T => Pure[T] = t => Pure(t) + override def bind[T,S] : Pure[T] => (T => Pure[S]) => Pure[S] = + v => f => v match { case Pure(v) => f(v) } + } + + implicit object OptionMonad extends Monad[Option] { + def return_[T]: T => Option[T] = Some(_) + def bind[T, S]: Option[T] => (T => Option[S]) => Option[S] = v => v.flatMap + } + +} From 21bbeaf50ebbf2330f5faa43389744cac95abba1 Mon Sep 17 00:00:00 2001 From: Uma Zalakain Date: Thu, 11 Feb 2021 15:50:08 +0000 Subject: [PATCH 02/11] Refactor VisitAndRebuild into traverse --- .../scala/shine/DPIA/Phrases/Phrase.scala | 5 + .../scala/shine/DPIA/Phrases/traverse.scala | 175 ++++++++++++++++++ .../scala/shine/DPIA/Types/PhraseType.scala | 1 + src/main/scala/util/monad.scala | 6 +- 4 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 src/main/scala/shine/DPIA/Phrases/traverse.scala diff --git a/src/main/scala/shine/DPIA/Phrases/Phrase.scala b/src/main/scala/shine/DPIA/Phrases/Phrase.scala index bf0605d35..d204bebe4 100644 --- a/src/main/scala/shine/DPIA/Phrases/Phrase.scala +++ b/src/main/scala/shine/DPIA/Phrases/Phrase.scala @@ -8,6 +8,7 @@ import shine.DPIA.Semantics.OperationalSemantics.{IndexData, NatData} import shine.DPIA.Types._ import shine.DPIA.Types.TypeCheck._ import shine.DPIA._ +import shine.DPIA.Phrases.traverse._ import shine.DPIA.primitives.functional.NatAsIndex sealed trait Phrase[T <: PhraseType] { @@ -46,6 +47,7 @@ final case class DepLambda[K <: Kind, T <: PhraseType](x: K#I, body: Phrase[T]) extends Phrase[K `()->:` T] { override val t: DepFunType[K, T] = DepFunType[K, T](x, body.t) override def toString: String = s"Λ(${x.name} : ${kn.get}). $body" + val kindName = implicitly(kn) } object DepLambda { @@ -367,6 +369,9 @@ sealed trait Primitive[T <: PhraseType] extends Phrase[T] { def xmlPrinter: xml.Elem = throw new Exception("xmlPrinter should be implemented by a macro") + def traverse[M[_]](f: Traversal[M]): M[Phrase[T]] = + throw new Exception("traverse should be implemented by a macro") + def visitAndRebuild(f: VisitAndRebuild.Visitor): Phrase[T] = throw new Exception("visitAndRebuild should be implemented by a macro") } diff --git a/src/main/scala/shine/DPIA/Phrases/traverse.scala b/src/main/scala/shine/DPIA/Phrases/traverse.scala new file mode 100644 index 000000000..4f661b343 --- /dev/null +++ b/src/main/scala/shine/DPIA/Phrases/traverse.scala @@ -0,0 +1,175 @@ +package shine.DPIA.Phrases + +import util.monad +import shine.DPIA.Types._ +import shine.DPIA._ +import shine.DPIA.Semantics.OperationalSemantics._ + +object traverse { + // Reexport util.monad.* + type Monad[M[_]] = monad.Monad[M] + type Pure[T] = monad.Pure[T] + implicit def monadicSyntax[M[_], A](m: M[A])(implicit tc: monad.Monad[M]) + = monad.monadicSyntax(m)(tc) + val PureMonad = monad.PureMonad + val OptionMonad = monad.OptionMonad + + trait ExprTraversal[M[_]] extends Traversal[M] { + override def `type` : PhraseType => M[PhraseType] = return_ + } + trait PureTraversal extends Traversal[Pure] { override def monad = PureMonad } + trait PureExprTraversal extends PureTraversal with ExprTraversal[Pure] + + def apply[T <: PhraseType](e : Phrase[T], f : PureTraversal) : Phrase[T] = f.phrase(e).unwrap + def apply[T <: PhraseType, M[_]](e : Phrase[T], f : Traversal[M]) : M[Phrase[T]] = f.phrase(e) + def apply(t : PhraseType, f : PureTraversal) : PhraseType = f.`type`(t).unwrap + def apply[M[_]](e : PhraseType, f : Traversal[M]) : M[PhraseType] = f.`type`(e) + + sealed trait VarType + case object Binding extends VarType + case object Reference extends VarType + + trait Traversal[M[_]] { + protected[this] implicit def monad: Monad[M] + def return_[T]: T => M[T] = monad.return_ + def bind[T, S]: M[T] => (T => M[S]) => M[S] = monad.bind + + def identifier[T <: PhraseType]: VarType => Identifier[T] => M[Identifier[T]] = _ => i => + for {t1 <- `type`(i.t)} + yield Identifier(i.name, t1) + def typeIdentifier[I <: Kind.Identifier]: VarType => I => M[I] = _ => return_ + def typeIdentifierDispatch[I <: Kind.Identifier]: VarType => I => M[I] = vt => i => (i match { + case n: NatIdentifier => bind(typeIdentifier(vt)(n))(nat) + case dt: DataTypeIdentifier => bind(typeIdentifier(vt)(dt))(datatype) + case a: AddressSpaceIdentifier => bind(typeIdentifier(vt)(a))(addressSpace) + case ac: AccessTypeIdentifier => bind(typeIdentifier(vt)(ac))(accessType) + case n2n: NatToNatIdentifier => bind(typeIdentifier(vt)(n2n))(natToNat) + case n2d: NatToDataIdentifier => bind(typeIdentifier(vt)(n2d))(natToData) + }).asInstanceOf[M[I]] + + def nat: Nat => M[Nat] = return_ + def addressSpace: AddressSpace => M[AddressSpace] = return_ + def accessType: AccessType => M[AccessType] = return_ + def data: Data => M[Data] = { + case VectorData(vd) => return_(VectorData(vd) : Data) + case NatData(n) => + for { n1 <- nat(n) } + yield NatData(n1) + case IndexData(i, n) => + for { i1 <- nat(i); n1 <- nat(n) } + yield IndexData(i1, n1) + case ArrayData(ad) => + for { ad1 <- monad.traverseV(ad.map(data)) } + yield ArrayData(ad1) + case PairData(l, r) => + for { l1 <- data(l); r1 <- data(r) } + yield PairData(l1, r1) + } + + def datatype[D <: DataType] : D => M[D] = { + case NatType => return_(NatType.asInstanceOf[D]) + case s : ScalarType => return_(s : D) + case IndexType(size) => + for {n1 <- nat(size)} + yield IndexType(n1) + case ArrayType(size, dt) => + for {n1 <- nat(size); dt1 <- datatype(dt)} + yield ArrayType(n1, dt1) + case DepArrayType(n, n2d) => + for {n1 <- nat(n); n2d1 <- natToData(n2d)} + yield DepArrayType(n1, n2d1) + case VectorType(size, dt) => + for {n1 <- nat(size); dt1 <- datatype(dt)} + yield VectorType(n1, dt1) + case PairType(l, r) => + for {l1 <- datatype(l); r1 <- datatype(r)} + yield PairType(l1, l1) + case pair@DepPairType(x, e) => + for {x1 <- typeIdentifierDispatch(Binding)(x); e1 <- datatype(e)} + yield DepPairType(x1, e1) + case NatToDataApply(ntdf, n) => + for {ntdf1 <- natToData(ntdf); n1 <- nat(n)} + yield NatToDataApply(ntdf1, n1) + } + + def natToNat[N <: NatToNat]: N => M[N] = { + case i: NatToNatIdentifier => return_(i: N) + case NatToNatLambda(n, body) => + for {n1 <- typeIdentifierDispatch(Binding)(n); body1 <- nat(body)} + yield NatToNatLambda(n1, body1) + } + + def natToData[N <: NatToData]: N => M[N] = { + case i: NatToDataIdentifier => return_(i: N) + case NatToDataLambda(n, body) => + for {n1 <- typeIdentifierDispatch(Binding)(n); body1 <- datatype(body)} + yield NatToDataLambda(n1, body1) + } + + def phrase[T <: PhraseType]: Phrase[T] => M[Phrase[T]] = { + case i: Identifier[T] => for {i1 <- identifier(Reference)(i)} yield i1 + case Lambda(x, p) => + for {x1 <- identifier(Binding)(x); p1 <- phrase(p)} + yield Lambda(x1, p1) + case Apply(p, q) => + for {p1 <- phrase(p); q1 <- phrase(q)} + yield Apply(p1, q1) + case dl@DepLambda(i, p) => + for {i1 <- typeIdentifierDispatch(Binding)(i); p1 <- phrase(p)} + yield DepLambda(i1, p1)(dl.kindName) + case DepApply(p, i) => + for {p1 <- phrase(p); i1 <- typeIdentifierDispatch(Reference)(i)} + yield DepApply(p1, i1) + case LetNat(i, defn, body) => + for {i1 <- typeIdentifierDispatch(Binding)(i); defn1 <- phrase(defn); body1 <- phrase(body)} + yield LetNat(i1, defn1, body1) + case PhrasePair(p, q) => + for {p1 <- phrase(p); q1 <- phrase(q)} + yield PhrasePair(p1, q1) + case Proj1(p) => + for {p1 <- phrase(p)} + yield Proj1(p1) + case Proj2(p) => + for {p1 <- phrase(p)} + yield Proj2(p1) + case IfThenElse(cond, thenP, elseP) => + for {cond1 <- phrase(cond); thenP1 <- phrase(thenP); elseP1 <- phrase(elseP)} + yield IfThenElse(cond1, thenP1, elseP1) + case Literal(d) => + for {d1 <- data(d)} + yield Literal(d1) + case Natural(n) => + for {n1 <- nat(n)} + yield Natural(n1) + case UnaryOp(op, x) => + for {x1 <- phrase(x)} + yield UnaryOp(op, x1) + case BinOp(op, lhs, rhs) => + for {lhs1 <- phrase(lhs); rhs1 <- phrase(rhs)} + yield BinOp(op, lhs1, rhs1) + case c: Primitive[T] => c.traverse(this) + } + + def `type`: PhraseType => M[PhraseType] = { + case CommType() => return_(CommType(): PhraseType) + case ExpType(dt, w) => + for {dt1 <- datatype(dt); w1 <- accessType(w)} + yield ExpType(dt1, w1) + case AccType(dt) => + for {dt1 <- datatype(dt)} + yield AccType(dt1) + case PhrasePairType(l, r) => + for {l1 <- `type`(l); r1 <- `type`(r)} + yield PhrasePairType(l1, r1) + case FunType(inT, outT) => + for {inT1 <- `type`(inT); outT1 <- `type`(outT)} + yield FunType(inT1, outT1) + case PassiveFunType(inT, outT) => + for {inT1 <- `type`(inT); outT1 <- `type`(outT)} + yield PassiveFunType(inT1, outT1) + case df@DepFunType(x, t) => + for {x1 <- typeIdentifierDispatch(Binding)(x); t1 <- `type`(t)} + yield DepFunType(x1, t1)(df.kindName) + } + } +} \ No newline at end of file diff --git a/src/main/scala/shine/DPIA/Types/PhraseType.scala b/src/main/scala/shine/DPIA/Types/PhraseType.scala index b8e79742a..600db2a9c 100644 --- a/src/main/scala/shine/DPIA/Types/PhraseType.scala +++ b/src/main/scala/shine/DPIA/Types/PhraseType.scala @@ -44,6 +44,7 @@ final case class DepFunType[K <: Kind, +R <: PhraseType](x: K#I, t: R) (implicit val kn: KindName[K]) extends PhraseType { override def toString = s"(${x.name}: ${kn.get}) -> $t" + val kindName = implicitly(kn) } object PhraseType { diff --git a/src/main/scala/util/monad.scala b/src/main/scala/util/monad.scala index 2e10fd78c..d0c3324db 100644 --- a/src/main/scala/util/monad.scala +++ b/src/main/scala/util/monad.scala @@ -7,7 +7,11 @@ object monad { def return_[T] : T => M[T] def bind[T,S] : M[T] => (T => M[S]) => M[S] def traverse[A] : Seq[M[A]] => M[Seq[A]] = - _.foldRight(return_(Nil : Seq[A]))({case (mx, mxs) => + _.foldRight(return_(Seq() : Seq[A]))({case (mx, mxs) => + bind(mx)(x => bind(mxs)(xs => return_(x +: xs)))}) + // FIXME: We should be able to use S[_] <: Seq[_] for both + def traverseV[A] : Vector[M[A]] => M[Vector[A]] = + _.foldRight(return_(Vector() : Vector[A]))({case (mx, mxs) => bind(mx)(x => bind(mxs)(xs => return_(x +: xs)))}) } From cf16ef41aed04d972709489a7d456312cbb81330 Mon Sep 17 00:00:00 2001 From: Uma Zalakain Date: Thu, 11 Feb 2021 17:52:37 +0000 Subject: [PATCH 03/11] Add macro support for Primitive.traverse --- .../main/scala/shine/macros/Primitive.scala | 47 ++++++++++++ .../scala/shine/DPIA/Phrases/Phrase.scala | 1 - .../scala/shine/DPIA/Phrases/traverse.scala | 72 +++++++++++-------- 3 files changed, 91 insertions(+), 29 deletions(-) diff --git a/macros/src/main/scala/shine/macros/Primitive.scala b/macros/src/main/scala/shine/macros/Primitive.scala index 731bf0273..dd10b9bb1 100644 --- a/macros/src/main/scala/shine/macros/Primitive.scala +++ b/macros/src/main/scala/shine/macros/Primitive.scala @@ -40,6 +40,49 @@ object Primitive { def makeLowerCaseName(s: String): String = s"${Character.toLowerCase(s.charAt(0))}${s.substring(1)}" + def makeTraverse(name: TypeName, + additionalParams: List[ValDef], + params: List[ValDef]): Tree = { + + val paramNames = params.map { case ValDef(_, name, _, _) => q"$name" } + val additionalParamNames = additionalParams.map { case ValDef(_, name, _, _) => q"$name" } + + def forLoopBindings(v : Tree) : List[Tree] = params.map { + case ValDef(_, name, tpt, _) => fq"${name} <- ${traverseCall(v, name)(tpt)}" + } + + def traverseCall(v : Tree, name : TermName) : Tree => Tree = { + case Ident(TypeName("DataType")) | Ident(TypeName("ScalarType")) | + Ident(TypeName("BasicType")) => q"$v.datatype($name)" + case Ident(TypeName("Data")) => q"$v.data($name)" + case Ident(TypeName("Nat")) => q"$v.nat($name)" + case Ident(TypeName("NatIdentifier")) => q"$v.nat($name)" + case Ident(TypeName("NatToNat")) => q"$v.natToNat($name)" + case Ident(TypeName("NatToData")) => q"$v.natToData($name)" + case Ident(TypeName("AccessType")) => q"$v.accessType($name)" + case Ident(TypeName("AddressSpace")) => q"$v.addressSpace($name)" + // Phrase[ExpType] + case AppliedTypeTree((Ident(TypeName("Phrase")), _)) => q"$v.phrase($name)" + // Vector[Phrase[ExpType]] + case AppliedTypeTree((Ident(TypeName("Vector")), + List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) + | AppliedTypeTree((Ident(TypeName("Seq")), + List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) => q"$name.map($v.phrase(_))" + case _ => + c.error(c.enclosingPosition, s"could not translate `${name.toString}'\n") + q"$name" + } + + val v = q"v" + q""" + override def traverse[M[_]]($v: shine.DPIA.Phrases.traverse.Traversal[M]): M[$name] = { + import shine.DPIA.Phrases.traverse._ + import scala.language.implicitConversions + for (..${forLoopBindings(v)}) yield new $name (..${additionalParamNames}, ..${paramNames}) + } + """ + } + def makeVisitAndRebuild(name: TypeName, additionalParams: List[ValDef], params: List[ValDef]): Tree = { @@ -185,12 +228,16 @@ object Primitive { } def makePrimitiveClass : ClassInfo => ClassDef = { case ClassInfo(name, additionalParams, params, body, parents) => + val traverseMissinng = + body.collectFirst({ case DefDef(_, TermName("traverseMissing"), _, _, _, _) => ()}).isEmpty val visitAndRebuildMissing = body.collectFirst({ case DefDef(_, TermName("visitAndRebuild"), _, _, _, _) => ()}).isEmpty val xmlPrinterMissing = body.collectFirst({ case DefDef(_, TermName("xmlPrinter"), _, _, _, _) => ()}).isEmpty val generated = q""" + ${if (traverseMissinng) makeTraverse(name, additionalParams, params) else q""} + ${if (visitAndRebuildMissing) makeVisitAndRebuild(name, additionalParams, params) else q""} diff --git a/src/main/scala/shine/DPIA/Phrases/Phrase.scala b/src/main/scala/shine/DPIA/Phrases/Phrase.scala index d204bebe4..fa02aba07 100644 --- a/src/main/scala/shine/DPIA/Phrases/Phrase.scala +++ b/src/main/scala/shine/DPIA/Phrases/Phrase.scala @@ -47,7 +47,6 @@ final case class DepLambda[K <: Kind, T <: PhraseType](x: K#I, body: Phrase[T]) extends Phrase[K `()->:` T] { override val t: DepFunType[K, T] = DepFunType[K, T](x, body.t) override def toString: String = s"Λ(${x.name} : ${kn.get}). $body" - val kindName = implicitly(kn) } object DepLambda { diff --git a/src/main/scala/shine/DPIA/Phrases/traverse.scala b/src/main/scala/shine/DPIA/Phrases/traverse.scala index 4f661b343..dccc3ba42 100644 --- a/src/main/scala/shine/DPIA/Phrases/traverse.scala +++ b/src/main/scala/shine/DPIA/Phrases/traverse.scala @@ -1,5 +1,6 @@ package shine.DPIA.Phrases +import scala.language.implicitConversions import util.monad import shine.DPIA.Types._ import shine.DPIA._ @@ -15,15 +16,15 @@ object traverse { val OptionMonad = monad.OptionMonad trait ExprTraversal[M[_]] extends Traversal[M] { - override def `type` : PhraseType => M[PhraseType] = return_ + override def `type`[T <: PhraseType] : T => M[T] = return_ } trait PureTraversal extends Traversal[Pure] { override def monad = PureMonad } trait PureExprTraversal extends PureTraversal with ExprTraversal[Pure] def apply[T <: PhraseType](e : Phrase[T], f : PureTraversal) : Phrase[T] = f.phrase(e).unwrap def apply[T <: PhraseType, M[_]](e : Phrase[T], f : Traversal[M]) : M[Phrase[T]] = f.phrase(e) - def apply(t : PhraseType, f : PureTraversal) : PhraseType = f.`type`(t).unwrap - def apply[M[_]](e : PhraseType, f : Traversal[M]) : M[PhraseType] = f.`type`(e) + def apply[T <: PhraseType] (t : T, f : PureTraversal) : T = f.`type`(t).unwrap + def apply[T <: PhraseType, M[_]](e : T, f : Traversal[M]) : M[T] = f.`type`(e) sealed trait VarType case object Binding extends VarType @@ -53,22 +54,23 @@ object traverse { def data: Data => M[Data] = { case VectorData(vd) => return_(VectorData(vd) : Data) case NatData(n) => - for { n1 <- nat(n) } - yield NatData(n1) + for { n1 <- nat(n) } + yield NatData(n1) case IndexData(i, n) => - for { i1 <- nat(i); n1 <- nat(n) } - yield IndexData(i1, n1) + for { i1 <- nat(i); n1 <- nat(n) } + yield IndexData(i1, n1) case ArrayData(ad) => - for { ad1 <- monad.traverseV(ad.map(data)) } - yield ArrayData(ad1) + for { ad1 <- monad.traverseV(ad.map(data)) } + yield ArrayData(ad1) case PairData(l, r) => - for { l1 <- data(l); r1 <- data(r) } - yield PairData(l1, r1) + for { l1 <- data(l); r1 <- data(r) } + yield PairData(l1, r1) + case d => return_(d) } - def datatype[D <: DataType] : D => M[D] = { - case NatType => return_(NatType.asInstanceOf[D]) - case s : ScalarType => return_(s : D) + def datatype : DataType => M[DataType] = { + case NatType => return_(NatType : DataType) + case s : ScalarType => return_(s : DataType) case IndexType(size) => for {n1 <- nat(size)} yield IndexType(n1) @@ -80,7 +82,7 @@ object traverse { yield DepArrayType(n1, n2d1) case VectorType(size, dt) => for {n1 <- nat(size); dt1 <- datatype(dt)} - yield VectorType(n1, dt1) + yield VectorType(n1, dt1.asInstanceOf[ScalarType]) case PairType(l, r) => for {l1 <- datatype(l); r1 <- datatype(r)} yield PairType(l1, l1) @@ -92,15 +94,15 @@ object traverse { yield NatToDataApply(ntdf1, n1) } - def natToNat[N <: NatToNat]: N => M[N] = { - case i: NatToNatIdentifier => return_(i: N) + def natToNat: NatToNat => M[NatToNat] = { + case i: NatToNatIdentifier => return_(i : NatToNat) case NatToNatLambda(n, body) => for {n1 <- typeIdentifierDispatch(Binding)(n); body1 <- nat(body)} yield NatToNatLambda(n1, body1) } - def natToData[N <: NatToData]: N => M[N] = { - case i: NatToDataIdentifier => return_(i: N) + def natToData: NatToData => M[NatToData] = { + case i: NatToDataIdentifier => return_(i) case NatToDataLambda(n, body) => for {n1 <- typeIdentifierDispatch(Binding)(n); body1 <- datatype(body)} yield NatToDataLambda(n1, body1) @@ -117,12 +119,26 @@ object traverse { case dl@DepLambda(i, p) => for {i1 <- typeIdentifierDispatch(Binding)(i); p1 <- phrase(p)} yield DepLambda(i1, p1)(dl.kindName) - case DepApply(p, i) => - for {p1 <- phrase(p); i1 <- typeIdentifierDispatch(Reference)(i)} - yield DepApply(p1, i1) - case LetNat(i, defn, body) => - for {i1 <- typeIdentifierDispatch(Binding)(i); defn1 <- phrase(defn); body1 <- phrase(body)} - yield LetNat(i1, defn1, body1) + case da@DepApply(f, x) => x match { + case n: Nat => + for {f1 <- phrase(f); n1 <- nat(n)} + yield DepApply[NatKind, T](f1.asInstanceOf[Phrase[NatKind `()->:` T]], n1) + case dt: DataType => + for {f1 <- phrase(f); dt1 <- datatype(dt)} + yield DepApply[DataKind, T](f1.asInstanceOf[Phrase[DataKind `()->:` T]], dt1) + case a: AddressSpace => + for {f1 <- phrase(f); a1 <- addressSpace(a)} + yield DepApply[AddressSpaceKind, T](f1.asInstanceOf[Phrase[AddressSpaceKind `()->:` T]], a1) + case n2n: NatToNat => + for {f1 <- phrase(f); n2n1 <- natToNat(n2n)} + yield DepApply[NatToNatKind, T](f1.asInstanceOf[Phrase[NatToNatKind `()->:` T]], n2n1) + case n2d: NatToData => + for {f1 <- phrase(f); n2d1 <- natToData(n2d)} + yield DepApply[NatToDataKind, T](f1.asInstanceOf[Phrase[NatToDataKind `()->:` T]], n2d1) + } + case LetNat(binder, defn, body) => + for {defn1 <- phrase(defn); body1 <- phrase(body)} + yield LetNat(binder, defn1, body1) case PhrasePair(p, q) => for {p1 <- phrase(p); q1 <- phrase(q)} yield PhrasePair(p1, q1) @@ -150,8 +166,8 @@ object traverse { case c: Primitive[T] => c.traverse(this) } - def `type`: PhraseType => M[PhraseType] = { - case CommType() => return_(CommType(): PhraseType) + def `type`[T <: PhraseType] : T => M[T] = t => (t match { + case CommType() => return_(CommType()) case ExpType(dt, w) => for {dt1 <- datatype(dt); w1 <- accessType(w)} yield ExpType(dt1, w1) @@ -170,6 +186,6 @@ object traverse { case df@DepFunType(x, t) => for {x1 <- typeIdentifierDispatch(Binding)(x); t1 <- `type`(t)} yield DepFunType(x1, t1)(df.kindName) - } + }).asInstanceOf[M[T]] } } \ No newline at end of file From 452fa7cfa1c40d9bcf731c4552058bb3464f9288 Mon Sep 17 00:00:00 2001 From: Uma Zalakain Date: Thu, 11 Feb 2021 18:59:55 +0000 Subject: [PATCH 04/11] Solve most of the issues with the macro --- .../main/scala/shine/macros/Primitive.scala | 65 ++++++++++--------- .../scala/shine/DPIA/Phrases/Phrase.scala | 1 + .../scala/shine/DPIA/Phrases/traverse.scala | 2 +- 3 files changed, 35 insertions(+), 33 deletions(-) diff --git a/macros/src/main/scala/shine/macros/Primitive.scala b/macros/src/main/scala/shine/macros/Primitive.scala index dd10b9bb1..c72d82683 100644 --- a/macros/src/main/scala/shine/macros/Primitive.scala +++ b/macros/src/main/scala/shine/macros/Primitive.scala @@ -40,45 +40,46 @@ object Primitive { def makeLowerCaseName(s: String): String = s"${Character.toLowerCase(s.charAt(0))}${s.substring(1)}" + def makeTraverseCall(v : Tree, name : TermName) : Tree => Option[Tree] = { + case Ident(TypeName("DataType")) | Ident(TypeName("ScalarType")) | + Ident(TypeName("BasicType")) => Some(fq"${name} <- $v.datatype($name)") + case Ident(TypeName("Data")) => Some(fq"${name} <- $v.data($name)") + case Ident(TypeName("Nat")) => Some(fq"${name} <- $v.nat($name)") + case Ident(TypeName("NatIdentifier")) => Some(fq"${name} <- $v.nat($name)") + case Ident(TypeName("NatToNat")) => Some(fq"${name} <- $v.natToNat($name)") + case Ident(TypeName("NatToData")) => Some(fq"${name} <- $v.natToData($name)") + case Ident(TypeName("AccessType")) => Some(fq"${name} <- $v.accessType($name)") + case Ident(TypeName("AddressSpace")) => Some(fq"${name} <- $v.addressSpace($name)") + // Phrase[ExpType] + case AppliedTypeTree((Ident(TypeName("Phrase")), _)) => Some(fq"${name} <- $v.phrase($name)") + // Vector[Phrase[ExpType]] + case AppliedTypeTree((Ident(TypeName("Vector")), + List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) + | AppliedTypeTree((Ident(TypeName("Seq")), + List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) => Some(fq"${name} <- $name.map($v.phrase(_))") + case _ => None + } + def makeTraverse(name: TypeName, additionalParams: List[ValDef], params: List[ValDef]): Tree = { + val v = q"v" val paramNames = params.map { case ValDef(_, name, _, _) => q"$name" } val additionalParamNames = additionalParams.map { case ValDef(_, name, _, _) => q"$name" } - - def forLoopBindings(v : Tree) : List[Tree] = params.map { - case ValDef(_, name, tpt, _) => fq"${name} <- ${traverseCall(v, name)(tpt)}" + val forLoopBindings : List[Tree] = params.flatMap { + case ValDef(_, name, tpt, _) => makeTraverseCall(v, name)(tpt) } + val construct = if (additionalParamNames.isEmpty) q"new $name(..$paramNames)" + else q"new $name(..$additionalParamNames)(..$paramNames)" + val forloop = if (forLoopBindings.isEmpty) q"monad.return_($construct)" + else q"for (..${forLoopBindings}) yield $construct" - def traverseCall(v : Tree, name : TermName) : Tree => Tree = { - case Ident(TypeName("DataType")) | Ident(TypeName("ScalarType")) | - Ident(TypeName("BasicType")) => q"$v.datatype($name)" - case Ident(TypeName("Data")) => q"$v.data($name)" - case Ident(TypeName("Nat")) => q"$v.nat($name)" - case Ident(TypeName("NatIdentifier")) => q"$v.nat($name)" - case Ident(TypeName("NatToNat")) => q"$v.natToNat($name)" - case Ident(TypeName("NatToData")) => q"$v.natToData($name)" - case Ident(TypeName("AccessType")) => q"$v.accessType($name)" - case Ident(TypeName("AddressSpace")) => q"$v.addressSpace($name)" - // Phrase[ExpType] - case AppliedTypeTree((Ident(TypeName("Phrase")), _)) => q"$v.phrase($name)" - // Vector[Phrase[ExpType]] - case AppliedTypeTree((Ident(TypeName("Vector")), - List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) - | AppliedTypeTree((Ident(TypeName("Seq")), - List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) => q"$name.map($v.phrase(_))" - case _ => - c.error(c.enclosingPosition, s"could not translate `${name.toString}'\n") - q"$name" - } - - val v = q"v" q""" override def traverse[M[_]]($v: shine.DPIA.Phrases.traverse.Traversal[M]): M[$name] = { - import shine.DPIA.Phrases.traverse._ - import scala.language.implicitConversions - for (..${forLoopBindings(v)}) yield new $name (..${additionalParamNames}, ..${paramNames}) + import util.monad._ + implicit val monad: Monad[M] = implicitly($v.monad) + $forloop } """ } @@ -228,15 +229,15 @@ object Primitive { } def makePrimitiveClass : ClassInfo => ClassDef = { case ClassInfo(name, additionalParams, params, body, parents) => - val traverseMissinng = - body.collectFirst({ case DefDef(_, TermName("traverseMissing"), _, _, _, _) => ()}).isEmpty + val traverseMissing = + body.collectFirst({ case DefDef(_, TermName("traverse"), _, _, _, _) => ()}).isEmpty val visitAndRebuildMissing = body.collectFirst({ case DefDef(_, TermName("visitAndRebuild"), _, _, _, _) => ()}).isEmpty val xmlPrinterMissing = body.collectFirst({ case DefDef(_, TermName("xmlPrinter"), _, _, _, _) => ()}).isEmpty val generated = q""" - ${if (traverseMissinng) makeTraverse(name, additionalParams, params) else q""} + ${if (traverseMissing) makeTraverse(name, additionalParams, params) else q""} ${if (visitAndRebuildMissing) makeVisitAndRebuild(name, additionalParams, params) diff --git a/src/main/scala/shine/DPIA/Phrases/Phrase.scala b/src/main/scala/shine/DPIA/Phrases/Phrase.scala index fa02aba07..8e1dc4519 100644 --- a/src/main/scala/shine/DPIA/Phrases/Phrase.scala +++ b/src/main/scala/shine/DPIA/Phrases/Phrase.scala @@ -47,6 +47,7 @@ final case class DepLambda[K <: Kind, T <: PhraseType](x: K#I, body: Phrase[T]) extends Phrase[K `()->:` T] { override val t: DepFunType[K, T] = DepFunType[K, T](x, body.t) override def toString: String = s"Λ(${x.name} : ${kn.get}). $body" + val kindName : KindName[K] = implicitly(kn) } object DepLambda { diff --git a/src/main/scala/shine/DPIA/Phrases/traverse.scala b/src/main/scala/shine/DPIA/Phrases/traverse.scala index dccc3ba42..457e2c13b 100644 --- a/src/main/scala/shine/DPIA/Phrases/traverse.scala +++ b/src/main/scala/shine/DPIA/Phrases/traverse.scala @@ -31,7 +31,7 @@ object traverse { case object Reference extends VarType trait Traversal[M[_]] { - protected[this] implicit def monad: Monad[M] + implicit def monad: Monad[M] def return_[T]: T => M[T] = monad.return_ def bind[T, S]: M[T] => (T => M[S]) => M[S] = monad.bind From 8d4c7eba90acafec0272324fd3bb3604a75b00ca Mon Sep 17 00:00:00 2001 From: Uma Zalakain Date: Thu, 11 Feb 2021 21:32:19 +0000 Subject: [PATCH 05/11] Temporary workaround for monads: make visit and rebuild always pure --- .../main/scala/shine/macros/Primitive.scala | 35 ++++++++----------- .../scala/shine/DPIA/Phrases/Phrase.scala | 2 +- .../scala/shine/DPIA/Phrases/traverse.scala | 20 +++++++---- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/macros/src/main/scala/shine/macros/Primitive.scala b/macros/src/main/scala/shine/macros/Primitive.scala index c72d82683..4193f0499 100644 --- a/macros/src/main/scala/shine/macros/Primitive.scala +++ b/macros/src/main/scala/shine/macros/Primitive.scala @@ -23,16 +23,14 @@ object Primitive { class Impl(val c: blackbox.Context) { import c.universe._ - def expPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees) - def accPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees) - def comPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees) + def expPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(getClassInfo(c)))(annottees) + def accPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(getClassInfo(c)))(annottees) + def comPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(getClassInfo(c)))(annottees) def primitive(transform : ClassDef => ClassDef)(annottees: Seq[c.Expr[Any]]): c.Expr[Any] = { annottees.map(_.tree) match { - case (cdef: ClassDef) :: Nil => - c.Expr(transform(cdef)) - case (cdef: ClassDef) :: (md: ModuleDef) :: Nil => - c.Expr(q"{${transform(cdef)}; $md}") + case (cdef: ClassDef) :: Nil => c.Expr(transform(cdef)) + case (cdef: ClassDef) :: (md: ModuleDef) :: Nil => c.Expr(q"{${transform(cdef)}; $md}") case _ => c.abort(c.enclosingPosition, "expected a class definition") } } @@ -45,7 +43,7 @@ object Primitive { Ident(TypeName("BasicType")) => Some(fq"${name} <- $v.datatype($name)") case Ident(TypeName("Data")) => Some(fq"${name} <- $v.data($name)") case Ident(TypeName("Nat")) => Some(fq"${name} <- $v.nat($name)") - case Ident(TypeName("NatIdentifier")) => Some(fq"${name} <- $v.nat($name)") + case Ident(TypeName("NatIdentifier")) => Some(fq"${name} <- $v.typeIdentifierDispatch(shine.DPIA.Phrases.traverse.Reference)($name)") // FIXME: icky case Ident(TypeName("NatToNat")) => Some(fq"${name} <- $v.natToNat($name)") case Ident(TypeName("NatToData")) => Some(fq"${name} <- $v.natToData($name)") case Ident(TypeName("AccessType")) => Some(fq"${name} <- $v.accessType($name)") @@ -54,16 +52,13 @@ object Primitive { case AppliedTypeTree((Ident(TypeName("Phrase")), _)) => Some(fq"${name} <- $v.phrase($name)") // Vector[Phrase[ExpType]] case AppliedTypeTree((Ident(TypeName("Vector")), - List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) - | AppliedTypeTree((Ident(TypeName("Seq")), - List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) => Some(fq"${name} <- $name.map($v.phrase(_))") + List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) => Some(fq"${name} <- monad.traverseV($name.map($v.phrase(_)))") + case AppliedTypeTree((Ident(TypeName("Seq")), + List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) => Some(fq"${name} <- monad.traverse($name.map($v.phrase(_)))") case _ => None } - def makeTraverse(name: TypeName, - additionalParams: List[ValDef], - params: List[ValDef]): Tree = { - + def makeTraverse(name: TypeName, additionalParams: List[ValDef], params: List[ValDef]): Tree = { val v = q"v" val paramNames = params.map { case ValDef(_, name, _, _) => q"$name" } val additionalParamNames = additionalParams.map { case ValDef(_, name, _, _) => q"$name" } @@ -72,13 +67,13 @@ object Primitive { } val construct = if (additionalParamNames.isEmpty) q"new $name(..$paramNames)" else q"new $name(..$additionalParamNames)(..$paramNames)" - val forloop = if (forLoopBindings.isEmpty) q"monad.return_($construct)" - else q"for (..${forLoopBindings}) yield $construct" + val forloop = if (forLoopBindings.isEmpty) q"$construct" + else q"(for (..${forLoopBindings}) yield $construct).unwrap" q""" - override def traverse[M[_]]($v: shine.DPIA.Phrases.traverse.Traversal[M]): M[$name] = { + override def traverse($v: shine.DPIA.Phrases.traverse.PureTraversal): $name = { import util.monad._ - implicit val monad: Monad[M] = implicitly($v.monad) + implicit val monad: Monad[Pure] = implicitly($v.monad) $forloop } """ @@ -193,7 +188,7 @@ object Primitive { body: List[Tree], parents: List[Tree]) - def primitivesFromClassDef: ClassDef => ClassInfo = { + def getClassInfo: ClassDef => ClassInfo = { case q"case class $name(..$params) extends { ..$_ } with ..$parents {..$body} " => ClassInfo( name.asInstanceOf[c.TypeName], diff --git a/src/main/scala/shine/DPIA/Phrases/Phrase.scala b/src/main/scala/shine/DPIA/Phrases/Phrase.scala index 8e1dc4519..d2421af75 100644 --- a/src/main/scala/shine/DPIA/Phrases/Phrase.scala +++ b/src/main/scala/shine/DPIA/Phrases/Phrase.scala @@ -369,7 +369,7 @@ sealed trait Primitive[T <: PhraseType] extends Phrase[T] { def xmlPrinter: xml.Elem = throw new Exception("xmlPrinter should be implemented by a macro") - def traverse[M[_]](f: Traversal[M]): M[Phrase[T]] = + def traverse(f: PureTraversal): Phrase[T] = throw new Exception("traverse should be implemented by a macro") def visitAndRebuild(f: VisitAndRebuild.Visitor): Phrase[T] = diff --git a/src/main/scala/shine/DPIA/Phrases/traverse.scala b/src/main/scala/shine/DPIA/Phrases/traverse.scala index 457e2c13b..10f10e409 100644 --- a/src/main/scala/shine/DPIA/Phrases/traverse.scala +++ b/src/main/scala/shine/DPIA/Phrases/traverse.scala @@ -18,7 +18,13 @@ object traverse { trait ExprTraversal[M[_]] extends Traversal[M] { override def `type`[T <: PhraseType] : T => M[T] = return_ } - trait PureTraversal extends Traversal[Pure] { override def monad = PureMonad } + trait PureTraversal extends Traversal[Pure] { + override def monad = PureMonad + override def phrase[T <: PhraseType]: Phrase[T] => Pure[Phrase[T]] = { + case c: Primitive[T] => return_(c.traverse(this)) + case p => super.phrase(p) + } + } trait PureExprTraversal extends PureTraversal with ExprTraversal[Pure] def apply[T <: PhraseType](e : Phrase[T], f : PureTraversal) : Phrase[T] = f.phrase(e).unwrap @@ -68,9 +74,9 @@ object traverse { case d => return_(d) } - def datatype : DataType => M[DataType] = { - case NatType => return_(NatType : DataType) - case s : ScalarType => return_(s : DataType) + def datatype[ D <:DataType] : D => M[D] = d => (d match { + case NatType => return_(NatType) + case s : ScalarType => return_(s) case IndexType(size) => for {n1 <- nat(size)} yield IndexType(n1) @@ -85,14 +91,14 @@ object traverse { yield VectorType(n1, dt1.asInstanceOf[ScalarType]) case PairType(l, r) => for {l1 <- datatype(l); r1 <- datatype(r)} - yield PairType(l1, l1) + yield PairType(l1, r1) case pair@DepPairType(x, e) => for {x1 <- typeIdentifierDispatch(Binding)(x); e1 <- datatype(e)} yield DepPairType(x1, e1) case NatToDataApply(ntdf, n) => for {ntdf1 <- natToData(ntdf); n1 <- nat(n)} yield NatToDataApply(ntdf1, n1) - } + }).asInstanceOf[M[D]] def natToNat: NatToNat => M[NatToNat] = { case i: NatToNatIdentifier => return_(i : NatToNat) @@ -163,7 +169,7 @@ object traverse { case BinOp(op, lhs, rhs) => for {lhs1 <- phrase(lhs); rhs1 <- phrase(rhs)} yield BinOp(op, lhs1, rhs1) - case c: Primitive[T] => c.traverse(this) + case c: Primitive[T] => return_(c) } def `type`[T <: PhraseType] : T => M[T] = t => (t match { From 90436c62b66b0d29181d4c189d4d7e3b5fbdedc5 Mon Sep 17 00:00:00 2001 From: Uma Zalakain Date: Thu, 11 Feb 2021 21:41:12 +0000 Subject: [PATCH 06/11] Make monads covariant on their contained type --- src/main/scala/rise/core/traverse.scala | 14 +++++++------- src/main/scala/shine/DPIA/Phrases/traverse.scala | 14 +++++++------- src/main/scala/util/monad.scala | 6 +++--- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/main/scala/rise/core/traverse.scala b/src/main/scala/rise/core/traverse.scala index 0575b7fa3..7a3274cee 100644 --- a/src/main/scala/rise/core/traverse.scala +++ b/src/main/scala/rise/core/traverse.scala @@ -7,9 +7,9 @@ import scala.language.implicitConversions object traverse { // Reexport util.monad.* - type Monad[M[_]] = monad.Monad[M] - type Pure[T] = monad.Pure[T] - implicit def monadicSyntax[M[_], A](m: M[A])(implicit tc: monad.Monad[M]) = monad.monadicSyntax(m)(tc) + type Monad[M[+_]] = monad.Monad[M] + type Pure[+T] = monad.Pure[T] + implicit def monadicSyntax[M[+_], A](m: M[A])(implicit tc: monad.Monad[M]) = monad.monadicSyntax(m)(tc) val PureMonad = monad.PureMonad val OptionMonad = monad.OptionMonad @@ -17,7 +17,7 @@ object traverse { case object Binding extends VarType case object Reference extends VarType - trait Traversal[M[_]] { + trait Traversal[M[+_]] { protected[this] implicit def monad : Monad[M] def return_[T] : T => M[T] = monad.return_ def bind[T,S] : M[T] => (T => M[S]) => M[S] = monad.bind @@ -173,7 +173,7 @@ object traverse { } } - trait ExprTraversal[M[_]] extends Traversal[M] { + trait ExprTraversal[M[+_]] extends Traversal[M] { override def `type`[T <: Type] : T => M[T] = return_ } @@ -181,7 +181,7 @@ object traverse { trait PureExprTraversal extends PureTraversal with ExprTraversal[Pure] def apply(e : Expr, f : PureTraversal) : Expr = f.expr(e).unwrap - def apply[M[_]](e : Expr, f : Traversal[M]) : M[Expr] = f.expr(e) + def apply[M[+_]](e : Expr, f : Traversal[M]) : M[Expr] = f.expr(e) def apply[T <: Type](t : T, f : PureTraversal) : T = f.`type`(t).unwrap - def apply[T <: Type, M[_]](e : T, f : Traversal[M]) : M[T] = f.`type`(e) + def apply[T <: Type, M[+_]](e : T, f : Traversal[M]) : M[T] = f.`type`(e) } diff --git a/src/main/scala/shine/DPIA/Phrases/traverse.scala b/src/main/scala/shine/DPIA/Phrases/traverse.scala index 10f10e409..48d2aabf6 100644 --- a/src/main/scala/shine/DPIA/Phrases/traverse.scala +++ b/src/main/scala/shine/DPIA/Phrases/traverse.scala @@ -8,14 +8,14 @@ import shine.DPIA.Semantics.OperationalSemantics._ object traverse { // Reexport util.monad.* - type Monad[M[_]] = monad.Monad[M] - type Pure[T] = monad.Pure[T] - implicit def monadicSyntax[M[_], A](m: M[A])(implicit tc: monad.Monad[M]) + type Monad[M[+_]] = monad.Monad[M] + type Pure[+T] = monad.Pure[T] + implicit def monadicSyntax[M[+_], A](m: M[A])(implicit tc: monad.Monad[M]) = monad.monadicSyntax(m)(tc) val PureMonad = monad.PureMonad val OptionMonad = monad.OptionMonad - trait ExprTraversal[M[_]] extends Traversal[M] { + trait ExprTraversal[M[+_]] extends Traversal[M] { override def `type`[T <: PhraseType] : T => M[T] = return_ } trait PureTraversal extends Traversal[Pure] { @@ -28,15 +28,15 @@ object traverse { trait PureExprTraversal extends PureTraversal with ExprTraversal[Pure] def apply[T <: PhraseType](e : Phrase[T], f : PureTraversal) : Phrase[T] = f.phrase(e).unwrap - def apply[T <: PhraseType, M[_]](e : Phrase[T], f : Traversal[M]) : M[Phrase[T]] = f.phrase(e) + def apply[T <: PhraseType, M[+_]](e : Phrase[T], f : Traversal[M]) : M[Phrase[T]] = f.phrase(e) def apply[T <: PhraseType] (t : T, f : PureTraversal) : T = f.`type`(t).unwrap - def apply[T <: PhraseType, M[_]](e : T, f : Traversal[M]) : M[T] = f.`type`(e) + def apply[T <: PhraseType, M[+_]](e : T, f : Traversal[M]) : M[T] = f.`type`(e) sealed trait VarType case object Binding extends VarType case object Reference extends VarType - trait Traversal[M[_]] { + trait Traversal[M[+_]] { implicit def monad: Monad[M] def return_[T]: T => M[T] = monad.return_ def bind[T, S]: M[T] => (T => M[S]) => M[S] = monad.bind diff --git a/src/main/scala/util/monad.scala b/src/main/scala/util/monad.scala index d0c3324db..d81005c47 100644 --- a/src/main/scala/util/monad.scala +++ b/src/main/scala/util/monad.scala @@ -3,7 +3,7 @@ package util import scala.language.implicitConversions object monad { - trait Monad[M[_]] { + trait Monad[M[+_]] { def return_[T] : T => M[T] def bind[T,S] : M[T] => (T => M[S]) => M[S] def traverse[A] : Seq[M[A]] => M[Seq[A]] = @@ -15,12 +15,12 @@ object monad { bind(mx)(x => bind(mxs)(xs => return_(x +: xs)))}) } - implicit def monadicSyntax[M[_], A](m: M[A])(implicit tc: Monad[M]) = new { + implicit def monadicSyntax[M[+_], A](m: M[A])(implicit tc: Monad[M]) = new { def map[B](f: A => B): M[B] = tc.bind(m)(a => tc.return_(f(a)) ) def flatMap[B](f: A => M[B]): M[B] = tc.bind(m)(f) } - case class Pure[T](unwrap : T) + case class Pure[+T](unwrap : T) implicit object PureMonad extends Monad[Pure] { override def return_[T] : T => Pure[T] = t => Pure(t) override def bind[T,S] : Pure[T] => (T => Pure[S]) => Pure[S] = From 44ebce102374b554da9e1bcfeebc740d83fa6af1 Mon Sep 17 00:00:00 2001 From: Uma Zalakain Date: Fri, 12 Feb 2021 10:37:55 +0000 Subject: [PATCH 07/11] Allow traverse on primitives to be monadic --- macros/src/main/scala/shine/macros/Primitive.scala | 8 ++++---- src/main/scala/shine/DPIA/Phrases/Phrase.scala | 2 +- src/main/scala/shine/DPIA/Phrases/traverse.scala | 10 ++-------- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/macros/src/main/scala/shine/macros/Primitive.scala b/macros/src/main/scala/shine/macros/Primitive.scala index 4193f0499..ec0c9669c 100644 --- a/macros/src/main/scala/shine/macros/Primitive.scala +++ b/macros/src/main/scala/shine/macros/Primitive.scala @@ -67,13 +67,13 @@ object Primitive { } val construct = if (additionalParamNames.isEmpty) q"new $name(..$paramNames)" else q"new $name(..$additionalParamNames)(..$paramNames)" - val forloop = if (forLoopBindings.isEmpty) q"$construct" - else q"(for (..${forLoopBindings}) yield $construct).unwrap" + val forloop = if (forLoopBindings.isEmpty) q"monad.return_($construct)" + else q"for (..${forLoopBindings}) yield $construct" q""" - override def traverse($v: shine.DPIA.Phrases.traverse.PureTraversal): $name = { + override def traverse[M[+_]]($v: shine.DPIA.Phrases.traverse.Traversal[M]): M[$name] = { import util.monad._ - implicit val monad: Monad[Pure] = implicitly($v.monad) + implicit val monad: Monad[M] = implicitly($v.monad) $forloop } """ diff --git a/src/main/scala/shine/DPIA/Phrases/Phrase.scala b/src/main/scala/shine/DPIA/Phrases/Phrase.scala index d2421af75..1ff1224ce 100644 --- a/src/main/scala/shine/DPIA/Phrases/Phrase.scala +++ b/src/main/scala/shine/DPIA/Phrases/Phrase.scala @@ -369,7 +369,7 @@ sealed trait Primitive[T <: PhraseType] extends Phrase[T] { def xmlPrinter: xml.Elem = throw new Exception("xmlPrinter should be implemented by a macro") - def traverse(f: PureTraversal): Phrase[T] = + def traverse[M[+_]](f: Traversal[M]): M[Phrase[T]] = throw new Exception("traverse should be implemented by a macro") def visitAndRebuild(f: VisitAndRebuild.Visitor): Phrase[T] = diff --git a/src/main/scala/shine/DPIA/Phrases/traverse.scala b/src/main/scala/shine/DPIA/Phrases/traverse.scala index 48d2aabf6..cafa80130 100644 --- a/src/main/scala/shine/DPIA/Phrases/traverse.scala +++ b/src/main/scala/shine/DPIA/Phrases/traverse.scala @@ -18,13 +18,7 @@ object traverse { trait ExprTraversal[M[+_]] extends Traversal[M] { override def `type`[T <: PhraseType] : T => M[T] = return_ } - trait PureTraversal extends Traversal[Pure] { - override def monad = PureMonad - override def phrase[T <: PhraseType]: Phrase[T] => Pure[Phrase[T]] = { - case c: Primitive[T] => return_(c.traverse(this)) - case p => super.phrase(p) - } - } + trait PureTraversal extends Traversal[Pure] { override def monad = PureMonad } trait PureExprTraversal extends PureTraversal with ExprTraversal[Pure] def apply[T <: PhraseType](e : Phrase[T], f : PureTraversal) : Phrase[T] = f.phrase(e).unwrap @@ -169,7 +163,7 @@ object traverse { case BinOp(op, lhs, rhs) => for {lhs1 <- phrase(lhs); rhs1 <- phrase(rhs)} yield BinOp(op, lhs1, rhs1) - case c: Primitive[T] => return_(c) + case c: Primitive[T] => c.traverse(this) } def `type`[T <: PhraseType] : T => M[T] = t => (t match { From 83c57548fed86aa58fecadfd89271033af2a012d Mon Sep 17 00:00:00 2001 From: Uma Zalakain Date: Fri, 12 Feb 2021 14:23:23 +0000 Subject: [PATCH 08/11] Fix covariance in tests --- src/main/scala/shine/DPIA/Phrases/traverse.scala | 2 +- src/test/scala/rise/core/traverseTest.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/scala/shine/DPIA/Phrases/traverse.scala b/src/main/scala/shine/DPIA/Phrases/traverse.scala index cafa80130..0efae2d65 100644 --- a/src/main/scala/shine/DPIA/Phrases/traverse.scala +++ b/src/main/scala/shine/DPIA/Phrases/traverse.scala @@ -82,7 +82,7 @@ object traverse { yield DepArrayType(n1, n2d1) case VectorType(size, dt) => for {n1 <- nat(size); dt1 <- datatype(dt)} - yield VectorType(n1, dt1.asInstanceOf[ScalarType]) + yield VectorType(n1, dt1) case PairType(l, r) => for {l1 <- datatype(l); r1 <- datatype(r)} yield PairType(l1, r1) diff --git a/src/test/scala/rise/core/traverseTest.scala b/src/test/scala/rise/core/traverseTest.scala index 84bcee19d..63f49be87 100644 --- a/src/test/scala/rise/core/traverseTest.scala +++ b/src/test/scala/rise/core/traverseTest.scala @@ -13,7 +13,7 @@ class traverseTest extends test_util.Tests { ) ) - case class Trace[T](unwrap : T) { val trace : Seq[Any] = Seq() } + case class Trace[+T](unwrap : T) { val trace : Seq[Any] = Seq() } implicit object TraceMonad extends Monad[Trace] { def write[T] : T => Trace[T] = t => new Trace(t) { override val trace : Seq[Any] = Seq(t)} override def return_[T] : T => Trace[T] = t => Trace(t) From ddecfb9bc4d8cd7c8be7a3f6a4a5db0b83c0475c Mon Sep 17 00:00:00 2001 From: Uma Zalakain Date: Fri, 12 Feb 2021 14:28:36 +0000 Subject: [PATCH 09/11] Get rid of unnecessary casts --- src/main/scala/rise/core/traverse.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/scala/rise/core/traverse.scala b/src/main/scala/rise/core/traverse.scala index 7a3274cee..d71da933e 100644 --- a/src/main/scala/rise/core/traverse.scala +++ b/src/main/scala/rise/core/traverse.scala @@ -40,7 +40,7 @@ object traverse { def nat : Nat => M[Nat] = return_ def addressSpace : AddressSpace => M[AddressSpace] = return_ def datatype : DataType => M[DataType] = { - case i: DataTypeIdentifier => return_(i.asInstanceOf[DataType]) + case i: DataTypeIdentifier => return_(i) case NatType => return_(NatType : DataType) case s : ScalarType => return_(s : DataType) case ArrayType(n, d) => @@ -67,14 +67,14 @@ object traverse { } def natToNat : NatToNat => M[NatToNat] = { - case i : NatToNatIdentifier => return_(i.asInstanceOf[NatToNat]) + case i : NatToNatIdentifier => return_(i) case NatToNatLambda(x, e) => for { x1 <- typeIdentifierDispatch(Binding)(x); e1 <- nat(e) } yield NatToNatLambda(x1, e1) } def natToData : NatToData => M[NatToData] = { - case i : NatToDataIdentifier => return_(i.asInstanceOf[NatToData]) + case i : NatToDataIdentifier => return_(i) case NatToDataLambda(x, e) => for { x1 <- typeIdentifierDispatch(Binding)(x); e1 <- datatype(e) } yield NatToDataLambda(x1, e1) From c84a1d10042c276ff2a3e243a3def55f7447c774 Mon Sep 17 00:00:00 2001 From: Uma Zalakain Date: Fri, 12 Feb 2021 14:36:47 +0000 Subject: [PATCH 10/11] Translate substitute --- .../scala/shine/DPIA/Phrases/Phrase.scala | 50 +++++++++---------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/src/main/scala/shine/DPIA/Phrases/Phrase.scala b/src/main/scala/shine/DPIA/Phrases/Phrase.scala index 1ff1224ce..d8c20e821 100644 --- a/src/main/scala/shine/DPIA/Phrases/Phrase.scala +++ b/src/main/scala/shine/DPIA/Phrases/Phrase.scala @@ -139,7 +139,7 @@ object Phrase { `for`: Phrase[T1], in: Phrase[T2]): Phrase[T2] = { var substCounter = 0 - object Visitor extends VisitAndRebuild.Visitor { + object Visitor extends PureTraversal { def renaming[X <: PhraseType](p: Phrase[X]): Phrase[X] = { case class Renaming(idMap: Map[String, String]) extends VisitAndRebuild.Visitor { override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = p match { @@ -169,33 +169,33 @@ object Phrase { } VisitAndRebuild(p, Renaming(Map())) } - override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = { - p match { - case `for` => - val newPh = if (substCounter == 0) ph else renaming(ph) - substCounter += 1 - Stop(newPh.asInstanceOf[Phrase[T]]) - case Natural(n) => - val v = NatIdentifier(`for` match { - case Identifier(name, _) => name - case _ => throw new Exception("This should never happen") - }) - - ph.t match { - case ExpType(NatType, _) => - Stop(Natural(Nat.substitute( - Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]]) - case ExpType(IndexType(_), _) => - Stop(Natural(Nat.substitute( - Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]]) - case _ => Continue(p, this) - } - case _ => Continue(p, this) - } + + // override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = { + override def phrase[T <: PhraseType]: Phrase[T] => Pure[Phrase[T]] = p => p match { + case `for` => + val newPh = if (substCounter == 0) ph else renaming(ph) + substCounter += 1 + return_(newPh.asInstanceOf[Phrase[T]]) + case Natural(n) => + val v = NatIdentifier(`for` match { + case Identifier(name, _) => name + case _ => throw new Exception("This should never happen") + }) + + ph.t match { + case ExpType(NatType, _) => + return_(Natural(Nat.substitute( + Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]]) + case ExpType(IndexType(_), _) => + return_(Natural(Nat.substitute( + Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]]) + case _ => super.phrase(p) + } + case _ => super.phrase(p) } } - VisitAndRebuild(in, Visitor) + Visitor.phrase(in).unwrap } def substitute[T2 <: PhraseType](substitutionMap: Map[Phrase[_], Phrase[_]], From 22fb2ccb471d6836a9d30a959e9c7e1a28eb7760 Mon Sep 17 00:00:00 2001 From: Uma Zalakain Date: Fri, 12 Feb 2021 14:43:15 +0000 Subject: [PATCH 11/11] Route NatIdentifiers directly through nat --- macros/src/main/scala/shine/macros/Primitive.scala | 2 +- src/main/scala/shine/DPIA/Phrases/traverse.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/macros/src/main/scala/shine/macros/Primitive.scala b/macros/src/main/scala/shine/macros/Primitive.scala index ec0c9669c..da490e20a 100644 --- a/macros/src/main/scala/shine/macros/Primitive.scala +++ b/macros/src/main/scala/shine/macros/Primitive.scala @@ -43,7 +43,7 @@ object Primitive { Ident(TypeName("BasicType")) => Some(fq"${name} <- $v.datatype($name)") case Ident(TypeName("Data")) => Some(fq"${name} <- $v.data($name)") case Ident(TypeName("Nat")) => Some(fq"${name} <- $v.nat($name)") - case Ident(TypeName("NatIdentifier")) => Some(fq"${name} <- $v.typeIdentifierDispatch(shine.DPIA.Phrases.traverse.Reference)($name)") // FIXME: icky + case Ident(TypeName("NatIdentifier")) => Some(fq"${name} <- $v.nat($name)") case Ident(TypeName("NatToNat")) => Some(fq"${name} <- $v.natToNat($name)") case Ident(TypeName("NatToData")) => Some(fq"${name} <- $v.natToData($name)") case Ident(TypeName("AccessType")) => Some(fq"${name} <- $v.accessType($name)") diff --git a/src/main/scala/shine/DPIA/Phrases/traverse.scala b/src/main/scala/shine/DPIA/Phrases/traverse.scala index 0efae2d65..125a396bf 100644 --- a/src/main/scala/shine/DPIA/Phrases/traverse.scala +++ b/src/main/scala/shine/DPIA/Phrases/traverse.scala @@ -48,7 +48,7 @@ object traverse { case n2d: NatToDataIdentifier => bind(typeIdentifier(vt)(n2d))(natToData) }).asInstanceOf[M[I]] - def nat: Nat => M[Nat] = return_ + def nat[N <: Nat] : N => M[N] = return_ def addressSpace: AddressSpace => M[AddressSpace] = return_ def accessType: AccessType => M[AccessType] = return_ def data: Data => M[Data] = {