Skip to content

Commit

Permalink
Republishing aware of UserID (#27)
Browse files Browse the repository at this point in the history
Republishing aware of UserID.

(cherry picked from commit d75666b)
  • Loading branch information
augi authored and jendakol committed May 29, 2019
1 parent 077ed13 commit 373b463
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 39 deletions.
24 changes: 17 additions & 7 deletions core/src/main/scala/com/avast/clients/rabbitmq/ConsumerBase.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.avast.clients.rabbitmq

import cats.effect.Effect
import com.avast.clients.rabbitmq.DefaultRabbitMQConsumer.RepublishOriginalRoutingKeyHeaderName
import com.avast.clients.rabbitmq.DefaultRabbitMQConsumer._
import com.avast.clients.rabbitmq.api.DeliveryResult
import com.avast.metrics.scalaapi.Monitor
import com.rabbitmq.client.AMQP.BasicProperties
Expand All @@ -18,6 +18,7 @@ private[rabbitmq] trait ConsumerBase[F[_]] extends StrictLogging {
protected def queueName: String
protected def channel: ServerChannel
protected def blockingScheduler: Scheduler
protected def connectionInfo: RabbitMQConnectionInfo
protected def monitor: Monitor

protected val resultsMonitor: Monitor = monitor.named("results")
Expand All @@ -35,7 +36,7 @@ private[rabbitmq] trait ConsumerBase[F[_]] extends StrictLogging {
case Reject => reject(messageId, deliveryTag)
case Retry => retry(messageId, deliveryTag)
case Republish(newHeaders) =>
republish(messageId, deliveryTag, mergeHeadersForRepublish(newHeaders, properties, routingKey), body)
republish(messageId, deliveryTag, createPropertiesForRepublish(newHeaders, properties, routingKey), body)
}

task.to[F]
Expand Down Expand Up @@ -86,12 +87,21 @@ private[rabbitmq] trait ConsumerBase[F[_]] extends StrictLogging {
}
}.executeOn(blockingScheduler).asyncBoundary

protected def mergeHeadersForRepublish(newHeaders: Map[String, AnyRef],
properties: BasicProperties,
routingKey: String): BasicProperties = {
protected def createPropertiesForRepublish(newHeaders: Map[String, AnyRef],
properties: BasicProperties,
routingKey: String): BasicProperties = {
// values in newHeaders will overwrite values in original headers
val h = newHeaders + (RepublishOriginalRoutingKeyHeaderName -> routingKey)
// we must also ensure that UserID will be the same as current username (or nothing): https://www.rabbitmq.com/validated-user-id.html
val originalUserId = Option(properties.getUserId).filter(_.nonEmpty)
val h = originalUserId match {
case Some(uid) => newHeaders + (RepublishOriginalRoutingKeyHeaderName -> routingKey) + (RepublishOriginalUserId -> uid)
case None => newHeaders + (RepublishOriginalRoutingKeyHeaderName -> routingKey)
}
val headers = Option(properties.getHeaders).map(_.asScala ++ h).getOrElse(h)
properties.builder().headers(headers.asJava).build()
val newUserId = originalUserId match {
case Some(_) => connectionInfo.username.orNull
case None => null
}
properties.builder().headers(headers.asJava).userId(newUserId).build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -306,13 +306,13 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging {
}

private[rabbitmq] def declareExchange(name: String,
channelFactoryInfo: RabbitMQConnectionInfo,
connectionInfo: RabbitMQConnectionInfo,
channel: ServerChannel,
autoDeclareExchange: AutoDeclareExchange): Unit = {
import autoDeclareExchange._

if (enabled) {
declareExchange(name, `type`, durable, autoDelete, arguments, channel, channelFactoryInfo)
declareExchange(name, `type`, durable, autoDelete, arguments, channel, connectionInfo)
}
()
}
Expand All @@ -323,8 +323,8 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging {
autoDelete: Boolean,
arguments: DeclareArguments,
channel: ServerChannel,
channelFactoryInfo: RabbitMQConnectionInfo): Unit = {
logger.info(s"Declaring exchange '$name' of type ${`type`} in virtual host '${channelFactoryInfo.virtualHost}'")
connectionInfo: RabbitMQConnectionInfo): Unit = {
logger.info(s"Declaring exchange '$name' of type ${`type`} in virtual host '${connectionInfo.virtualHost}'")
val javaArguments = argsAsJava(arguments.value)
channel.exchangeDeclare(name, `type`, durable, autoDelete, javaArguments)
()
Expand Down Expand Up @@ -358,33 +358,33 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging {
private def preparePullConsumer[F[_]: Effect, A: DeliveryConverter](
consumerConfig: PullConsumerConfig,
configName: String,
channelFactoryInfo: RabbitMQConnectionInfo,
connectionInfo: RabbitMQConnectionInfo,
channel: ServerChannel,
blockingScheduler: Scheduler,
monitor: Monitor)(implicit scheduler: Scheduler): DefaultRabbitMQPullConsumer[F, A] = {

import consumerConfig._

// auto declare exchanges
declareExchangesFromBindings(configName, channelFactoryInfo, channel, consumerConfig.bindings)
declareExchangesFromBindings(configName, connectionInfo, channel, consumerConfig.bindings)

// auto declare queue
declareQueue(queueName, channelFactoryInfo, channel, declare)
declareQueue(queueName, connectionInfo, channel, declare)

// auto bind
bindQueues(channelFactoryInfo, channel, consumerConfig.queueName, consumerConfig.bindings)
bindQueues(connectionInfo, channel, consumerConfig.queueName, consumerConfig.bindings)

new DefaultRabbitMQPullConsumer[F, A](name, channel, queueName, failureAction, monitor, blockingScheduler)
new DefaultRabbitMQPullConsumer[F, A](name, channel, queueName, connectionInfo, failureAction, monitor, blockingScheduler)
}

private def declareQueue(queueName: String,
channelFactoryInfo: RabbitMQConnectionInfo,
connectionInfo: RabbitMQConnectionInfo,
channel: ServerChannel,
declare: AutoDeclareQueue): Unit = {
import declare._

if (enabled) {
logger.info(s"Declaring queue '$queueName' in virtual host '${channelFactoryInfo.virtualHost}'")
logger.info(s"Declaring queue '$queueName' in virtual host '${connectionInfo.virtualHost}'")
declareQueue(channel, queueName, durable, exclusive, autoDelete, arguments)
}
}
Expand All @@ -398,15 +398,15 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging {
channel.queueDeclare(queueName, durable, exclusive, autoDelete, arguments.value)
}

private def bindQueues(channelFactoryInfo: RabbitMQConnectionInfo,
private def bindQueues(connectionInfo: RabbitMQConnectionInfo,
channel: ServerChannel,
queueName: String,
bindings: immutable.Seq[AutoBindQueue]): Unit = {
bindings.foreach { bind =>
import bind._
val exchangeName = bind.exchange.name

bindQueues(channel, queueName, exchangeName, routingKeys, bindArguments, channelFactoryInfo)
bindQueues(channel, queueName, exchangeName, routingKeys, bindArguments, connectionInfo)
}
}

Expand All @@ -415,22 +415,22 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging {
exchangeName: String,
routingKeys: immutable.Seq[String],
bindArguments: BindArguments,
channelFactoryInfo: RabbitMQConnectionInfo): Unit = {
connectionInfo: RabbitMQConnectionInfo): Unit = {
if (routingKeys.nonEmpty) {
routingKeys.foreach { routingKey =>
bindQueue(channelFactoryInfo)(channel, queueName)(exchangeName, routingKey, bindArguments.value)
bindQueue(connectionInfo)(channel, queueName)(exchangeName, routingKey, bindArguments.value)
}
} else {
// binding without routing key, possibly to fanout exchange

bindQueue(channelFactoryInfo)(channel, queueName)(exchangeName, "", bindArguments.value)
bindQueue(connectionInfo)(channel, queueName)(exchangeName, "", bindArguments.value)
}
}

private[rabbitmq] def bindQueue(channelFactoryInfo: RabbitMQConnectionInfo)(
private[rabbitmq] def bindQueue(connectionInfo: RabbitMQConnectionInfo)(
channel: ServerChannel,
queueName: String)(exchangeName: String, routingKey: String, arguments: ArgumentsMap): AMQP.Queue.BindOk = {
logger.info(s"Binding exchange $exchangeName($routingKey) -> queue '$queueName' in virtual host '${channelFactoryInfo.virtualHost}'")
logger.info(s"Binding exchange $exchangeName($routingKey) -> queue '$queueName' in virtual host '${connectionInfo.virtualHost}'")

channel.queueBind(queueName, exchangeName, routingKey, arguments)
}
Expand All @@ -448,7 +448,7 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging {
}

private def declareExchangesFromBindings(configName: String,
channelFactoryInfo: RabbitMQConnectionInfo,
connectionInfo: RabbitMQConnectionInfo,
channel: ServerChannel,
bindings: Seq[AutoBindQueue]): Unit = {
bindings.zipWithIndex.foreach {
Expand All @@ -460,13 +460,13 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging {
val path = s"$configName.bindings.$i.exchange.declare"
val d = declare.wrapped(path).as[AutoDeclareExchange](path)

declareExchange(name, channelFactoryInfo, channel, d)
declareExchange(name, connectionInfo, channel, d)
}
}
}

private def prepareConsumer[F[_]: Effect, A: DeliveryConverter](consumerConfig: ConsumerConfig,
channelFactoryInfo: RabbitMQConnectionInfo,
connectionInfo: RabbitMQConnectionInfo,
channel: ServerChannel,
userReadAction: DeliveryReadAction[F, A],
consumerListener: ConsumerListener,
Expand Down Expand Up @@ -495,7 +495,7 @@ private[rabbitmq] object DefaultRabbitMQClientFactory extends LazyLogging {
}

val consumer = {
new DefaultRabbitMQConsumer(name, channel, queueName, monitor, failureAction, consumerListener, blockingScheduler)(readAction)
new DefaultRabbitMQConsumer(name, channel, queueName, connectionInfo, monitor, failureAction, consumerListener, blockingScheduler)(readAction)
}

val finalConsumerTag = if (consumerTag == "Default") "" else consumerTag
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class DefaultRabbitMQConsumer[F[_]: Effect](
override val name: String,
override protected val channel: ServerChannel,
override protected val queueName: String,
override protected val connectionInfo: RabbitMQConnectionInfo,
override protected val monitor: Monitor,
failureAction: DeliveryResult,
consumerListener: ConsumerListener,
Expand Down Expand Up @@ -157,4 +158,5 @@ class DefaultRabbitMQConsumer[F[_]: Effect](

object DefaultRabbitMQConsumer {
final val RepublishOriginalRoutingKeyHeaderName = "X-Original-Routing-Key"
final val RepublishOriginalUserId = "X-Original-User-Id"
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class DefaultRabbitMQPullConsumer[F[_]: Effect, A: DeliveryConverter](
override val name: String,
protected override val channel: ServerChannel,
protected override val queueName: String,
protected override val connectionInfo: RabbitMQConnectionInfo,
failureAction: DeliveryResult,
protected override val monitor: Monitor,
protected override val blockingScheduler: Scheduler)(implicit sch: Scheduler)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ object RabbitMQConnection extends StrictLogging {
connection = connection,
info = RabbitMQConnectionInfo(
hosts = connectionConfig.hosts.toVector,
virtualHost = connectionConfig.virtualHost
virtualHost = connectionConfig.virtualHost,
username = if (connectionConfig.credentials.enabled) Option(connectionConfig.credentials.username) else None
),
config = providedConfig,
connectionListener = connectionListener,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ case class Ssl(enabled: Boolean, trustStore: TrustStore)

case class TrustStore(path: Path, password: String)

private[rabbitmq] case class RabbitMQConnectionInfo(hosts: immutable.Seq[String], virtualHost: String)
private[rabbitmq] case class RabbitMQConnectionInfo(hosts: immutable.Seq[String], virtualHost: String, username: Option[String])

case class ConsumerConfig(queueName: String,
processTimeout: Duration,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,21 @@ import com.rabbitmq.client.impl.recovery.AutorecoveringChannel
import monix.eval.Task
import monix.execution.Scheduler
import monix.execution.Scheduler.Implicits.global
import org.mockito.Matchers
import org.mockito.Matchers.any
import org.mockito.{ArgumentCaptor, Matchers}
import org.mockito.Mockito._
import org.scalatest.time.{Seconds, Span}

import scala.collection.mutable
import scala.collection.immutable
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Random, Success}

import scala.collection.JavaConverters._

class DefaultRabbitMQConsumerTest extends TestBase {

private val connectionInfo = RabbitMQConnectionInfo(immutable.Seq("localhost"), "/", None)

test("should ACK") {
val messageId = UUID.randomUUID().toString

Expand All @@ -39,6 +44,7 @@ class DefaultRabbitMQConsumerTest extends TestBase {
"test",
channel,
"queueName",
connectionInfo,
Monitor.noOp,
DeliveryResult.Reject,
DefaultListeners.DefaultConsumerListener,
Expand Down Expand Up @@ -77,6 +83,7 @@ class DefaultRabbitMQConsumerTest extends TestBase {
"test",
channel,
"queueName",
connectionInfo,
Monitor.noOp,
DeliveryResult.Reject,
DefaultListeners.DefaultConsumerListener,
Expand Down Expand Up @@ -115,6 +122,7 @@ class DefaultRabbitMQConsumerTest extends TestBase {
"test",
channel,
"queueName",
connectionInfo,
Monitor.noOp,
DeliveryResult.Reject,
DefaultListeners.DefaultConsumerListener,
Expand Down Expand Up @@ -144,14 +152,16 @@ class DefaultRabbitMQConsumerTest extends TestBase {
val envelope = mock[Envelope]
when(envelope.getDeliveryTag).thenReturn(deliveryTag)

val properties = new BasicProperties.Builder().messageId(messageId).build()
val originalUserId = "OriginalUserId"
val properties = new BasicProperties.Builder().messageId(messageId).userId(originalUserId).build()

val channel = mock[AutorecoveringChannel]

val consumer = new DefaultRabbitMQConsumer[Task](
"test",
channel,
"queueName",
connectionInfo,
Monitor.noOp,
DeliveryResult.Reject,
DefaultListeners.DefaultConsumerListener,
Expand All @@ -169,7 +179,9 @@ class DefaultRabbitMQConsumerTest extends TestBase {
verify(channel, times(1)).basicAck(deliveryTag, false)
verify(channel, times(0)).basicReject(deliveryTag, true)
verify(channel, times(0)).basicReject(deliveryTag, false)
verify(channel, times(1)).basicPublish(Matchers.eq(""), Matchers.eq("queueName"), any(), Matchers.eq(body))
val propertiesCaptor = ArgumentCaptor.forClass(classOf[BasicProperties])
verify(channel, times(1)).basicPublish(Matchers.eq(""), Matchers.eq("queueName"), propertiesCaptor.capture(), Matchers.eq(body))
assertResult(Some(originalUserId))(propertiesCaptor.getValue.getHeaders.asScala.get(DefaultRabbitMQConsumer.RepublishOriginalUserId))
}
}

Expand All @@ -190,6 +202,7 @@ class DefaultRabbitMQConsumerTest extends TestBase {
"test",
channel,
"queueName",
connectionInfo,
Monitor.noOp,
DeliveryResult.Retry,
DefaultListeners.DefaultConsumerListener,
Expand Down Expand Up @@ -225,6 +238,7 @@ class DefaultRabbitMQConsumerTest extends TestBase {
"test",
channel,
"queueName",
connectionInfo,
Monitor.noOp,
DeliveryResult.Retry,
DefaultListeners.DefaultConsumerListener,
Expand Down Expand Up @@ -288,6 +302,7 @@ class DefaultRabbitMQConsumerTest extends TestBase {
"test",
channel,
"queueName",
connectionInfo,
monitor,
DeliveryResult.Retry,
DefaultListeners.DefaultConsumerListener,
Expand Down Expand Up @@ -315,6 +330,7 @@ class DefaultRabbitMQConsumerTest extends TestBase {
"test",
channel,
"queueName",
connectionInfo,
monitor,
DeliveryResult.Retry,
DefaultListeners.DefaultConsumerListener,
Expand Down Expand Up @@ -381,6 +397,7 @@ class DefaultRabbitMQConsumerTest extends TestBase {
"test",
channel,
"queueName",
connectionInfo,
monitor,
DeliveryResult.Retry,
DefaultListeners.DefaultConsumerListener,
Expand Down Expand Up @@ -408,6 +425,7 @@ class DefaultRabbitMQConsumerTest extends TestBase {
"test",
channel,
"queueName",
connectionInfo,
monitor,
DeliveryResult.Retry,
DefaultListeners.DefaultConsumerListener,
Expand Down
Loading

0 comments on commit 373b463

Please sign in to comment.