diff --git a/be2-scala/src/main/scala/ch/epfl/pop/Server.scala b/be2-scala/src/main/scala/ch/epfl/pop/Server.scala index 071e16ed3e..ef8d4b366e 100644 --- a/be2-scala/src/main/scala/ch/epfl/pop/Server.scala +++ b/be2-scala/src/main/scala/ch/epfl/pop/Server.scala @@ -10,6 +10,7 @@ import akka.http.scaladsl.server.{RequestContext, RouteResult} import akka.pattern.AskableActorRef import akka.util.Timeout import ch.epfl.pop.config.{RuntimeEnvironment, ServerConf} +import ch.epfl.pop.decentralized.{ConnectionMediator, HeartbeatGenerator, Monitor} import ch.epfl.pop.pubsub.{MessageRegistry, PubSubMediator, PublishSubscribe} import ch.epfl.pop.storage.DbActor import org.iq80.leveldb.Options @@ -46,12 +47,35 @@ object Server { val pubSubMediatorRef: ActorRef = system.actorOf(PubSubMediator.props, "PubSubMediator") val dbActorRef: AskableActorRef = system.actorOf(Props(DbActor(pubSubMediatorRef, messageRegistry)), "DbActor") + // Create necessary actors for server-server communications + val heartbeatGenRef: ActorRef = system.actorOf(HeartbeatGenerator.props(dbActorRef)) + val monitorRef: ActorRef = system.actorOf(Monitor.props(heartbeatGenRef)) + val connectionMediatorRef: ActorRef = system.actorOf(ConnectionMediator.props(monitorRef, pubSubMediatorRef, dbActorRef, messageRegistry)) + // Setup routes def publishSubscribeRoute: RequestContext => Future[RouteResult] = { path(config.clientPath) { - handleWebSocketMessages(PublishSubscribe.buildGraph(pubSubMediatorRef, dbActorRef, messageRegistry)(system)) + handleWebSocketMessages( + PublishSubscribe.buildGraph( + pubSubMediatorRef, + dbActorRef, + messageRegistry, + monitorRef, + connectionMediatorRef, + isServer = false + )(system) + ) } ~ path(config.serverPath) { - handleWebSocketMessages(PublishSubscribe.buildGraph(pubSubMediatorRef, dbActorRef, messageRegistry)(system)) + handleWebSocketMessages( + PublishSubscribe.buildGraph( + pubSubMediatorRef, + dbActorRef, + messageRegistry, + monitorRef, + connectionMediatorRef, + isServer = true + )(system) + ) } } diff --git a/be2-scala/src/main/scala/ch/epfl/pop/decentralized/ConnectionMediator.scala b/be2-scala/src/main/scala/ch/epfl/pop/decentralized/ConnectionMediator.scala index d0e6b0caca..47a34bb841 100644 --- a/be2-scala/src/main/scala/ch/epfl/pop/decentralized/ConnectionMediator.scala +++ b/be2-scala/src/main/scala/ch/epfl/pop/decentralized/ConnectionMediator.scala @@ -31,7 +31,10 @@ final case class ConnectionMediator( PublishSubscribe.buildGraph( mediatorRef, dbActorRef, - messageRegistry + messageRegistry, + monitorRef, + self, + isServer = true ) ) ) diff --git a/be2-scala/src/main/scala/ch/epfl/pop/pubsub/ClientActor.scala b/be2-scala/src/main/scala/ch/epfl/pop/pubsub/ClientActor.scala index a911b1e208..c27b34917b 100644 --- a/be2-scala/src/main/scala/ch/epfl/pop/pubsub/ClientActor.scala +++ b/be2-scala/src/main/scala/ch/epfl/pop/pubsub/ClientActor.scala @@ -3,6 +3,7 @@ package ch.epfl.pop.pubsub import akka.actor.{Actor, ActorLogging, ActorRef, Props} import akka.event.LoggingReceive import akka.pattern.AskableActorRef +import ch.epfl.pop.decentralized.ConnectionMediator import ch.epfl.pop.model.objects.Channel import ch.epfl.pop.pubsub.ClientActor._ import ch.epfl.pop.pubsub.PubSubMediator._ @@ -13,13 +14,18 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.{Await, Future} import scala.util.Failure -final case class ClientActor(mediator: ActorRef) extends Actor with ActorLogging with AskPatternConstants { +final case class ClientActor(mediator: ActorRef, connectionMediatorRef: ActorRef, isServer: Boolean) extends Actor with ActorLogging with AskPatternConstants { private var wsHandle: Option[ActorRef] = None private val subscribedChannels: mutable.Set[Channel] = mutable.Set.empty private val mediatorAskable: AskableActorRef = mediator + // Tell connectionMediator we are online + if (isServer) { + connectionMediatorRef ! ConnectionMediator.NewServerConnected(self) + } + private def messageWsHandle(event: ClientActorMessage): Unit = event match { case ClientAnswer(graphMessage) => wsHandle.fold(())(_ ! graphMessage) } @@ -30,7 +36,11 @@ final case class ClientActor(mediator: ActorRef) extends Actor with ActorLogging log.info(s"Connecting wsHandle $wsClient to actor ${this.self}") wsHandle = Some(wsClient) - case DisconnectWsHandle => subscribedChannels.foreach(channel => mediator ! PubSubMediator.UnsubscribeFrom(channel, this.self)) + case DisconnectWsHandle => + subscribedChannels.foreach(channel => mediator ! PubSubMediator.UnsubscribeFrom(channel, this.self)) + if (isServer) { + connectionMediatorRef ! ConnectionMediator.ServerLeft(self) + } case ClientActor.SubscribeTo(channel) => val ask: Future[PubSubMediatorMessage] = (mediatorAskable ? PubSubMediator.SubscribeTo(channel, this.self)).map { @@ -79,7 +89,8 @@ final case class ClientActor(mediator: ActorRef) extends Actor with ActorLogging } object ClientActor { - def props(mediator: ActorRef): Props = Props(new ClientActor(mediator)) + def props(mediator: ActorRef, connectionMediatorRef: ActorRef, isServer: Boolean): Props = + Props(new ClientActor(mediator, connectionMediatorRef, isServer)) sealed trait ClientActorMessage diff --git a/be2-scala/src/main/scala/ch/epfl/pop/pubsub/PublishSubscribe.scala b/be2-scala/src/main/scala/ch/epfl/pop/pubsub/PublishSubscribe.scala index 4a989076c0..74ddad367b 100644 --- a/be2-scala/src/main/scala/ch/epfl/pop/pubsub/PublishSubscribe.scala +++ b/be2-scala/src/main/scala/ch/epfl/pop/pubsub/PublishSubscribe.scala @@ -5,10 +5,11 @@ import akka.actor.{ActorRef, ActorSystem} import akka.http.scaladsl.model.ws.{Message, TextMessage} import akka.pattern.AskableActorRef import akka.stream.FlowShape -import akka.stream.scaladsl.{Flow, GraphDSL, Merge, Partition} -import ch.epfl.pop.model.network.JsonRpcRequest +import akka.stream.scaladsl.{Broadcast, Flow, GraphDSL, Merge, Partition} +import ch.epfl.pop.decentralized.Monitor +import ch.epfl.pop.model.network.{JsonRpcRequest, JsonRpcResponse} import ch.epfl.pop.pubsub.graph._ -import ch.epfl.pop.pubsub.graph.handlers.{ParamsWithChannelHandler, ParamsWithMessageHandler} +import ch.epfl.pop.pubsub.graph.handlers.{GetMessagesByIdResponseHandler, ParamsWithChannelHandler, ParamsWithMapHandler, ParamsWithMessageHandler} object PublishSubscribe { @@ -16,19 +17,30 @@ object PublishSubscribe { def getDbActorRef: AskableActorRef = dbActorRef - def buildGraph(mediatorActorRef: ActorRef, dbActorRefT: AskableActorRef, messageRegistry: MessageRegistry)(implicit system: ActorSystem): Flow[Message, Message, NotUsed] = Flow.fromGraph(GraphDSL.create() { + def buildGraph( + mediatorActorRef: ActorRef, + dbActorRefT: AskableActorRef, + messageRegistry: MessageRegistry, + monitorRef: ActorRef, + connectionMediatorRef: ActorRef, + isServer: Boolean + )(implicit system: ActorSystem): Flow[Message, Message, NotUsed] = Flow.fromGraph(GraphDSL.create() { implicit builder: GraphDSL.Builder[NotUsed] => { import GraphDSL.Implicits._ - val clientActorRef: ActorRef = system.actorOf(ClientActor.props(mediatorActorRef)) + val clientActorRef: ActorRef = system.actorOf(ClientActor.props(mediatorActorRef, connectionMediatorRef, isServer)) dbActorRef = dbActorRefT /* partitioner port numbers */ val portPipelineError = 0 val portParamsWithMessage = 1 val portParamsWithChannel = 2 - val totalPorts = 3 + val portParamsWithMap = 3 + val portResponseHandler = 4 + val totalPorts = 5 + + val totalBroadcastPort = 2 /* building blocks */ // input message from the client @@ -44,15 +56,21 @@ object PublishSubscribe { { case Right(m: JsonRpcRequest) if m.hasParamsMessage => portParamsWithMessage // Publish and Broadcast messages case Right(m: JsonRpcRequest) if m.hasParamsChannel => portParamsWithChannel + case Right(_: JsonRpcRequest) => portParamsWithMap + case Right(_: JsonRpcResponse) => portResponseHandler case _ => portPipelineError // Pipeline error goes directly in merger } )) val hasMessagePartition = builder.add(ParamsWithMessageHandler.graph(messageRegistry)) val hasChannelPartition = builder.add(ParamsWithChannelHandler.graph(clientActorRef)) + val hasMapPartition = builder.add(ParamsWithMapHandler.graph(dbActorRef)) + val responsePartition = builder.add(GetMessagesByIdResponseHandler.graph(dbActorRef.actorRef)) val merger = builder.add(Merge[GraphMessage](totalPorts)) + val broadcast = builder.add(Broadcast[GraphMessage](totalBroadcastPort)) + val monitorSink = builder.add(Monitor.sink(monitorRef)) val jsonRpcAnswerGenerator = builder.add(AnswerGenerator.generator) val jsonRpcAnswerer = builder.add(Answerer.answerer(clientActorRef, mediatorActorRef)) @@ -65,8 +83,12 @@ object PublishSubscribe { methodPartitioner.out(portPipelineError) ~> merger methodPartitioner.out(portParamsWithMessage) ~> hasMessagePartition ~> merger methodPartitioner.out(portParamsWithChannel) ~> hasChannelPartition ~> merger + methodPartitioner.out(portParamsWithMap) ~> hasMapPartition ~> merger + methodPartitioner.out(portResponseHandler) ~> responsePartition ~> merger - merger ~> jsonRpcAnswerGenerator ~> jsonRpcAnswerer ~> output + merger ~> broadcast + broadcast ~> jsonRpcAnswerGenerator ~> jsonRpcAnswerer ~> output + broadcast ~> monitorSink /* close the shape */ FlowShape(input.in, output.out) diff --git a/be2-scala/src/main/scala/ch/epfl/pop/pubsub/graph/AnswerGenerator.scala b/be2-scala/src/main/scala/ch/epfl/pop/pubsub/graph/AnswerGenerator.scala index d1b67cd6e5..b2d4e7b406 100644 --- a/be2-scala/src/main/scala/ch/epfl/pop/pubsub/graph/AnswerGenerator.scala +++ b/be2-scala/src/main/scala/ch/epfl/pop/pubsub/graph/AnswerGenerator.scala @@ -3,8 +3,8 @@ package ch.epfl.pop.pubsub.graph import akka.NotUsed import akka.pattern.AskableActorRef import akka.stream.scaladsl.Flow -import ch.epfl.pop.model.network.method.{Broadcast, Catchup} -import ch.epfl.pop.model.network.{ResultObject, _} +import ch.epfl.pop.model.network.method.{Broadcast, Catchup, GetMessagesById} +import ch.epfl.pop.model.network._ import ch.epfl.pop.model.objects.DbActorNAckException import ch.epfl.pop.pubsub.AskPatternConstants import ch.epfl.pop.pubsub.graph.validators.RpcValidator @@ -48,6 +48,9 @@ class AnswerGenerator(dbActor: => AskableActorRef) extends AskPatternConstants { rpcRequest.id )) + // Let get_messages_by_id request go through + case GetMessagesById(_) => graphMessage + // Standard answer res == 0 case _ => Right(JsonRpcResponse( RpcValidator.JSON_RPC_VERSION, @@ -57,6 +60,9 @@ class AnswerGenerator(dbActor: => AskableActorRef) extends AskPatternConstants { )) } + // Let get_messages_by_id answer go through + case Right(_: JsonRpcResponse) => graphMessage + // Convert PipelineErrors into negative JsonRpcResponses case Left(pipelineError: PipelineError) => Right(JsonRpcResponse( RpcValidator.JSON_RPC_VERSION, diff --git a/be2-scala/src/main/scala/ch/epfl/pop/pubsub/graph/Validator.scala b/be2-scala/src/main/scala/ch/epfl/pop/pubsub/graph/Validator.scala index 281299c177..93cd568951 100644 --- a/be2-scala/src/main/scala/ch/epfl/pop/pubsub/graph/Validator.scala +++ b/be2-scala/src/main/scala/ch/epfl/pop/pubsub/graph/Validator.scala @@ -28,29 +28,29 @@ object Validator { private def validateMethodContent(graphMessage: GraphMessage): GraphMessage = graphMessage match { case Right(jsonRpcRequest: JsonRpcRequest) => jsonRpcRequest.getParams match { - case _: Broadcast => validateBroadcast(jsonRpcRequest) - case _: Catchup => validateCatchup(jsonRpcRequest) - case _: Publish => validatePublish(jsonRpcRequest) - case _: Subscribe => validateSubscribe(jsonRpcRequest) - case _: Unsubscribe => validateUnsubscribe(jsonRpcRequest) - case _ => Left(validationError(jsonRpcRequest.id)) + case _: Broadcast => validateBroadcast(jsonRpcRequest) + case _: Catchup => validateCatchup(jsonRpcRequest) + case _: Publish => validatePublish(jsonRpcRequest) + case _: Subscribe => validateSubscribe(jsonRpcRequest) + case _: Unsubscribe => validateUnsubscribe(jsonRpcRequest) + case _: Heartbeat => graphMessage // No check necessary + case _: GetMessagesById => graphMessage // No check necessary + case _ => Left(validationError(jsonRpcRequest.id)) } - case Right(jsonRpcResponse: JsonRpcResponse) => Left(PipelineError( - ErrorCodes.SERVER_ERROR.id, - "Unsupported action: MethodValidator was given a response message", - jsonRpcResponse.id - )) + case _ => graphMessage } private def validateMessageContent(graphMessage: GraphMessage): GraphMessage = graphMessage match { case Right(jsonRpcRequest: JsonRpcRequest) => jsonRpcRequest.getParams match { - case _: Broadcast => validateMessage(jsonRpcRequest) - case _: Catchup => graphMessage - case _: Publish => validateMessage(jsonRpcRequest) - case _: Subscribe => graphMessage - case _: Unsubscribe => graphMessage - case _ => Left(validationError(jsonRpcRequest.id)) + case _: Broadcast => validateMessage(jsonRpcRequest) + case _: Catchup => graphMessage + case _: Publish => validateMessage(jsonRpcRequest) + case _: Subscribe => graphMessage + case _: Unsubscribe => graphMessage + case _: Heartbeat => graphMessage + case _: GetMessagesById => graphMessage + case _ => Left(validationError(jsonRpcRequest.id)) } case graphMessage @ _ => graphMessage } @@ -59,7 +59,6 @@ object Validator { case Right(_) => validateJsonRpcContent(graphMessage) match { case Right(_) => validateMethodContent(graphMessage) match { case Right(_) => validateMessageContent(graphMessage) match { - case Right(_) => graphMessage case graphMessage @ _ => graphMessage } case graphMessage @ _ => graphMessage