diff --git a/macros/src/main/scala/shine/macros/Primitive.scala b/macros/src/main/scala/shine/macros/Primitive.scala index 4c5570f35..92ec8d987 100644 --- a/macros/src/main/scala/shine/macros/Primitive.scala +++ b/macros/src/main/scala/shine/macros/Primitive.scala @@ -23,16 +23,14 @@ object Primitive { class Impl(val c: blackbox.Context) { import c.universe._ - def expPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees) - def accPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees) - def comPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(primitivesFromClassDef(c)))(annottees) + def expPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(getClassInfo(c)))(annottees) + def accPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(getClassInfo(c)))(annottees) + def comPrimitive(annottees : c.Expr[Any]*): c.Expr[Any] = primitive(c => makePrimitiveClass(getClassInfo(c)))(annottees) def primitive(transform : ClassDef => ClassDef)(annottees: Seq[c.Expr[Any]]): c.Expr[Any] = { annottees.map(_.tree) match { - case (cdef: ClassDef) :: Nil => - c.Expr(transform(cdef)) - case (cdef: ClassDef) :: (md: ModuleDef) :: Nil => - c.Expr(q"{${transform(cdef)}; $md}") + case (cdef: ClassDef) :: Nil => c.Expr(transform(cdef)) + case (cdef: ClassDef) :: (md: ModuleDef) :: Nil => c.Expr(q"{${transform(cdef)}; $md}") case _ => c.abort(c.enclosingPosition, "expected a class definition") } } @@ -40,6 +38,47 @@ object Primitive { def makeLowerCaseName(s: String): String = s"${Character.toLowerCase(s.charAt(0))}${s.substring(1)}" + def makeTraverseCall(v : Tree, name : TermName) : Tree => Option[Tree] = { + case Ident(TypeName("DataType")) | Ident(TypeName("ScalarType")) | + Ident(TypeName("BasicType")) => Some(fq"${name} <- $v.datatype($name)") + case Ident(TypeName("Data")) => Some(fq"${name} <- $v.data($name)") + case Ident(TypeName("Nat")) => Some(fq"${name} <- $v.nat($name)") + case Ident(TypeName("NatIdentifier")) => Some(fq"${name} <- $v.nat($name)") + case Ident(TypeName("NatToNat")) => Some(fq"${name} <- $v.natToNat($name)") + case Ident(TypeName("NatToData")) => Some(fq"${name} <- $v.natToData($name)") + case Ident(TypeName("AccessType")) => Some(fq"${name} <- $v.accessType($name)") + case Ident(TypeName("AddressSpace")) => Some(fq"${name} <- $v.addressSpace($name)") + // Phrase[ExpType] + case AppliedTypeTree((Ident(TypeName("Phrase")), _)) => Some(fq"${name} <- $v.phrase($name)") + // Vector[Phrase[ExpType]] + case AppliedTypeTree((Ident(TypeName("Vector")), + List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) => Some(fq"${name} <- monad.traverseV($name.map($v.phrase(_)))") + case AppliedTypeTree((Ident(TypeName("Seq")), + List(AppliedTypeTree((Ident(TypeName("Phrase")), _))))) => Some(fq"${name} <- monad.traverse($name.map($v.phrase(_)))") + case _ => None + } + + def makeTraverse(name: TypeName, additionalParams: List[ValDef], params: List[ValDef]): Tree = { + val v = q"v" + val paramNames = params.map { case ValDef(_, name, _, _) => q"$name" } + val additionalParamNames = additionalParams.map { case ValDef(_, name, _, _) => q"$name" } + val forLoopBindings : List[Tree] = params.flatMap { + case ValDef(_, name, tpt, _) => makeTraverseCall(v, name)(tpt) + } + val construct = if (additionalParamNames.isEmpty) q"new $name(..$paramNames)" + else q"new $name(..$additionalParamNames)(..$paramNames)" + val forloop = if (forLoopBindings.isEmpty) q"monad.return_($construct)" + else q"for (..${forLoopBindings}) yield $construct" + + q""" + override def traverse[M[+_]]($v: shine.DPIA.Phrases.traverse.Traversal[M]): M[$name] = { + import util.monad._ + implicit val monad: Monad[M] = implicitly($v.monad) + $forloop + } + """ + } + def makeVisitAndRebuild(name: TypeName, additionalParams: List[ValDef], params: List[ValDef]): Tree = { @@ -151,7 +190,7 @@ object Primitive { body: List[Tree], parents: List[Tree]) - def primitivesFromClassDef: ClassDef => ClassInfo = { + def getClassInfo: ClassDef => ClassInfo = { case q"case class $name(..$params) extends { ..$_ } with ..$parents {..$body} " => ClassInfo( name.asInstanceOf[c.TypeName], @@ -187,12 +226,16 @@ object Primitive { } def makePrimitiveClass : ClassInfo => ClassDef = { case ClassInfo(name, additionalParams, params, body, parents) => + val traverseMissing = + body.collectFirst({ case DefDef(_, TermName("traverse"), _, _, _, _) => ()}).isEmpty val visitAndRebuildMissing = body.collectFirst({ case DefDef(_, TermName("visitAndRebuild"), _, _, _, _) => ()}).isEmpty val xmlPrinterMissing = body.collectFirst({ case DefDef(_, TermName("xmlPrinter"), _, _, _, _) => ()}).isEmpty val generated = q""" + ${if (traverseMissing) makeTraverse(name, additionalParams, params) else q""} + ${if (visitAndRebuildMissing) makeVisitAndRebuild(name, additionalParams, params) else q""} diff --git a/src/main/scala/rise/core/traverse.scala b/src/main/scala/rise/core/traverse.scala index 402527a62..448876cd0 100644 --- a/src/main/scala/rise/core/traverse.scala +++ b/src/main/scala/rise/core/traverse.scala @@ -1,29 +1,23 @@ package rise.core -import scala.language.implicitConversions -import arithexpr.arithmetic.NamedVar +import util.monad import rise.core.semantics._ import rise.core.types._ +import scala.language.implicitConversions object traverse { - trait Monad[M[_]] { - def return_[T] : T => M[T] - def bind[T,S] : M[T] => (T => M[S]) => M[S] - def traverse[A] : Seq[M[A]] => M[Seq[A]] = - _.foldRight(return_(Nil : Seq[A]))({case (mx, mxs) => - bind(mx)(x => bind(mxs)(xs => return_(x +: xs)))}) - } - - implicit def monadicSyntax[M[_], A](m: M[A])(implicit tc: Monad[M]) = new { - def map[B](f: A => B): M[B] = tc.bind(m)(a => tc.return_(f(a)) ) - def flatMap[B](f: A => M[B]): M[B] = tc.bind(m)(f) - } + // Reexport util.monad.* + type Monad[M[+_]] = monad.Monad[M] + type Pure[+T] = monad.Pure[T] + implicit def monadicSyntax[M[+_], A](m: M[A])(implicit tc: monad.Monad[M]) = monad.monadicSyntax(m)(tc) + val PureMonad = monad.PureMonad + val OptionMonad = monad.OptionMonad sealed trait VarType case object Binding extends VarType case object Reference extends VarType - trait Traversal[M[_]] { + trait Traversal[M[+_]] { protected[this] implicit def monad : Monad[M] def return_[T] : T => M[T] = monad.return_ def bind[T,S] : M[T] => (T => M[S]) => M[S] = monad.bind @@ -55,7 +49,7 @@ object traverse { def matrixLayout : MatrixLayout => M[MatrixLayout] = return_ def fragmentKind : FragmentKind => M[FragmentKind] = return_ def datatype : DataType => M[DataType] = { - case i: DataTypeIdentifier => return_(i.asInstanceOf[DataType]) + case i: DataTypeIdentifier => return_(i) case NatType => return_(NatType : DataType) case s : ScalarType => return_(s : DataType) case ArrayType(n, d) => @@ -86,14 +80,14 @@ object traverse { } def natToNat : NatToNat => M[NatToNat] = { - case i : NatToNatIdentifier => return_(i.asInstanceOf[NatToNat]) + case i : NatToNatIdentifier => return_(i) case NatToNatLambda(x, e) => for { x1 <- typeIdentifierDispatch(Binding)(x); e1 <- natDispatch(Reference)(e) } yield NatToNatLambda(x1, e1) } def natToData : NatToData => M[NatToData] = { - case i : NatToDataIdentifier => return_(i.asInstanceOf[NatToData]) + case i : NatToDataIdentifier => return_(i) case NatToDataLambda(x, e) => for { x1 <- typeIdentifierDispatch(Binding)(x); e1 <- datatype(e) } yield NatToDataLambda(x1, e1) @@ -192,27 +186,15 @@ object traverse { } } - trait ExprTraversal[M[_]] extends Traversal[M] { + trait ExprTraversal[M[+_]] extends Traversal[M] { override def `type`[T <: Type] : T => M[T] = return_ } - case class Pure[T](unwrap : T) - implicit object PureMonad extends Monad[Pure] { - override def return_[T] : T => Pure[T] = t => Pure(t) - override def bind[T,S] : Pure[T] => (T => Pure[S]) => Pure[S] = - v => f => v match { case Pure(v) => f(v) } - } - - implicit object OptionMonad extends Monad[Option] { - def return_[T]: T => Option[T] = Some(_) - def bind[T, S]: Option[T] => (T => Option[S]) => Option[S] = v => v.flatMap - } - trait PureTraversal extends Traversal[Pure] { override def monad = PureMonad } trait PureExprTraversal extends PureTraversal with ExprTraversal[Pure] def apply(e : Expr, f : PureTraversal) : Expr = f.expr(e).unwrap - def apply[M[_]](e : Expr, f : Traversal[M]) : M[Expr] = f.expr(e) + def apply[M[+_]](e : Expr, f : Traversal[M]) : M[Expr] = f.expr(e) def apply[T <: Type](t : T, f : PureTraversal) : T = f.`type`(t).unwrap - def apply[T <: Type, M[_]](e : T, f : Traversal[M]) : M[T] = f.`type`(e) + def apply[T <: Type, M[+_]](e : T, f : Traversal[M]) : M[T] = f.`type`(e) } diff --git a/src/main/scala/shine/DPIA/Phrases/Phrase.scala b/src/main/scala/shine/DPIA/Phrases/Phrase.scala index bf0605d35..d8c20e821 100644 --- a/src/main/scala/shine/DPIA/Phrases/Phrase.scala +++ b/src/main/scala/shine/DPIA/Phrases/Phrase.scala @@ -8,6 +8,7 @@ import shine.DPIA.Semantics.OperationalSemantics.{IndexData, NatData} import shine.DPIA.Types._ import shine.DPIA.Types.TypeCheck._ import shine.DPIA._ +import shine.DPIA.Phrases.traverse._ import shine.DPIA.primitives.functional.NatAsIndex sealed trait Phrase[T <: PhraseType] { @@ -46,6 +47,7 @@ final case class DepLambda[K <: Kind, T <: PhraseType](x: K#I, body: Phrase[T]) extends Phrase[K `()->:` T] { override val t: DepFunType[K, T] = DepFunType[K, T](x, body.t) override def toString: String = s"Λ(${x.name} : ${kn.get}). $body" + val kindName : KindName[K] = implicitly(kn) } object DepLambda { @@ -137,7 +139,7 @@ object Phrase { `for`: Phrase[T1], in: Phrase[T2]): Phrase[T2] = { var substCounter = 0 - object Visitor extends VisitAndRebuild.Visitor { + object Visitor extends PureTraversal { def renaming[X <: PhraseType](p: Phrase[X]): Phrase[X] = { case class Renaming(idMap: Map[String, String]) extends VisitAndRebuild.Visitor { override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = p match { @@ -167,33 +169,33 @@ object Phrase { } VisitAndRebuild(p, Renaming(Map())) } - override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = { - p match { - case `for` => - val newPh = if (substCounter == 0) ph else renaming(ph) - substCounter += 1 - Stop(newPh.asInstanceOf[Phrase[T]]) - case Natural(n) => - val v = NatIdentifier(`for` match { - case Identifier(name, _) => name - case _ => throw new Exception("This should never happen") - }) - - ph.t match { - case ExpType(NatType, _) => - Stop(Natural(Nat.substitute( - Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]]) - case ExpType(IndexType(_), _) => - Stop(Natural(Nat.substitute( - Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]]) - case _ => Continue(p, this) - } - case _ => Continue(p, this) - } + + // override def phrase[T <: PhraseType](p: Phrase[T]): Result[Phrase[T]] = { + override def phrase[T <: PhraseType]: Phrase[T] => Pure[Phrase[T]] = p => p match { + case `for` => + val newPh = if (substCounter == 0) ph else renaming(ph) + substCounter += 1 + return_(newPh.asInstanceOf[Phrase[T]]) + case Natural(n) => + val v = NatIdentifier(`for` match { + case Identifier(name, _) => name + case _ => throw new Exception("This should never happen") + }) + + ph.t match { + case ExpType(NatType, _) => + return_(Natural(Nat.substitute( + Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]]) + case ExpType(IndexType(_), _) => + return_(Natural(Nat.substitute( + Internal.transientNatFromExpr(ph.asInstanceOf[Phrase[ExpType]]).n, v, n)).asInstanceOf[Phrase[T]]) + case _ => super.phrase(p) + } + case _ => super.phrase(p) } } - VisitAndRebuild(in, Visitor) + Visitor.phrase(in).unwrap } def substitute[T2 <: PhraseType](substitutionMap: Map[Phrase[_], Phrase[_]], @@ -367,6 +369,9 @@ sealed trait Primitive[T <: PhraseType] extends Phrase[T] { def xmlPrinter: xml.Elem = throw new Exception("xmlPrinter should be implemented by a macro") + def traverse[M[+_]](f: Traversal[M]): M[Phrase[T]] = + throw new Exception("traverse should be implemented by a macro") + def visitAndRebuild(f: VisitAndRebuild.Visitor): Phrase[T] = throw new Exception("visitAndRebuild should be implemented by a macro") } diff --git a/src/main/scala/shine/DPIA/Phrases/traverse.scala b/src/main/scala/shine/DPIA/Phrases/traverse.scala new file mode 100644 index 000000000..125a396bf --- /dev/null +++ b/src/main/scala/shine/DPIA/Phrases/traverse.scala @@ -0,0 +1,191 @@ +package shine.DPIA.Phrases + +import scala.language.implicitConversions +import util.monad +import shine.DPIA.Types._ +import shine.DPIA._ +import shine.DPIA.Semantics.OperationalSemantics._ + +object traverse { + // Reexport util.monad.* + type Monad[M[+_]] = monad.Monad[M] + type Pure[+T] = monad.Pure[T] + implicit def monadicSyntax[M[+_], A](m: M[A])(implicit tc: monad.Monad[M]) + = monad.monadicSyntax(m)(tc) + val PureMonad = monad.PureMonad + val OptionMonad = monad.OptionMonad + + trait ExprTraversal[M[+_]] extends Traversal[M] { + override def `type`[T <: PhraseType] : T => M[T] = return_ + } + trait PureTraversal extends Traversal[Pure] { override def monad = PureMonad } + trait PureExprTraversal extends PureTraversal with ExprTraversal[Pure] + + def apply[T <: PhraseType](e : Phrase[T], f : PureTraversal) : Phrase[T] = f.phrase(e).unwrap + def apply[T <: PhraseType, M[+_]](e : Phrase[T], f : Traversal[M]) : M[Phrase[T]] = f.phrase(e) + def apply[T <: PhraseType] (t : T, f : PureTraversal) : T = f.`type`(t).unwrap + def apply[T <: PhraseType, M[+_]](e : T, f : Traversal[M]) : M[T] = f.`type`(e) + + sealed trait VarType + case object Binding extends VarType + case object Reference extends VarType + + trait Traversal[M[+_]] { + implicit def monad: Monad[M] + def return_[T]: T => M[T] = monad.return_ + def bind[T, S]: M[T] => (T => M[S]) => M[S] = monad.bind + + def identifier[T <: PhraseType]: VarType => Identifier[T] => M[Identifier[T]] = _ => i => + for {t1 <- `type`(i.t)} + yield Identifier(i.name, t1) + def typeIdentifier[I <: Kind.Identifier]: VarType => I => M[I] = _ => return_ + def typeIdentifierDispatch[I <: Kind.Identifier]: VarType => I => M[I] = vt => i => (i match { + case n: NatIdentifier => bind(typeIdentifier(vt)(n))(nat) + case dt: DataTypeIdentifier => bind(typeIdentifier(vt)(dt))(datatype) + case a: AddressSpaceIdentifier => bind(typeIdentifier(vt)(a))(addressSpace) + case ac: AccessTypeIdentifier => bind(typeIdentifier(vt)(ac))(accessType) + case n2n: NatToNatIdentifier => bind(typeIdentifier(vt)(n2n))(natToNat) + case n2d: NatToDataIdentifier => bind(typeIdentifier(vt)(n2d))(natToData) + }).asInstanceOf[M[I]] + + def nat[N <: Nat] : N => M[N] = return_ + def addressSpace: AddressSpace => M[AddressSpace] = return_ + def accessType: AccessType => M[AccessType] = return_ + def data: Data => M[Data] = { + case VectorData(vd) => return_(VectorData(vd) : Data) + case NatData(n) => + for { n1 <- nat(n) } + yield NatData(n1) + case IndexData(i, n) => + for { i1 <- nat(i); n1 <- nat(n) } + yield IndexData(i1, n1) + case ArrayData(ad) => + for { ad1 <- monad.traverseV(ad.map(data)) } + yield ArrayData(ad1) + case PairData(l, r) => + for { l1 <- data(l); r1 <- data(r) } + yield PairData(l1, r1) + case d => return_(d) + } + + def datatype[ D <:DataType] : D => M[D] = d => (d match { + case NatType => return_(NatType) + case s : ScalarType => return_(s) + case IndexType(size) => + for {n1 <- nat(size)} + yield IndexType(n1) + case ArrayType(size, dt) => + for {n1 <- nat(size); dt1 <- datatype(dt)} + yield ArrayType(n1, dt1) + case DepArrayType(n, n2d) => + for {n1 <- nat(n); n2d1 <- natToData(n2d)} + yield DepArrayType(n1, n2d1) + case VectorType(size, dt) => + for {n1 <- nat(size); dt1 <- datatype(dt)} + yield VectorType(n1, dt1) + case PairType(l, r) => + for {l1 <- datatype(l); r1 <- datatype(r)} + yield PairType(l1, r1) + case pair@DepPairType(x, e) => + for {x1 <- typeIdentifierDispatch(Binding)(x); e1 <- datatype(e)} + yield DepPairType(x1, e1) + case NatToDataApply(ntdf, n) => + for {ntdf1 <- natToData(ntdf); n1 <- nat(n)} + yield NatToDataApply(ntdf1, n1) + }).asInstanceOf[M[D]] + + def natToNat: NatToNat => M[NatToNat] = { + case i: NatToNatIdentifier => return_(i : NatToNat) + case NatToNatLambda(n, body) => + for {n1 <- typeIdentifierDispatch(Binding)(n); body1 <- nat(body)} + yield NatToNatLambda(n1, body1) + } + + def natToData: NatToData => M[NatToData] = { + case i: NatToDataIdentifier => return_(i) + case NatToDataLambda(n, body) => + for {n1 <- typeIdentifierDispatch(Binding)(n); body1 <- datatype(body)} + yield NatToDataLambda(n1, body1) + } + + def phrase[T <: PhraseType]: Phrase[T] => M[Phrase[T]] = { + case i: Identifier[T] => for {i1 <- identifier(Reference)(i)} yield i1 + case Lambda(x, p) => + for {x1 <- identifier(Binding)(x); p1 <- phrase(p)} + yield Lambda(x1, p1) + case Apply(p, q) => + for {p1 <- phrase(p); q1 <- phrase(q)} + yield Apply(p1, q1) + case dl@DepLambda(i, p) => + for {i1 <- typeIdentifierDispatch(Binding)(i); p1 <- phrase(p)} + yield DepLambda(i1, p1)(dl.kindName) + case da@DepApply(f, x) => x match { + case n: Nat => + for {f1 <- phrase(f); n1 <- nat(n)} + yield DepApply[NatKind, T](f1.asInstanceOf[Phrase[NatKind `()->:` T]], n1) + case dt: DataType => + for {f1 <- phrase(f); dt1 <- datatype(dt)} + yield DepApply[DataKind, T](f1.asInstanceOf[Phrase[DataKind `()->:` T]], dt1) + case a: AddressSpace => + for {f1 <- phrase(f); a1 <- addressSpace(a)} + yield DepApply[AddressSpaceKind, T](f1.asInstanceOf[Phrase[AddressSpaceKind `()->:` T]], a1) + case n2n: NatToNat => + for {f1 <- phrase(f); n2n1 <- natToNat(n2n)} + yield DepApply[NatToNatKind, T](f1.asInstanceOf[Phrase[NatToNatKind `()->:` T]], n2n1) + case n2d: NatToData => + for {f1 <- phrase(f); n2d1 <- natToData(n2d)} + yield DepApply[NatToDataKind, T](f1.asInstanceOf[Phrase[NatToDataKind `()->:` T]], n2d1) + } + case LetNat(binder, defn, body) => + for {defn1 <- phrase(defn); body1 <- phrase(body)} + yield LetNat(binder, defn1, body1) + case PhrasePair(p, q) => + for {p1 <- phrase(p); q1 <- phrase(q)} + yield PhrasePair(p1, q1) + case Proj1(p) => + for {p1 <- phrase(p)} + yield Proj1(p1) + case Proj2(p) => + for {p1 <- phrase(p)} + yield Proj2(p1) + case IfThenElse(cond, thenP, elseP) => + for {cond1 <- phrase(cond); thenP1 <- phrase(thenP); elseP1 <- phrase(elseP)} + yield IfThenElse(cond1, thenP1, elseP1) + case Literal(d) => + for {d1 <- data(d)} + yield Literal(d1) + case Natural(n) => + for {n1 <- nat(n)} + yield Natural(n1) + case UnaryOp(op, x) => + for {x1 <- phrase(x)} + yield UnaryOp(op, x1) + case BinOp(op, lhs, rhs) => + for {lhs1 <- phrase(lhs); rhs1 <- phrase(rhs)} + yield BinOp(op, lhs1, rhs1) + case c: Primitive[T] => c.traverse(this) + } + + def `type`[T <: PhraseType] : T => M[T] = t => (t match { + case CommType() => return_(CommType()) + case ExpType(dt, w) => + for {dt1 <- datatype(dt); w1 <- accessType(w)} + yield ExpType(dt1, w1) + case AccType(dt) => + for {dt1 <- datatype(dt)} + yield AccType(dt1) + case PhrasePairType(l, r) => + for {l1 <- `type`(l); r1 <- `type`(r)} + yield PhrasePairType(l1, r1) + case FunType(inT, outT) => + for {inT1 <- `type`(inT); outT1 <- `type`(outT)} + yield FunType(inT1, outT1) + case PassiveFunType(inT, outT) => + for {inT1 <- `type`(inT); outT1 <- `type`(outT)} + yield PassiveFunType(inT1, outT1) + case df@DepFunType(x, t) => + for {x1 <- typeIdentifierDispatch(Binding)(x); t1 <- `type`(t)} + yield DepFunType(x1, t1)(df.kindName) + }).asInstanceOf[M[T]] + } +} \ No newline at end of file diff --git a/src/main/scala/shine/DPIA/Types/PhraseType.scala b/src/main/scala/shine/DPIA/Types/PhraseType.scala index fc5bd3219..4c57e0d10 100644 --- a/src/main/scala/shine/DPIA/Types/PhraseType.scala +++ b/src/main/scala/shine/DPIA/Types/PhraseType.scala @@ -45,6 +45,7 @@ final case class DepFunType[K <: Kind, +R <: PhraseType](x: K#I, t: R) (implicit val kn: KindName[K]) extends PhraseType { override def toString = s"(${x.name}: ${kn.get}) -> $t" + val kindName = implicitly(kn) } object PhraseType { diff --git a/src/main/scala/util/monad.scala b/src/main/scala/util/monad.scala new file mode 100644 index 000000000..d81005c47 --- /dev/null +++ b/src/main/scala/util/monad.scala @@ -0,0 +1,35 @@ +package util + +import scala.language.implicitConversions + +object monad { + trait Monad[M[+_]] { + def return_[T] : T => M[T] + def bind[T,S] : M[T] => (T => M[S]) => M[S] + def traverse[A] : Seq[M[A]] => M[Seq[A]] = + _.foldRight(return_(Seq() : Seq[A]))({case (mx, mxs) => + bind(mx)(x => bind(mxs)(xs => return_(x +: xs)))}) + // FIXME: We should be able to use S[_] <: Seq[_] for both + def traverseV[A] : Vector[M[A]] => M[Vector[A]] = + _.foldRight(return_(Vector() : Vector[A]))({case (mx, mxs) => + bind(mx)(x => bind(mxs)(xs => return_(x +: xs)))}) + } + + implicit def monadicSyntax[M[+_], A](m: M[A])(implicit tc: Monad[M]) = new { + def map[B](f: A => B): M[B] = tc.bind(m)(a => tc.return_(f(a)) ) + def flatMap[B](f: A => M[B]): M[B] = tc.bind(m)(f) + } + + case class Pure[+T](unwrap : T) + implicit object PureMonad extends Monad[Pure] { + override def return_[T] : T => Pure[T] = t => Pure(t) + override def bind[T,S] : Pure[T] => (T => Pure[S]) => Pure[S] = + v => f => v match { case Pure(v) => f(v) } + } + + implicit object OptionMonad extends Monad[Option] { + def return_[T]: T => Option[T] = Some(_) + def bind[T, S]: Option[T] => (T => Option[S]) => Option[S] = v => v.flatMap + } + +} diff --git a/src/test/scala/rise/core/traverseTest.scala b/src/test/scala/rise/core/traverseTest.scala index 68133bce8..847d33c8a 100644 --- a/src/test/scala/rise/core/traverseTest.scala +++ b/src/test/scala/rise/core/traverseTest.scala @@ -13,7 +13,7 @@ class traverseTest extends test_util.Tests { ) ) - case class Trace[T](unwrap : T) { val trace : Seq[Any] = Seq() } + case class Trace[+T](unwrap : T) { val trace : Seq[Any] = Seq() } implicit object TraceMonad extends Monad[Trace] { def write[T] : T => Trace[T] = t => new Trace(t) { override val trace : Seq[Any] = Seq(t)} override def return_[T] : T => Trace[T] = t => Trace(t)