Skip to content

Commit

Permalink
Merge pull request #200 from avast/fix_publish_confirmation
Browse files Browse the repository at this point in the history
fix: Fix bug in the publisher confirmations
  • Loading branch information
jendakol authored Feb 7, 2023
2 parents 0ef2a8a + 0fece01 commit 3ab1cdf
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 77 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.avast.clients.rabbitmq

import cats.effect._
import cats.effect.concurrent.{Deferred, Ref}
import cats.implicits.{catsSyntaxFlatMapOps, toFunctorOps, toTraverseOps}
import com.avast.bytes.Bytes
import com.avast.clients.rabbitmq.DefaultRabbitMQClientFactory.startConsumingQueue
Expand Down Expand Up @@ -270,41 +269,42 @@ private[rabbitmq] class DefaultRabbitMQClientFactory[F[_]: ConcurrentEffect: Tim
producerConfig.properties.confirms match {
case Some(PublisherConfirmsConfig(true, sendAttempts)) =>
prepareProducer(producerConfig, connection) { (defaultProperties, channel, logger) =>
Ref.of(Map.empty[Long, Deferred[F, Either[NotAcknowledgedPublish, Unit]]])
.map {
new PublishConfirmsRabbitMQProducer[F, A](
producerConfig.name,
producerConfig.exchange,
channel,
defaultProperties,
_,
sendAttempts,
producerConfig.reportUnroutable,
producerConfig.sizeLimitBytes,
blocker,
logger,
monitor)
}
F.pure {
new PublishConfirmsRabbitMQProducer[F, A](
producerConfig.name,
producerConfig.exchange,
channel,
defaultProperties,
sendAttempts,
producerConfig.reportUnroutable,
producerConfig.sizeLimitBytes,
blocker,
logger,
monitor
)
}
}
case _ =>
prepareProducer(producerConfig, connection) { (defaultProperties, channel, logger) =>
F.pure {
new DefaultRabbitMQProducer[F, A](producerConfig.name,
producerConfig.exchange,
channel,
defaultProperties,
producerConfig.reportUnroutable,
producerConfig.sizeLimitBytes,
blocker,
logger,
monitor)
new DefaultRabbitMQProducer[F, A](
producerConfig.name,
producerConfig.exchange,
channel,
defaultProperties,
producerConfig.reportUnroutable,
producerConfig.sizeLimitBytes,
blocker,
logger,
monitor
)
}
}
}
}

private def prepareProducer[T: ClassTag, A: ProductConverter](producerConfig: ProducerConfig, connection: RabbitMQConnection[F])(
createProducer: (MessageProperties, ServerChannel, ImplicitContextLogger[F]) => F[T]) = {
createProducer: (MessageProperties, ServerChannel, ImplicitContextLogger[F]) => F[T]) = {
val logger: ImplicitContextLogger[F] = ImplicitContextLogger.createLogger[F, T]

connection
Expand All @@ -319,7 +319,7 @@ private[rabbitmq] class DefaultRabbitMQClientFactory[F[_]: ConcurrentEffect: Tim
contentType = producerConfig.properties.contentType,
contentEncoding = producerConfig.properties.contentEncoding,
priority = producerConfig.properties.priority.map(Integer.valueOf)
)
)
createProducer(defaultProperties, channel, logger)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import com.avast.clients.rabbitmq.JavaConverters._
import com.avast.clients.rabbitmq.api.CorrelationIdStrategy.FromPropertiesOrRandomNew
import com.avast.clients.rabbitmq.api._
import com.avast.clients.rabbitmq.logging.ImplicitContextLogger
import com.avast.clients.rabbitmq.{CorrelationId, ProductConverter, ServerChannel, startAndForget}
import com.avast.clients.rabbitmq.{startAndForget, CorrelationId, ProductConverter, ServerChannel}
import com.avast.metrics.scalaeffectapi.Monitor
import com.rabbitmq.client.AMQP.BasicProperties
import com.rabbitmq.client.{AlreadyClosedException, ReturnListener}
Expand Down Expand Up @@ -63,17 +63,21 @@ abstract class BaseRabbitMQProducer[F[_], A: ProductConverter](name: String,
}
}

protected def basicSend(routingKey: String, body: Bytes, properties: MessageProperties)(implicit correlationId: CorrelationId): F[Unit] = {
protected def basicSend(routingKey: String, body: Bytes, properties: MessageProperties, preSendAction: Long => Unit = (_: Long) => ())(
implicit correlationId: CorrelationId): F[Long] = {
for {
_ <- logger.debug(s"Sending message with ${body.size()} B to exchange $exchangeName with routing key '$routingKey' and $properties")
_ <- blocker.delay {
sequenceNumber <- blocker.delay {
sendLock.synchronized {
// see https://www.rabbitmq.com/api-guide.html#channel-threads
val sequenceNumber = channel.getNextPublishSeqNo
preSendAction(sequenceNumber)
channel.basicPublish(exchangeName, routingKey, properties.asAMQP, body.toByteArray)
sequenceNumber
}
}
_ <- sentMeter.mark
} yield ()
} yield sequenceNumber
}

private def processErrors(from: F[Unit], routingKey: String)(implicit correlationId: CorrelationId): F[Unit] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.avast.clients.rabbitmq.publisher

import cats.effect.{Blocker, ConcurrentEffect, ContextShift}
import cats.implicits.toFunctorOps
import com.avast.bytes.Bytes
import com.avast.clients.rabbitmq.api.MessageProperties
import com.avast.clients.rabbitmq.logging.ImplicitContextLogger
Expand All @@ -26,5 +27,5 @@ class DefaultRabbitMQProducer[F[_], A: ProductConverter](name: String,
logger,
monitor) {
override def sendMessage(routingKey: String, body: Bytes, properties: MessageProperties)(implicit correlationId: CorrelationId): F[Unit] =
basicSend(routingKey, body, properties)
basicSend(routingKey, body, properties).void
}
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
package com.avast.clients.rabbitmq.publisher

import cats.effect.concurrent.{Deferred, Ref}
import cats.effect.concurrent.Deferred
import cats.effect.{Blocker, ConcurrentEffect, ContextShift}
import cats.syntax.flatMap._
import cats.syntax.functor._
import com.avast.bytes.Bytes
import com.avast.clients.rabbitmq.api.{MaxAttemptsReached, MessageProperties, NotAcknowledgedPublish}
import com.avast.clients.rabbitmq.logging.ImplicitContextLogger
import com.avast.clients.rabbitmq.publisher.PublishConfirmsRabbitMQProducer.SentMessages
import com.avast.clients.rabbitmq.{CorrelationId, ProductConverter, ServerChannel, startAndForget}
import com.avast.clients.rabbitmq.{startAndForget, CorrelationId, ProductConverter, ServerChannel}
import com.avast.metrics.scalaeffectapi.Monitor
import com.rabbitmq.client.ConfirmListener

import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters._
class PublishConfirmsRabbitMQProducer[F[_], A: ProductConverter](name: String,
exchangeName: String,
channel: ServerChannel,
defaultProperties: MessageProperties,
sentMessages: SentMessages[F],
sendAttempts: Int,
reportUnroutable: Boolean,
sizeLimitBytes: Option[Int],
Expand All @@ -39,6 +39,10 @@ class PublishConfirmsRabbitMQProducer[F[_], A: ProductConverter](name: String,
private val acked = monitor.meter("acked")
private val nacked = monitor.meter("nacked")

private[rabbitmq] val confirmationCallbacks = {
new ConcurrentHashMap[Long, Deferred[F, Either[NotAcknowledgedPublish, Unit]]]().asScala
}

override def sendMessage(routingKey: String, body: Bytes, properties: MessageProperties)(implicit correlationId: CorrelationId): F[Unit] =
sendWithAck(routingKey, body, properties, 1)

Expand All @@ -48,34 +52,29 @@ class PublishConfirmsRabbitMQProducer[F[_], A: ProductConverter](name: String,
if (attemptCount > sendAttempts) {
F.raiseError(MaxAttemptsReached("Exhausted max number of attempts"))
} else {
val messageId = channel.getNextPublishSeqNo
for {
defer <- Deferred.apply[F, Either[NotAcknowledgedPublish, Unit]]
_ <- sentMessages.update(_ + (messageId -> defer))
_ <- basicSend(routingKey, body, properties)
result <- defer.get
confirmationCallback <- Deferred.apply[F, Either[NotAcknowledgedPublish, Unit]]
sequenceNumber <- basicSend(routingKey, body, properties, (sequenceNumber: Long) => {
confirmationCallbacks += sequenceNumber -> confirmationCallback
})
result <- confirmationCallback.get
_ <- F.delay(confirmationCallbacks -= sequenceNumber)
_ <- result match {
case Left(err) =>
val sendResult = if (sendAttempts > 1) {
clearProcessedMessage(messageId) >> sendWithAck(routingKey, body, properties, attemptCount + 1)
sendWithAck(routingKey, body, properties, attemptCount + 1)
} else {
F.raiseError(err)
}

nacked.mark >> sendResult
case Right(_) =>
acked.mark >> clearProcessedMessage(messageId)
acked.mark
}
} yield ()
}
}

private def clearProcessedMessage(messageId: Long): F[Unit] = {
sentMessages.update(_ - messageId)
}

private object DefaultConfirmListener extends ConfirmListener {
import cats.syntax.foldable._

override def handleAck(deliveryTag: Long, multiple: Boolean): Unit = {
startAndForget {
Expand All @@ -92,12 +91,10 @@ class PublishConfirmsRabbitMQProducer[F[_], A: ProductConverter](name: String,
}

private def completeDefer(deliveryTag: Long, result: Either[NotAcknowledgedPublish, Unit]): F[Unit] = {
sentMessages.get.flatMap(_.get(deliveryTag).traverse_(_.complete(result)))
confirmationCallbacks.get(deliveryTag) match {
case Some(callback) => callback.complete(result)
case None => logger.plainWarn("Received confirmation for unknown delivery tag. That is unexpected state.")
}
}
}

}

object PublishConfirmsRabbitMQProducer {
type SentMessages[F[_]] = Ref[F, Map[Long, Deferred[F, Either[NotAcknowledgedPublish, Unit]]]]
}
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
package com.avast.clients.rabbitmq

import cats.effect.concurrent.{Deferred, Ref}
import cats.syntax.parallel._
import cats.implicits.catsSyntaxParallelAp
import com.avast.bytes.Bytes
import com.avast.clients.rabbitmq.api.{MessageProperties, NotAcknowledgedPublish}
import com.avast.clients.rabbitmq.logging.ImplicitContextLogger
import com.avast.clients.rabbitmq.publisher.PublishConfirmsRabbitMQProducer
import com.avast.clients.rabbitmq.publisher.PublishConfirmsRabbitMQProducer.SentMessages
import com.avast.metrics.scalaeffectapi.Monitor
import com.rabbitmq.client.impl.recovery.AutorecoveringChannel
import monix.eval.Task
import monix.execution.Scheduler.Implicits.global
import org.junit.runner.manipulation.InvalidOrderingException
import org.mockito.Matchers
import org.mockito.Matchers.any
import org.mockito.Mockito.{times, verify, when}

import scala.concurrent.Await
import scala.concurrent.duration.DurationInt
import scala.util.Random

class PublisherConfirmsRabbitMQProducerTest extends TestBase {
test("message is acked after one retry") {

test("Message is acked after one retry") {
val exchangeName = Random.nextString(10)
val routingKey = Random.nextString(10)
val seqNumber = 1L
val seqNumber2 = 2L

val channel = mock[AutorecoveringChannel]
val ref = Ref.of[Task, Map[Long, Deferred[Task, Either[NotAcknowledgedPublish, Unit]]]](Map.empty).await
val updatedState1 = updateMessageState(ref, seqNumber)(Left(NotAcknowledgedPublish("abcd", messageId = seqNumber)))
val updatedState2 = updateMessageState(ref, seqNumber2)(Right())
when(channel.getNextPublishSeqNo).thenReturn(seqNumber, seqNumber2)

val producer = new PublishConfirmsRabbitMQProducer[Task, Bytes](
name = "test",
Expand All @@ -39,15 +39,21 @@ class PublisherConfirmsRabbitMQProducerTest extends TestBase {
sizeLimitBytes = None,
blocker = TestBase.testBlocker,
logger = ImplicitContextLogger.createLogger,
sentMessages = ref,
sendAttempts = 2
)
when(channel.getNextPublishSeqNo).thenReturn(seqNumber, seqNumber2)

producer.send(routingKey, Bytes.copyFrom(Array.fill(499)(32.toByte))).parProduct(updatedState1.parProduct(updatedState2)).await
val body = Bytes.copyFrom(Array.fill(499)(32.toByte))

val publishTask = producer.send(routingKey, body).runToFuture

updateMessageState(producer, seqNumber)(Left(NotAcknowledgedPublish("abcd", messageId = seqNumber))).parProduct {
updateMessageState(producer, seqNumber2)(Right())
}.await

Await.result(publishTask, 10.seconds)

verify(channel, times(2))
.basicPublish(Matchers.eq(exchangeName), Matchers.eq(routingKey), any(), Matchers.eq(Bytes.copyFrom(Array.fill(499)(32.toByte)).toByteArray))
.basicPublish(Matchers.eq(exchangeName), Matchers.eq(routingKey), any(), Matchers.eq(body.toByteArray))
}

test("Message not acked returned if number of attempts exhausted") {
Expand All @@ -56,8 +62,7 @@ class PublisherConfirmsRabbitMQProducerTest extends TestBase {
val seqNumber = 1L

val channel = mock[AutorecoveringChannel]
val ref = Ref.of[Task, Map[Long, Deferred[Task, Either[NotAcknowledgedPublish, Unit]]]](Map.empty).await
val updatedState = updateMessageState(ref, seqNumber)(Left(NotAcknowledgedPublish("abcd", messageId = seqNumber)))
when(channel.getNextPublishSeqNo).thenReturn(seqNumber)

val producer = new PublishConfirmsRabbitMQProducer[Task, Bytes](
name = "test",
Expand All @@ -69,22 +74,76 @@ class PublisherConfirmsRabbitMQProducerTest extends TestBase {
sizeLimitBytes = None,
blocker = TestBase.testBlocker,
logger = ImplicitContextLogger.createLogger,
sentMessages = ref,
sendAttempts = 1
)
when(channel.getNextPublishSeqNo).thenReturn(seqNumber)

val body = Bytes.copyFrom(Array.fill(499)(32.toByte))

val publishTask = producer.send(routingKey, body).runToFuture

assertThrows[NotAcknowledgedPublish] {
producer.send(routingKey, Bytes.copyFrom(Array.fill(499)(32.toByte))).parProduct(updatedState).await
updateMessageState(producer, seqNumber)(Left(NotAcknowledgedPublish("abcd", messageId = seqNumber))).await
Await.result(publishTask, 10.seconds)
}

verify(channel).basicPublish(Matchers.eq(exchangeName), Matchers.eq(routingKey), any(), Matchers.eq(Bytes.copyFrom(Array.fill(499)(32.toByte)).toByteArray))
verify(channel).basicPublish(Matchers.eq(exchangeName), Matchers.eq(routingKey), any(), Matchers.eq(body.toByteArray))
}

test("Multiple messages are fully acked") {
val exchangeName = Random.nextString(10)
val routingKey = Random.nextString(10)

val channel = mock[AutorecoveringChannel]

val seqNumbers = 1 to 500
val iterator = seqNumbers.iterator
when(channel.getNextPublishSeqNo).thenAnswer(_ => { iterator.next() })

val producer = new PublishConfirmsRabbitMQProducer[Task, Bytes](
name = "test",
exchangeName = exchangeName,
channel = channel,
monitor = Monitor.noOp(),
defaultProperties = MessageProperties.empty,
reportUnroutable = false,
sizeLimitBytes = None,
blocker = TestBase.testBlocker,
logger = ImplicitContextLogger.createLogger,
sendAttempts = 2
)

val body = Bytes.copyFrom(Array.fill(499)(32.toByte))

val publishTasks = Task.parSequenceUnordered {
seqNumbers.map { _ =>
producer.send(routingKey, body)
}
}.runToFuture

Task
.parSequenceUnordered(seqNumbers.map { seqNumber =>
updateMessageState(producer, seqNumber)(Right())
})
.await(15.seconds)

Await.result(publishTasks, 15.seconds)

verify(channel, times(seqNumbers.length))
.basicPublish(Matchers.eq(exchangeName), Matchers.eq(routingKey), any(), Matchers.eq(body.toByteArray))
}

private def updateMessageState(ref: SentMessages[Task], messageId: Long)(result: Either[NotAcknowledgedPublish, Unit]): Task[Unit] = {
ref.get.flatMap(map => map.get(messageId) match {
case Some(value) => value.complete(result)
case None => updateMessageState(ref, messageId)(result)
})
private def updateMessageState(producer: PublishConfirmsRabbitMQProducer[Task, Bytes], messageId: Long, attempt: Int = 1)(
result: Either[NotAcknowledgedPublish, Unit]): Task[Unit] = {
Task
.delay(producer.confirmationCallbacks.get(messageId))
.flatMap {
case Some(value) => value.complete(result)
case None =>
if (attempt < 90) {
Task.sleep(100.millis) >> updateMessageState(producer, messageId, attempt + 1)(result)
} else {
throw new InvalidOrderingException(s"The message ID $messageId is not present in the list of callbacks")
}
}
}
}

0 comments on commit 3ab1cdf

Please sign in to comment.