From d75666bc62830205557c7e34b9fc104122bd244a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michal=20August=C3=BDn?= Date: Wed, 29 May 2019 16:44:09 +0200 Subject: [PATCH] Republishing aware of UserID (#27) Republishing aware of UserID. --- .../avast/clients/rabbitmq/ConsumerBase.scala | 24 +++++++--- .../DefaultRabbitMQClientFactory.scala | 44 +++++++++---------- .../rabbitmq/DefaultRabbitMQConsumer.scala | 2 + .../DefaultRabbitMQPullConsumer.scala | 1 + .../clients/rabbitmq/RabbitMQConnection.scala | 3 +- .../clients/rabbitmq/configuration.scala | 2 +- .../DefaultRabbitMQConsumerTest.scala | 26 +++++++++-- .../DefaultRabbitMQPullConsumerTest.scala | 21 +++++++-- 8 files changed, 84 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/com/avast/clients/rabbitmq/ConsumerBase.scala b/core/src/main/scala/com/avast/clients/rabbitmq/ConsumerBase.scala index ffa1cc60..a8f5e500 100644 --- a/core/src/main/scala/com/avast/clients/rabbitmq/ConsumerBase.scala +++ b/core/src/main/scala/com/avast/clients/rabbitmq/ConsumerBase.scala @@ -1,7 +1,7 @@ package com.avast.clients.rabbitmq import cats.effect.Effect -import com.avast.clients.rabbitmq.DefaultRabbitMQConsumer.RepublishOriginalRoutingKeyHeaderName +import com.avast.clients.rabbitmq.DefaultRabbitMQConsumer._ import com.avast.clients.rabbitmq.api.DeliveryResult import com.avast.metrics.scalaapi.Monitor import com.rabbitmq.client.AMQP.BasicProperties @@ -18,6 +18,7 @@ private[rabbitmq] trait ConsumerBase[F[_]] extends StrictLogging { protected def queueName: String protected def channel: ServerChannel protected def blockingScheduler: Scheduler + protected def connectionInfo: RabbitMQConnectionInfo protected def monitor: Monitor protected val resultsMonitor: Monitor = monitor.named("results") @@ -35,7 +36,7 @@ private[rabbitmq] trait ConsumerBase[F[_]] extends StrictLogging { case Reject => reject(messageId, deliveryTag) case Retry => retry(messageId, deliveryTag) case Republish(newHeaders) => - republish(messageId, deliveryTag, mergeHeadersForRepublish(newHeaders, properties, routingKey), body) + republish(messageId, deliveryTag, createPropertiesForRepublish(newHeaders, properties, routingKey), body) } task.to[F] @@ -86,12 +87,21 @@ private[rabbitmq] trait ConsumerBase[F[_]] extends StrictLogging { } }.executeOn(blockingScheduler).asyncBoundary - protected def mergeHeadersForRepublish(newHeaders: Map[String, AnyRef], - properties: BasicProperties, - routingKey: String): BasicProperties = { + protected def createPropertiesForRepublish(newHeaders: Map[String, AnyRef], + properties: BasicProperties, + routingKey: String): BasicProperties = { // values in newHeaders will overwrite values in original headers - val h = newHeaders + (RepublishOriginalRoutingKeyHeaderName -> routingKey) + // we must also ensure that UserID will be the same as current username (or nothing): https://www.rabbitmq.com/validated-user-id.html + val originalUserId = Option(properties.getUserId).filter(_.nonEmpty) + val h = originalUserId match { + case Some(uid) => newHeaders + (RepublishOriginalRoutingKeyHeaderName -> routingKey) + (RepublishOriginalUserId -> uid) + case None => newHeaders + (RepublishOriginalRoutingKeyHeaderName -> routingKey) + } val headers = Option(properties.getHeaders).map(_.asScala ++ h).getOrElse(h) - properties.builder().headers(headers.asJava).build() + val newUserId = originalUserId match { + case Some(_) => connectionInfo.username.orNull + case None => null + } + properties.builder().headers(headers.asJava).userId(newUserId).build() } } diff --git a/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQClientFactory.scala b/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQClientFactory.scala index 39696381..ba8b9068 100644 --- a/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQClientFactory.scala +++ b/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQClientFactory.scala @@ -313,13 +313,13 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging { } private[rabbitmq] def declareExchange(name: String, - channelFactoryInfo: RabbitMQConnectionInfo, + connectionInfo: RabbitMQConnectionInfo, channel: ServerChannel, autoDeclareExchange: AutoDeclareExchange): Unit = { import autoDeclareExchange._ if (enabled) { - declareExchange(name, `type`, durable, autoDelete, arguments, channel, channelFactoryInfo) + declareExchange(name, `type`, durable, autoDelete, arguments, channel, connectionInfo) } () } @@ -330,8 +330,8 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging { autoDelete: Boolean, arguments: DeclareArguments, channel: ServerChannel, - channelFactoryInfo: RabbitMQConnectionInfo): Unit = { - logger.info(s"Declaring exchange '$name' of type ${`type`} in virtual host '${channelFactoryInfo.virtualHost}'") + connectionInfo: RabbitMQConnectionInfo): Unit = { + logger.info(s"Declaring exchange '$name' of type ${`type`} in virtual host '${connectionInfo.virtualHost}'") val javaArguments = argsAsJava(arguments.value) channel.exchangeDeclare(name, `type`, durable, autoDelete, javaArguments) () @@ -365,7 +365,7 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging { private def preparePullConsumer[F[_]: Effect, A: DeliveryConverter]( consumerConfig: PullConsumerConfig, configName: String, - channelFactoryInfo: RabbitMQConnectionInfo, + connectionInfo: RabbitMQConnectionInfo, channel: ServerChannel, blockingScheduler: Scheduler, monitor: Monitor)(implicit scheduler: Scheduler): DefaultRabbitMQPullConsumer[F, A] = { @@ -373,25 +373,25 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging { import consumerConfig._ // auto declare exchanges - declareExchangesFromBindings(configName, channelFactoryInfo, channel, consumerConfig.bindings) + declareExchangesFromBindings(configName, connectionInfo, channel, consumerConfig.bindings) // auto declare queue - declareQueue(queueName, channelFactoryInfo, channel, declare) + declareQueue(queueName, connectionInfo, channel, declare) // auto bind - bindQueues(channelFactoryInfo, channel, consumerConfig.queueName, consumerConfig.bindings) + bindQueues(connectionInfo, channel, consumerConfig.queueName, consumerConfig.bindings) - new DefaultRabbitMQPullConsumer[F, A](name, channel, queueName, failureAction, monitor, blockingScheduler) + new DefaultRabbitMQPullConsumer[F, A](name, channel, queueName, connectionInfo, failureAction, monitor, blockingScheduler) } private def declareQueue(queueName: String, - channelFactoryInfo: RabbitMQConnectionInfo, + connectionInfo: RabbitMQConnectionInfo, channel: ServerChannel, declare: AutoDeclareQueue): Unit = { import declare._ if (enabled) { - logger.info(s"Declaring queue '$queueName' in virtual host '${channelFactoryInfo.virtualHost}'") + logger.info(s"Declaring queue '$queueName' in virtual host '${connectionInfo.virtualHost}'") declareQueue(channel, queueName, durable, exclusive, autoDelete, arguments) } } @@ -405,7 +405,7 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging { channel.queueDeclare(queueName, durable, exclusive, autoDelete, arguments.value) } - private def bindQueues(channelFactoryInfo: RabbitMQConnectionInfo, + private def bindQueues(connectionInfo: RabbitMQConnectionInfo, channel: ServerChannel, queueName: String, bindings: immutable.Seq[AutoBindQueue]): Unit = { @@ -413,7 +413,7 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging { import bind._ val exchangeName = bind.exchange.name - bindQueues(channel, queueName, exchangeName, routingKeys, bindArguments, channelFactoryInfo) + bindQueues(channel, queueName, exchangeName, routingKeys, bindArguments, connectionInfo) } } @@ -422,22 +422,22 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging { exchangeName: String, routingKeys: immutable.Seq[String], bindArguments: BindArguments, - channelFactoryInfo: RabbitMQConnectionInfo): Unit = { + connectionInfo: RabbitMQConnectionInfo): Unit = { if (routingKeys.nonEmpty) { routingKeys.foreach { routingKey => - bindQueue(channelFactoryInfo)(channel, queueName)(exchangeName, routingKey, bindArguments.value) + bindQueue(connectionInfo)(channel, queueName)(exchangeName, routingKey, bindArguments.value) } } else { // binding without routing key, possibly to fanout exchange - bindQueue(channelFactoryInfo)(channel, queueName)(exchangeName, "", bindArguments.value) + bindQueue(connectionInfo)(channel, queueName)(exchangeName, "", bindArguments.value) } } - private[rabbitmq] def bindQueue(channelFactoryInfo: RabbitMQConnectionInfo)( + private[rabbitmq] def bindQueue(connectionInfo: RabbitMQConnectionInfo)( channel: ServerChannel, queueName: String)(exchangeName: String, routingKey: String, arguments: ArgumentsMap): AMQP.Queue.BindOk = { - logger.info(s"Binding exchange $exchangeName($routingKey) -> queue '$queueName' in virtual host '${channelFactoryInfo.virtualHost}'") + logger.info(s"Binding exchange $exchangeName($routingKey) -> queue '$queueName' in virtual host '${connectionInfo.virtualHost}'") channel.queueBind(queueName, exchangeName, routingKey, arguments) } @@ -455,7 +455,7 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging { } private def declareExchangesFromBindings(configName: String, - channelFactoryInfo: RabbitMQConnectionInfo, + connectionInfo: RabbitMQConnectionInfo, channel: ServerChannel, bindings: Seq[AutoBindQueue]): Unit = { bindings.zipWithIndex.foreach { @@ -467,13 +467,13 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging { val path = s"$configName.bindings.$i.exchange.declare" val d = declare.wrapped(path).as[AutoDeclareExchange](path) - declareExchange(name, channelFactoryInfo, channel, d) + declareExchange(name, connectionInfo, channel, d) } } } private def prepareConsumer[F[_]: Effect, A: DeliveryConverter](consumerConfig: ConsumerConfig, - channelFactoryInfo: RabbitMQConnectionInfo, + connectionInfo: RabbitMQConnectionInfo, channel: ServerChannel, userReadAction: DeliveryReadAction[F, A], consumerListener: ConsumerListener, @@ -502,7 +502,7 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging { } val consumer = { - new DefaultRabbitMQConsumer(name, channel, queueName, monitor, failureAction, consumerListener, blockingScheduler)(readAction) + new DefaultRabbitMQConsumer(name, channel, queueName, connectionInfo, monitor, failureAction, consumerListener, blockingScheduler)(readAction) } val finalConsumerTag = if (consumerTag == "Default") "" else consumerTag diff --git a/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQConsumer.scala b/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQConsumer.scala index 64226cf0..e1f4604b 100755 --- a/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQConsumer.scala +++ b/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQConsumer.scala @@ -24,6 +24,7 @@ class DefaultRabbitMQConsumer[F[_]: Effect]( override val name: String, override protected val channel: ServerChannel, override protected val queueName: String, + override protected val connectionInfo: RabbitMQConnectionInfo, override protected val monitor: Monitor, failureAction: DeliveryResult, consumerListener: ConsumerListener, @@ -153,4 +154,5 @@ class DefaultRabbitMQConsumer[F[_]: Effect]( object DefaultRabbitMQConsumer { final val RepublishOriginalRoutingKeyHeaderName = "X-Original-Routing-Key" + final val RepublishOriginalUserId = "X-Original-User-Id" } diff --git a/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQPullConsumer.scala b/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQPullConsumer.scala index 4d992a33..57a3a8de 100644 --- a/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQPullConsumer.scala +++ b/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQPullConsumer.scala @@ -20,6 +20,7 @@ class DefaultRabbitMQPullConsumer[F[_]: Effect, A: DeliveryConverter]( override val name: String, protected override val channel: ServerChannel, protected override val queueName: String, + protected override val connectionInfo: RabbitMQConnectionInfo, failureAction: DeliveryResult, protected override val monitor: Monitor, protected override val blockingScheduler: Scheduler)(implicit sch: Scheduler) diff --git a/core/src/main/scala/com/avast/clients/rabbitmq/RabbitMQConnection.scala b/core/src/main/scala/com/avast/clients/rabbitmq/RabbitMQConnection.scala index 0f34a20e..55775240 100644 --- a/core/src/main/scala/com/avast/clients/rabbitmq/RabbitMQConnection.scala +++ b/core/src/main/scala/com/avast/clients/rabbitmq/RabbitMQConnection.scala @@ -186,7 +186,8 @@ object RabbitMQConnection extends StrictLogging { connection = connection, info = RabbitMQConnectionInfo( hosts = connectionConfig.hosts.toVector, - virtualHost = connectionConfig.virtualHost + virtualHost = connectionConfig.virtualHost, + username = if (connectionConfig.credentials.enabled) Option(connectionConfig.credentials.username) else None ), config = providedConfig, connectionListener = connectionListener, diff --git a/core/src/main/scala/com/avast/clients/rabbitmq/configuration.scala b/core/src/main/scala/com/avast/clients/rabbitmq/configuration.scala index 39529f31..375fcfe7 100644 --- a/core/src/main/scala/com/avast/clients/rabbitmq/configuration.scala +++ b/core/src/main/scala/com/avast/clients/rabbitmq/configuration.scala @@ -28,7 +28,7 @@ case class Ssl(enabled: Boolean, trustStore: TrustStore) case class TrustStore(path: Path, password: String) -private[rabbitmq] case class RabbitMQConnectionInfo(hosts: immutable.Seq[String], virtualHost: String) +private[rabbitmq] case class RabbitMQConnectionInfo(hosts: immutable.Seq[String], virtualHost: String, username: Option[String]) case class ConsumerConfig(queueName: String, processTimeout: Duration, diff --git a/core/src/test/scala/com/avast/clients/rabbitmq/DefaultRabbitMQConsumerTest.scala b/core/src/test/scala/com/avast/clients/rabbitmq/DefaultRabbitMQConsumerTest.scala index c84b837d..e2f268ab 100644 --- a/core/src/test/scala/com/avast/clients/rabbitmq/DefaultRabbitMQConsumerTest.scala +++ b/core/src/test/scala/com/avast/clients/rabbitmq/DefaultRabbitMQConsumerTest.scala @@ -12,16 +12,21 @@ import com.rabbitmq.client.impl.recovery.AutorecoveringChannel import monix.eval.Task import monix.execution.Scheduler import monix.execution.Scheduler.Implicits.global -import org.mockito.Matchers -import org.mockito.Matchers.any +import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ import org.scalatest.time.{Seconds, Span} import scala.collection.mutable +import scala.collection.immutable import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Random, Success} +import scala.collection.JavaConverters._ + class DefaultRabbitMQConsumerTest extends TestBase { + + private val connectionInfo = RabbitMQConnectionInfo(immutable.Seq("localhost"), "/", None) + test("should ACK") { val messageId = UUID.randomUUID().toString @@ -39,6 +44,7 @@ class DefaultRabbitMQConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, Monitor.noOp, DeliveryResult.Reject, DefaultListeners.DefaultConsumerListener, @@ -77,6 +83,7 @@ class DefaultRabbitMQConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, Monitor.noOp, DeliveryResult.Reject, DefaultListeners.DefaultConsumerListener, @@ -115,6 +122,7 @@ class DefaultRabbitMQConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, Monitor.noOp, DeliveryResult.Reject, DefaultListeners.DefaultConsumerListener, @@ -144,7 +152,8 @@ class DefaultRabbitMQConsumerTest extends TestBase { val envelope = mock[Envelope] when(envelope.getDeliveryTag).thenReturn(deliveryTag) - val properties = new BasicProperties.Builder().messageId(messageId).build() + val originalUserId = "OriginalUserId" + val properties = new BasicProperties.Builder().messageId(messageId).userId(originalUserId).build() val channel = mock[AutorecoveringChannel] @@ -152,6 +161,7 @@ class DefaultRabbitMQConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, Monitor.noOp, DeliveryResult.Reject, DefaultListeners.DefaultConsumerListener, @@ -169,7 +179,9 @@ class DefaultRabbitMQConsumerTest extends TestBase { verify(channel, times(1)).basicAck(deliveryTag, false) verify(channel, times(0)).basicReject(deliveryTag, true) verify(channel, times(0)).basicReject(deliveryTag, false) - verify(channel, times(1)).basicPublish(Matchers.eq(""), Matchers.eq("queueName"), any(), Matchers.eq(body)) + val propertiesCaptor = ArgumentCaptor.forClass(classOf[BasicProperties]) + verify(channel, times(1)).basicPublish(Matchers.eq(""), Matchers.eq("queueName"), propertiesCaptor.capture(), Matchers.eq(body)) + assertResult(Some(originalUserId))(propertiesCaptor.getValue.getHeaders.asScala.get(DefaultRabbitMQConsumer.RepublishOriginalUserId)) } } @@ -190,6 +202,7 @@ class DefaultRabbitMQConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, Monitor.noOp, DeliveryResult.Retry, DefaultListeners.DefaultConsumerListener, @@ -225,6 +238,7 @@ class DefaultRabbitMQConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, Monitor.noOp, DeliveryResult.Retry, DefaultListeners.DefaultConsumerListener, @@ -288,6 +302,7 @@ class DefaultRabbitMQConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, monitor, DeliveryResult.Retry, DefaultListeners.DefaultConsumerListener, @@ -315,6 +330,7 @@ class DefaultRabbitMQConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, monitor, DeliveryResult.Retry, DefaultListeners.DefaultConsumerListener, @@ -381,6 +397,7 @@ class DefaultRabbitMQConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, monitor, DeliveryResult.Retry, DefaultListeners.DefaultConsumerListener, @@ -408,6 +425,7 @@ class DefaultRabbitMQConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, monitor, DeliveryResult.Retry, DefaultListeners.DefaultConsumerListener, diff --git a/core/src/test/scala/com/avast/clients/rabbitmq/DefaultRabbitMQPullConsumerTest.scala b/core/src/test/scala/com/avast/clients/rabbitmq/DefaultRabbitMQPullConsumerTest.scala index c38549f5..b099f01b 100644 --- a/core/src/test/scala/com/avast/clients/rabbitmq/DefaultRabbitMQPullConsumerTest.scala +++ b/core/src/test/scala/com/avast/clients/rabbitmq/DefaultRabbitMQPullConsumerTest.scala @@ -11,15 +11,19 @@ import com.rabbitmq.client.{Envelope, GetResponse} import monix.eval.Task import monix.execution.Scheduler import monix.execution.Scheduler.Implicits.global -import org.mockito.Matchers -import org.mockito.Matchers.any +import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ import org.scalatest.time.{Seconds, Span} +import scala.collection.immutable import scala.util.Random +import scala.collection.JavaConverters._ + class DefaultRabbitMQPullConsumerTest extends TestBase { + private val connectionInfo = RabbitMQConnectionInfo(immutable.Seq("localhost"), "/", None) + test("should ACK") { val messageId = UUID.randomUUID().toString @@ -43,6 +47,7 @@ class DefaultRabbitMQPullConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, DeliveryResult.Reject, Monitor.noOp, Scheduler.global @@ -85,6 +90,7 @@ class DefaultRabbitMQPullConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, DeliveryResult.Reject, Monitor.noOp, Scheduler.global @@ -127,6 +133,7 @@ class DefaultRabbitMQPullConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, DeliveryResult.Ack, Monitor.noOp, Scheduler.global @@ -154,7 +161,8 @@ class DefaultRabbitMQPullConsumerTest extends TestBase { val envelope = mock[Envelope] when(envelope.getDeliveryTag).thenReturn(deliveryTag) - val properties = new BasicProperties.Builder().messageId(messageId).build() + val originalUserId = "OriginalUserId" + val properties = new BasicProperties.Builder().messageId(messageId).userId(originalUserId).build() val channel = mock[AutorecoveringChannel] @@ -168,6 +176,7 @@ class DefaultRabbitMQPullConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, DeliveryResult.Reject, Monitor.noOp, Scheduler.global @@ -183,7 +192,9 @@ class DefaultRabbitMQPullConsumerTest extends TestBase { verify(channel, times(1)).basicAck(deliveryTag, false) verify(channel, times(0)).basicReject(deliveryTag, true) verify(channel, times(0)).basicReject(deliveryTag, false) - verify(channel, times(1)).basicPublish(Matchers.eq(""), Matchers.eq("queueName"), any(), Matchers.eq(body)) + val propertiesCaptor = ArgumentCaptor.forClass(classOf[BasicProperties]) + verify(channel, times(1)).basicPublish(Matchers.eq(""), Matchers.eq("queueName"), propertiesCaptor.capture(), Matchers.eq(body)) + assertResult(Some(originalUserId))(propertiesCaptor.getValue.getHeaders.asScala.get(DefaultRabbitMQConsumer.RepublishOriginalUserId)) } } @@ -216,6 +227,7 @@ class DefaultRabbitMQPullConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, DeliveryResult.Retry, Monitor.noOp, Scheduler.global @@ -260,6 +272,7 @@ class DefaultRabbitMQPullConsumerTest extends TestBase { "test", channel, "queueName", + connectionInfo, DeliveryResult.Ack, Monitor.noOp, Scheduler.global