Skip to content

Commit

Permalink
Merge pull request #1580 from dedis/work-be2-lauener-refactoring-publ…
Browse files Browse the repository at this point in the history
…ishsubscribe

Refactoring Akka graph construction
  • Loading branch information
K1li4nL authored May 22, 2023
2 parents db5d3da + a63c7cf commit 97af26c
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 130 deletions.
42 changes: 30 additions & 12 deletions be2-scala/src/main/scala/ch/epfl/pop/pubsub/PublishSubscribe.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ import akka.pattern.AskableActorRef
import akka.stream.FlowShape
import akka.stream.scaladsl.{Broadcast, Flow, GraphDSL, Merge, Partition, Sink}
import ch.epfl.pop.decentralized.Monitor
import ch.epfl.pop.model.network.MethodType._
import ch.epfl.pop.model.network.{JsonRpcRequest, JsonRpcResponse}
import ch.epfl.pop.pubsub.graph._
import ch.epfl.pop.pubsub.graph.handlers.{GetMessagesByIdResponseHandler, ParamsWithChannelHandler, ParamsWithMapHandler, ParamsWithMessageHandler}
import ch.epfl.pop.pubsub.graph.handlers.{GetMessagesByIdResponseHandler, ParamsHandler, ParamsWithMapHandler, ParamsWithMessageHandler}

object PublishSubscribe {

Expand Down Expand Up @@ -95,9 +96,12 @@ object PublishSubscribe {
/* partitioner port numbers */
val portPipelineError = 0
val portParamsWithMessage = 1
val portParamsWithChannel = 2
val portParamsWithMap = 3
val totalPorts = 4
val portSubscribe = 2
val portUnsubscribe = 3
val portCatchup = 4
val portHeartbeat = 5
val portGetMessagesById = 6
val totalPorts = 7

/* building blocks */
val input = builder.add(Flow[GraphMessage].collect { case msg: GraphMessage => msg })
Expand All @@ -106,16 +110,27 @@ object PublishSubscribe {
val methodPartitioner = builder.add(Partition[GraphMessage](
totalPorts,
{
case Right(m: JsonRpcRequest) if m.hasParamsMessage => portParamsWithMessage // Publish and Broadcast messages
case Right(m: JsonRpcRequest) if m.hasParamsChannel => portParamsWithChannel
case Right(_: JsonRpcRequest) => portParamsWithMap
case _ => portPipelineError // Pipeline error goes directly in merger
case Right(m: JsonRpcRequest) => m.method match {
case BROADCAST => portParamsWithMessage
case PUBLISH => portParamsWithMessage
case SUBSCRIBE => portSubscribe
case UNSUBSCRIBE => portUnsubscribe
case CATCHUP => portCatchup
case HEARTBEAT => portHeartbeat
case GET_MESSAGES_BY_ID => portGetMessagesById
case _ => portPipelineError
}

case _ => portPipelineError // Pipeline error goes directly in merger
}
))

val hasMessagePartition = builder.add(ParamsWithMessageHandler.graph(messageRegistry))
val hasChannelPartition = builder.add(ParamsWithChannelHandler.graph(clientActorRef))
val hasMapPartition = builder.add(ParamsWithMapHandler.graph(dbActorRef))
val subscribePartition = builder.add(ParamsHandler.subscribeHandler(clientActorRef))
val unsubscribePartition = builder.add(ParamsHandler.unsubscribeHandler(clientActorRef))
val catchupPartition = builder.add(ParamsHandler.catchupHandler(clientActorRef))
val heartbeatPartition = builder.add(ParamsWithMapHandler.heartbeatHandler(dbActorRef))
val getMessagesByIdPartition = builder.add(ParamsWithMapHandler.getMessagesByIdHandler(dbActorRef))

val merger = builder.add(Merge[GraphMessage](totalPorts))

Expand All @@ -124,8 +139,11 @@ object PublishSubscribe {

methodPartitioner.out(portPipelineError) ~> merger
methodPartitioner.out(portParamsWithMessage) ~> hasMessagePartition ~> merger
methodPartitioner.out(portParamsWithChannel) ~> hasChannelPartition ~> merger
methodPartitioner.out(portParamsWithMap) ~> hasMapPartition ~> merger
methodPartitioner.out(portSubscribe) ~> subscribePartition ~> merger
methodPartitioner.out(portUnsubscribe) ~> unsubscribePartition ~> merger
methodPartitioner.out(portCatchup) ~> catchupPartition ~> merger
methodPartitioner.out(portHeartbeat) ~> heartbeatPartition ~> merger
methodPartitioner.out(portGetMessagesById) ~> getMessagesByIdPartition ~> merger

/* close the shape */
FlowShape(input.in, merger.out)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package ch.epfl.pop.pubsub.graph.handlers

import akka.NotUsed
import akka.actor.ActorRef
import akka.pattern.AskableActorRef
import akka.stream.FlowShape
import akka.stream.scaladsl.{Flow, GraphDSL, Merge, Partition}
import ch.epfl.pop.model.network.method.{Catchup, Subscribe, Unsubscribe}
import akka.stream.scaladsl.Flow
import ch.epfl.pop.model.network.{JsonRpcRequest, JsonRpcResponse}
import ch.epfl.pop.model.objects.Channel
import ch.epfl.pop.pubsub.graph.{ErrorCodes, GraphMessage, PipelineError}
Expand All @@ -14,51 +11,7 @@ import ch.epfl.pop.pubsub.{AskPatternConstants, ClientActor, PubSubMediator}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.{Await, Future}

object ParamsWithChannelHandler extends AskPatternConstants {

def graph(clientActorRef: ActorRef): Flow[GraphMessage, GraphMessage, NotUsed] = Flow.fromGraph(GraphDSL.create() {
implicit builder: GraphDSL.Builder[NotUsed] =>
{
import GraphDSL.Implicits._

/* partitioner port numbers */
val portPipelineError = 0
val portSubscribe = 1
val portUnsubscribe = 2
val portCatchup = 3
val totalPorts = 4

/* building blocks */
val handlerPartitioner = builder.add(Partition[GraphMessage](
totalPorts,
{
case Right(jsonRpcMessage: JsonRpcRequest) => jsonRpcMessage.getParams match {
case _: Subscribe => portSubscribe
case _: Unsubscribe => portUnsubscribe
case _: Catchup => portCatchup
}
case _ => portPipelineError // Pipeline error goes directly in handlerMerger
}
))

val subscribeHandler = builder.add(ParamsWithChannelHandler.subscribeHandler(clientActorRef))
val unsubscribeHandler = builder.add(ParamsWithChannelHandler.unsubscribeHandler(clientActorRef))
val catchupHandler = builder.add(ParamsWithChannelHandler.catchupHandler(clientActorRef))

val handlerMerger = builder.add(Merge[GraphMessage](totalPorts))

/* glue the components together */
handlerPartitioner.out(portPipelineError) ~> handlerMerger
handlerPartitioner.out(portSubscribe) ~> subscribeHandler ~> handlerMerger
handlerPartitioner.out(portUnsubscribe) ~> unsubscribeHandler ~> handlerMerger
handlerPartitioner.out(portCatchup) ~> catchupHandler ~> handlerMerger

/* close the shape */
FlowShape(handlerPartitioner.in, handlerMerger.out)
}
})

final case class Asking(g: GraphMessage, replyTo: ActorRef)
object ParamsHandler extends AskPatternConstants {

def subscribeHandler(clientActorRef: AskableActorRef): Flow[GraphMessage, GraphMessage, NotUsed] = Flow[GraphMessage].map {
case Right(jsonRpcMessage: JsonRpcRequest) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,24 @@ package ch.epfl.pop.pubsub.graph.handlers

import akka.NotUsed
import akka.pattern.AskableActorRef
import akka.stream.scaladsl.{Flow, GraphDSL, Merge, Partition}
import akka.stream.scaladsl.Flow
import ch.epfl.pop.model.network.method.message.Message
import ch.epfl.pop.model.network.method.{GetMessagesById, Heartbeat}
import ch.epfl.pop.model.network.{JsonRpcRequest, JsonRpcResponse}
import ch.epfl.pop.model.network.{JsonRpcRequest, JsonRpcResponse, MethodType, ResultObject}
import ch.epfl.pop.model.objects.{Channel, DbActorNAckException, Hash}
import ch.epfl.pop.pubsub.graph.{ErrorCodes, GraphMessage, PipelineError}
import ch.epfl.pop.pubsub.AskPatternConstants
import ch.epfl.pop.storage.DbActor
import ch.epfl.pop.model.network.MethodType
import ch.epfl.pop.pubsub.graph.validators.RpcValidator
import ch.epfl.pop.model.network.method.message.Message
import scala.collection.mutable
import akka.stream.FlowShape
import ch.epfl.pop.model.network.ResultObject
import ch.epfl.pop.pubsub.graph.{ErrorCodes, GraphMessage, PipelineError}
import ch.epfl.pop.storage.DbActor

import scala.collection.immutable.HashMap
import scala.collection.mutable
import scala.concurrent.Await
import scala.util.{Failure, Success}

object ParamsWithMapHandler extends AskPatternConstants {

def graph(dbActorRef: AskableActorRef): Flow[GraphMessage, GraphMessage, NotUsed] = Flow.fromGraph(GraphDSL.create() {
implicit builder: GraphDSL.Builder[NotUsed] =>
{
import GraphDSL.Implicits._

/* partitioner port numbers */
val portPipelineError = 0
val portHeartBeatHandler = 1
val portGetMessagesByIdHandler = 2
val totalPorts = 3

/* building blocks */
val handlerPartitioner = builder.add(Partition[GraphMessage](
totalPorts,
{
case Right(jsonRpcMessage: JsonRpcRequest) => jsonRpcMessage.getParams match {
case _: Heartbeat => portHeartBeatHandler
case _: GetMessagesById => portGetMessagesByIdHandler
}
case _ => portPipelineError // Pipeline error goes directly in handlerMerger
}
))

val heartbeatHandler = builder.add(ParamsWithMapHandler.heartbeatHandler(dbActorRef))
val getMessagesByIdHandler = builder.add(ParamsWithMapHandler.getMessagesByIdHandler(dbActorRef))
val handlerMerger = builder.add(Merge[GraphMessage](totalPorts))

/* glue the components together */
handlerPartitioner.out(portPipelineError) ~> handlerMerger
handlerPartitioner.out(portHeartBeatHandler) ~> heartbeatHandler ~> handlerMerger
handlerPartitioner.out(portGetMessagesByIdHandler) ~> getMessagesByIdHandler ~> handlerMerger

/* close the shape */
FlowShape(handlerPartitioner.in, handlerMerger.out)
}
})

private def heartbeatHandler(dbActorRef: AskableActorRef): Flow[GraphMessage, GraphMessage, NotUsed] = Flow[GraphMessage].map {
def heartbeatHandler(dbActorRef: AskableActorRef): Flow[GraphMessage, GraphMessage, NotUsed] = Flow[GraphMessage].map {
case Right(jsonRpcMessage: JsonRpcRequest) =>
/** first step is to retrieve the received heartbeat from the jsonRpcRequest */
val receivedHeartBeat: Map[Channel, Set[Hash]] = jsonRpcMessage.getParams.asInstanceOf[Heartbeat].channelsToMessageIds
Expand Down Expand Up @@ -105,7 +65,7 @@ object ParamsWithMapHandler extends AskPatternConstants {
case graphMessage @ _ => graphMessage
}.filter(!isGetMessagesByIdEmpty(_)) // Answer to heartbeats only if some messages are actually missing

private def getMessagesByIdHandler(dbActorRef: AskableActorRef): Flow[GraphMessage, GraphMessage, NotUsed] = Flow[GraphMessage].map {
def getMessagesByIdHandler(dbActorRef: AskableActorRef): Flow[GraphMessage, GraphMessage, NotUsed] = Flow[GraphMessage].map {
case Right(jsonRpcMessage: JsonRpcRequest) =>
val receivedRequest: Map[Channel, Set[Hash]] = jsonRpcMessage.getParams.asInstanceOf[GetMessagesById].channelsToMessageIds
val response: mutable.HashMap[Channel, Set[Message]] = mutable.HashMap()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import util.examples.JsonRpcRequestExample
import scala.concurrent.Await
import scala.util.{Failure, Success}

class ParamsWithChannelHandlerSuite extends FunSuite with Matchers with AskPatternConstants {
class ParamsHandlerSuite extends FunSuite with Matchers with AskPatternConstants {

implicit val system: ActorSystem = ActorSystem()

Expand All @@ -22,7 +22,7 @@ class ParamsWithChannelHandlerSuite extends FunSuite with Matchers with AskPatte
val rpcExample = JsonRpcRequestExample.subscribeRpcRequest
val expectedAsk = ClientActor.SubscribeTo(rpcExample.getParams.channel)
val pipelineOutput = Source.single(Right(rpcExample))
.via(ParamsWithChannelHandler.subscribeHandler(mockClientRef.ref))
.via(ParamsHandler.subscribeHandler(mockClientRef.ref))
.runWith(Sink.head)

val channel = mockClientRef.expectMsg(expectedAsk).channel
Expand All @@ -40,7 +40,7 @@ class ParamsWithChannelHandlerSuite extends FunSuite with Matchers with AskPatte
val rpcExample = JsonRpcRequestExample.subscribeRpcRequest
val expectedAsk = ClientActor.SubscribeTo(rpcExample.getParams.channel)
val pipelineOutput = Source.single(Right(rpcExample))
.via(ParamsWithChannelHandler.subscribeHandler(mockClientRef.ref))
.via(ParamsHandler.subscribeHandler(mockClientRef.ref))
.runWith(Sink.head)

val channel = mockClientRef.expectMsg(expectedAsk).channel
Expand All @@ -59,7 +59,7 @@ class ParamsWithChannelHandlerSuite extends FunSuite with Matchers with AskPatte
val rpcExample = JsonRpcRequestExample.unSubscribeRpcRequest
val expectedAsk = ClientActor.UnsubscribeFrom(rpcExample.getParams.channel)
val pipelineOutput = Source.single(Right(rpcExample))
.via(ParamsWithChannelHandler.unsubscribeHandler(mockClientRef.ref))
.via(ParamsHandler.unsubscribeHandler(mockClientRef.ref))
.runWith(Sink.head)

val channel = mockClientRef.expectMsg(expectedAsk).channel
Expand All @@ -77,7 +77,7 @@ class ParamsWithChannelHandlerSuite extends FunSuite with Matchers with AskPatte
val rpcExample = JsonRpcRequestExample.unSubscribeRpcRequest
val expectedAsk = ClientActor.UnsubscribeFrom(rpcExample.getParams.channel)
val pipelineOutput = Source.single(Right(rpcExample))
.via(ParamsWithChannelHandler.unsubscribeHandler(mockClientRef.ref))
.via(ParamsHandler.unsubscribeHandler(mockClientRef.ref))
.runWith(Sink.head)

val channel = mockClientRef.expectMsg(expectedAsk).channel
Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,32 @@
package ch.epfl.pop.pubsub.graph.handlers

import akka.NotUsed
import akka.actor.{ActorSystem, Props}
import akka.pattern.AskableActorRef
import akka.stream.SinkShape
import akka.stream.scaladsl.{Flow, Sink, Source}
import akka.stream.scaladsl.{Sink, Source}
import akka.testkit.TestKit
import ch.epfl.pop.decentralized.ToyDbActor
import ch.epfl.pop.model.network.{JsonRpcRequest, JsonRpcResponse, MethodType, ResultObject}
import ch.epfl.pop.model.network.method.{GetMessagesById, Heartbeat}
import ch.epfl.pop.model.network.method.message.Message
import ch.epfl.pop.model.objects.{Base64Data, Channel, Hash}
import ch.epfl.pop.model.network.method.GetMessagesById
import ch.epfl.pop.model.network.{JsonRpcRequest, JsonRpcResponse}
import ch.epfl.pop.pubsub.AskPatternConstants
import ch.epfl.pop.pubsub.graph.GraphMessage
import ch.epfl.pop.pubsub.graph.validators.RpcValidator
import org.scalatest.funsuite.{AnyFunSuite, AnyFunSuiteLike}
import org.scalatest.funsuite.AnyFunSuiteLike
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal}
import util.examples.JsonRpcRequestExample._
import scala.collection.{immutable, mutable}
import scala.concurrent.{Await, Future}
import scala.util.{Failure, Success}

import scala.concurrent.Await
import scala.util.Success
class ParamsWithMapHandlerSuite extends TestKit(ActorSystem("HbActorSuiteActorSystem")) with AnyFunSuiteLike with AskPatternConstants {

final val toyDbActorRef: AskableActorRef = system.actorOf(Props(new ToyDbActor))
final val boxUnderTest: Flow[GraphMessage, GraphMessage, NotUsed] = ParamsWithMapHandler.graph(toyDbActorRef)
final val heartbeatHandler = ParamsWithMapHandler.heartbeatHandler(toyDbActorRef)
final val getMessagesByIdHandler = ParamsWithMapHandler.getMessagesByIdHandler(toyDbActorRef)
final val rpc: String = "rpc"
final val id: Option[Int] = Some(0)

test("sending a heartbeat correctly returns the missing ids") {
val input: List[GraphMessage] = List(Right(VALID_RECEIVED_HEARTBEAT_RPC))
val source = Source(input)
val s = source.via(boxUnderTest).runWith(Sink.seq[GraphMessage])
val s = source.via(heartbeatHandler).runWith(Sink.seq[GraphMessage])
Await.ready(s, duration).value match {
case Some(Success(seq)) => seq.toList.head match {
case Right(jsonRpcReq: JsonRpcRequest) => jsonRpcReq.getParams.asInstanceOf[GetMessagesById].channelsToMessageIds should equal(EXPECTED_MISSING_MESSAGE_IDS)
Expand All @@ -44,7 +40,7 @@ class ParamsWithMapHandlerSuite extends TestKit(ActorSystem("HbActorSuiteActorSy
test("sending a getMessagesById correctly returns the missing messages") {
val input: List[GraphMessage] = List(Right(VALID_RECEIVED_GET_MSG_BY_ID_RPC))
val source = Source(input)
val s = source.via(boxUnderTest).runWith(Sink.seq[GraphMessage])
val s = source.via(getMessagesByIdHandler).runWith(Sink.seq[GraphMessage])
Await.ready(s, duration).value match {
case Some(Success(seq)) => seq.toList.head match {
case Right(jsonRpcResp: JsonRpcResponse) => jsonRpcResp.result.get.resultMap.get should equal(EXPECTED_MISSING_MESSAGES)
Expand All @@ -58,7 +54,7 @@ class ParamsWithMapHandlerSuite extends TestKit(ActorSystem("HbActorSuiteActorSy
test("receiving a heartbeat with unknown channel asks back for this channel") {
val input: List[GraphMessage] = List(Right(VALID_RECEIVED_UNKNOWN_CHANNEL_HEARTBEAT_RPC))
val source = Source(input)
val s = source.via(boxUnderTest).runWith(Sink.seq[GraphMessage])
val s = source.via(heartbeatHandler).runWith(Sink.seq[GraphMessage])
Await.ready(s, duration).value match {
case Some(Success(seq)) => seq.toList.head match {
case Right(jsonRpcReq: JsonRpcRequest) => jsonRpcReq.getParams.asInstanceOf[GetMessagesById].channelsToMessageIds should equal(EXPECTED_UNKNOWN_CHANNEL_MISSING_MESSAGE_IDS)
Expand Down

0 comments on commit 97af26c

Please sign in to comment.