From bd1a667b260906ceddb361e207aaf0314ae2ad15 Mon Sep 17 00:00:00 2001 From: Karel Fajkus Date: Tue, 13 Dec 2022 17:14:19 +0100 Subject: [PATCH] PR fixes --- .../clients/rabbitmq/api/exceptions.scala | 4 +- .../DefaultRabbitMQClientFactory.scala | 77 +++++++++------- .../publisher/BaseRabbitMQProducer.scala | 4 +- .../PublishConfirmsRabbitMQProducer.scala | 14 +-- .../DefaultRabbitMQProducerTest.scala | 3 +- ...ublisherConfirmsRabbitMQProducerTest.scala | 90 +++++++++++++++++++ 6 files changed, 146 insertions(+), 46 deletions(-) create mode 100644 core/src/test/scala/com/avast/clients/rabbitmq/PublisherConfirmsRabbitMQProducerTest.scala diff --git a/api/src/main/scala/com/avast/clients/rabbitmq/api/exceptions.scala b/api/src/main/scala/com/avast/clients/rabbitmq/api/exceptions.scala index 4a58096a..e52a5b76 100644 --- a/api/src/main/scala/com/avast/clients/rabbitmq/api/exceptions.scala +++ b/api/src/main/scala/com/avast/clients/rabbitmq/api/exceptions.scala @@ -7,6 +7,6 @@ case class ChannelNotRecoveredException(desc: String, cause: Throwable = null) e case class TooBigMessage(desc: String, cause: Throwable = null) extends IllegalArgumentException(desc, cause) -case class MaxAttempts(desc: String, cause: Throwable = null) extends RuntimeException(desc, cause) +case class MaxAttemptsReached(desc: String, cause: Throwable = null) extends RuntimeException(desc, cause) -case class NotAcknowledgedPublish(desc: String, cause: Throwable = null) extends RuntimeException(desc, cause) +case class NotAcknowledgedPublish(desc: String, cause: Throwable = null, messageId: Long) extends RuntimeException(desc, cause) diff --git a/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQClientFactory.scala b/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQClientFactory.scala index e34ae39d..3381e813 100644 --- a/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQClientFactory.scala +++ b/core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQClientFactory.scala @@ -15,6 +15,7 @@ import scala.collection.compat._ import scala.collection.immutable import scala.jdk.CollectionConverters._ import scala.language.implicitConversions +import scala.reflect.ClassTag private[rabbitmq] class DefaultRabbitMQClientFactory[F[_]: ConcurrentEffect: Timer: ContextShift]( connection: RabbitMQConnection[F], @@ -266,7 +267,45 @@ private[rabbitmq] class DefaultRabbitMQClientFactory[F[_]: ConcurrentEffect: Tim private def prepareProducer[A: ProductConverter](producerConfig: ProducerConfig, connection: RabbitMQConnection[F], monitor: Monitor[F]): Resource[F, BaseRabbitMQProducer[F, A]] = { - val logger = ImplicitContextLogger.createLogger[F, BaseRabbitMQProducer[F, A]] + producerConfig.properties.confirms match { + case Some(PublisherConfirmsConfig(true, sendAttempts)) => + prepareProducer(producerConfig, connection) { (defaultProperties, channel, logger) => + Ref.of(Map.empty[Long, Deferred[F, Either[NotAcknowledgedPublish, Unit]]]) + .map { + new PublishConfirmsRabbitMQProducer[F, A]( + producerConfig.name, + producerConfig.exchange, + channel, + defaultProperties, + _, + sendAttempts, + producerConfig.reportUnroutable, + producerConfig.sizeLimitBytes, + blocker, + logger, + monitor) + } + } + case _ => + prepareProducer(producerConfig, connection) { (defaultProperties, channel, logger) => + F.pure { + new DefaultRabbitMQProducer[F, A](producerConfig.name, + producerConfig.exchange, + channel, + defaultProperties, + producerConfig.reportUnroutable, + producerConfig.sizeLimitBytes, + blocker, + logger, + monitor) + } + } + } + } + + private def prepareProducer[T: ClassTag, A: ProductConverter](producerConfig: ProducerConfig, connection: RabbitMQConnection[F])( + createProducer: (MessageProperties, ServerChannel, ImplicitContextLogger[F]) => F[T]) = { + val logger: ImplicitContextLogger[F] = ImplicitContextLogger.createLogger[F, T] connection .newChannel() @@ -274,44 +313,14 @@ private[rabbitmq] class DefaultRabbitMQClientFactory[F[_]: ConcurrentEffect: Tim // auto declare exchange; if configured producerConfig.declare.map { declareExchange(producerConfig.exchange, channel, _)(logger) }.getOrElse(F.unit) } - .evalMap[F, BaseRabbitMQProducer[F, A]] { channel => + .evalMap[F, T] { channel => val defaultProperties = MessageProperties( deliveryMode = DeliveryMode.fromCode(producerConfig.properties.deliveryMode), contentType = producerConfig.properties.contentType, contentEncoding = producerConfig.properties.contentEncoding, priority = producerConfig.properties.priority.map(Integer.valueOf) - ) - - producerConfig.properties.confirms match { - case Some(PublisherConfirmsConfig(true, sendAttempts)) => - Ref.of(Map.empty[Long, Deferred[F, Either[Throwable, Unit]]]) - .map { - new PublishConfirmsRabbitMQProducer[F, A]( - producerConfig.name, - producerConfig.exchange, - channel, - defaultProperties, - _, - sendAttempts, - producerConfig.reportUnroutable, - producerConfig.sizeLimitBytes, - blocker, - logger, - monitor) - } - case _ => - F.pure { - new DefaultRabbitMQProducer[F, A](producerConfig.name, - producerConfig.exchange, - channel, - defaultProperties, - producerConfig.reportUnroutable, - producerConfig.sizeLimitBytes, - blocker, - logger, - monitor) - } - } + ) + createProducer(defaultProperties, channel, logger) } } diff --git a/core/src/main/scala/com/avast/clients/rabbitmq/publisher/BaseRabbitMQProducer.scala b/core/src/main/scala/com/avast/clients/rabbitmq/publisher/BaseRabbitMQProducer.scala index 73d4abb5..ecd68611 100644 --- a/core/src/main/scala/com/avast/clients/rabbitmq/publisher/BaseRabbitMQProducer.scala +++ b/core/src/main/scala/com/avast/clients/rabbitmq/publisher/BaseRabbitMQProducer.scala @@ -57,7 +57,7 @@ abstract class BaseRabbitMQProducer[F[_], A: ProductConverter](name: String, case Right(convertedBody) => for { _ <- checkSize(convertedBody, routingKey) - _ <- logErrors(sendMessage(routingKey, convertedBody, finalProperties), routingKey) + _ <- processErrors(sendMessage(routingKey, convertedBody, finalProperties), routingKey) } yield () case Left(ce) => Sync[F].raiseError(ce) } @@ -76,7 +76,7 @@ abstract class BaseRabbitMQProducer[F[_], A: ProductConverter](name: String, } yield () } - private def logErrors(from: F[Unit], routingKey: String)(implicit correlationId: CorrelationId): F[Unit] = { + private def processErrors(from: F[Unit], routingKey: String)(implicit correlationId: CorrelationId): F[Unit] = { from.recoverWith { case ce: AlreadyClosedException => logger.debug(ce)(s"[$name] Failed to send message with routing key '$routingKey' to exchange '$exchangeName'") >> diff --git a/core/src/main/scala/com/avast/clients/rabbitmq/publisher/PublishConfirmsRabbitMQProducer.scala b/core/src/main/scala/com/avast/clients/rabbitmq/publisher/PublishConfirmsRabbitMQProducer.scala index 287ad3e0..196f118e 100644 --- a/core/src/main/scala/com/avast/clients/rabbitmq/publisher/PublishConfirmsRabbitMQProducer.scala +++ b/core/src/main/scala/com/avast/clients/rabbitmq/publisher/PublishConfirmsRabbitMQProducer.scala @@ -5,7 +5,7 @@ import cats.effect.{Blocker, ConcurrentEffect, ContextShift} import cats.syntax.flatMap._ import cats.syntax.functor._ import com.avast.bytes.Bytes -import com.avast.clients.rabbitmq.api.{MaxAttempts, MessageProperties, NotAcknowledgedPublish} +import com.avast.clients.rabbitmq.api.{MaxAttemptsReached, MessageProperties, NotAcknowledgedPublish} import com.avast.clients.rabbitmq.logging.ImplicitContextLogger import com.avast.clients.rabbitmq.publisher.PublishConfirmsRabbitMQProducer.SentMessages import com.avast.clients.rabbitmq.{CorrelationId, ProductConverter, ServerChannel, startAndForget} @@ -46,11 +46,11 @@ class PublishConfirmsRabbitMQProducer[F[_], A: ProductConverter](name: String, implicit correlationId: CorrelationId): F[Unit] = { if (attemptCount > sendAttempts) { - F.raiseError(MaxAttempts("Exhausted max number of attempts")) + F.raiseError(MaxAttemptsReached("Exhausted max number of attempts")) } else { val messageId = channel.getNextPublishSeqNo for { - defer <- Deferred.apply[F, Either[Throwable, Unit]] + defer <- Deferred.apply[F, Either[NotAcknowledgedPublish, Unit]] _ <- sentMessages.update(_ + (messageId -> defer)) _ <- basicSend(routingKey, body, properties) result <- defer.get @@ -59,7 +59,7 @@ class PublishConfirmsRabbitMQProducer[F[_], A: ProductConverter](name: String, val sendResult = if (sendAttempts > 1) { clearProcessedMessage(messageId) >> sendWithAck(routingKey, body, properties, attemptCount + 1) } else { - F.raiseError(NotAcknowledgedPublish(s"Broker did not acknowledge publish of message $messageId", err)) + F.raiseError(err) } nacked.mark >> sendResult @@ -87,11 +87,11 @@ class PublishConfirmsRabbitMQProducer[F[_], A: ProductConverter](name: String, startAndForget { logger.plainTrace(s"Not acked $deliveryTag") >> completeDefer( deliveryTag, - Left(new Exception(s"Message $deliveryTag not acknowledged by broker"))) + Left(NotAcknowledgedPublish(s"Message $deliveryTag not acknowledged by broker", messageId = deliveryTag))) } } - private def completeDefer(deliveryTag: Long, result: Either[Throwable, Unit]): F[Unit] = { + private def completeDefer(deliveryTag: Long, result: Either[NotAcknowledgedPublish, Unit]): F[Unit] = { sentMessages.get.flatMap(_.get(deliveryTag).traverse_(_.complete(result))) } } @@ -99,5 +99,5 @@ class PublishConfirmsRabbitMQProducer[F[_], A: ProductConverter](name: String, } object PublishConfirmsRabbitMQProducer { - type SentMessages[F[_]] = Ref[F, Map[Long, Deferred[F, Either[Throwable, Unit]]]] + type SentMessages[F[_]] = Ref[F, Map[Long, Deferred[F, Either[NotAcknowledgedPublish, Unit]]]] } diff --git a/core/src/test/scala/com/avast/clients/rabbitmq/DefaultRabbitMQProducerTest.scala b/core/src/test/scala/com/avast/clients/rabbitmq/DefaultRabbitMQProducerTest.scala index 0c3da8a3..94416b92 100644 --- a/core/src/test/scala/com/avast/clients/rabbitmq/DefaultRabbitMQProducerTest.scala +++ b/core/src/test/scala/com/avast/clients/rabbitmq/DefaultRabbitMQProducerTest.scala @@ -1,5 +1,6 @@ package com.avast.clients.rabbitmq +import cats.effect.ConcurrentEffect import com.avast.bytes.Bytes import com.avast.clients.rabbitmq.api._ import com.avast.clients.rabbitmq.logging.ImplicitContextLogger @@ -192,7 +193,7 @@ class DefaultRabbitMQProducerTest extends TestBase { reportUnroutable = false, sizeLimitBytes = Some(limit), blocker = TestBase.testBlocker, - logger = ImplicitContextLogger.createLogger, + logger = ImplicitContextLogger.createLogger ) // don't test anything except it doesn't fail diff --git a/core/src/test/scala/com/avast/clients/rabbitmq/PublisherConfirmsRabbitMQProducerTest.scala b/core/src/test/scala/com/avast/clients/rabbitmq/PublisherConfirmsRabbitMQProducerTest.scala new file mode 100644 index 00000000..c0b5cf68 --- /dev/null +++ b/core/src/test/scala/com/avast/clients/rabbitmq/PublisherConfirmsRabbitMQProducerTest.scala @@ -0,0 +1,90 @@ +package com.avast.clients.rabbitmq + +import cats.effect.concurrent.{Deferred, Ref} +import cats.syntax.parallel._ +import com.avast.bytes.Bytes +import com.avast.clients.rabbitmq.api.{MessageProperties, NotAcknowledgedPublish} +import com.avast.clients.rabbitmq.logging.ImplicitContextLogger +import com.avast.clients.rabbitmq.publisher.PublishConfirmsRabbitMQProducer +import com.avast.clients.rabbitmq.publisher.PublishConfirmsRabbitMQProducer.SentMessages +import com.avast.metrics.scalaeffectapi.Monitor +import com.rabbitmq.client.impl.recovery.AutorecoveringChannel +import monix.eval.Task +import monix.execution.Scheduler.Implicits.global +import org.mockito.Matchers +import org.mockito.Matchers.any +import org.mockito.Mockito.{times, verify, when} + +import scala.util.Random + +class PublisherConfirmsRabbitMQProducerTest extends TestBase { + test("message is acked after one retry") { + val exchangeName = Random.nextString(10) + val routingKey = Random.nextString(10) + val seqNumber = 1L + val seqNumber2 = 2L + + val channel = mock[AutorecoveringChannel] + val ref = Ref.of[Task, Map[Long, Deferred[Task, Either[NotAcknowledgedPublish, Unit]]]](Map.empty).await + val updatedState1 = updateMessageState(ref, seqNumber)(Left(NotAcknowledgedPublish("abcd", messageId = seqNumber))) + val updatedState2 = updateMessageState(ref, seqNumber2)(Right()) + + val producer = new PublishConfirmsRabbitMQProducer[Task, Bytes]( + name = "test", + exchangeName = exchangeName, + channel = channel, + monitor = Monitor.noOp(), + defaultProperties = MessageProperties.empty, + reportUnroutable = false, + sizeLimitBytes = None, + blocker = TestBase.testBlocker, + logger = ImplicitContextLogger.createLogger, + sentMessages = ref, + sendAttempts = 2 + ) + when(channel.getNextPublishSeqNo).thenReturn(seqNumber, seqNumber2) + + producer.send(routingKey, Bytes.copyFrom(Array.fill(499)(32.toByte))).parProduct(updatedState1.parProduct(updatedState2)).await + + verify(channel, times(2)) + .basicPublish(Matchers.eq(exchangeName), Matchers.eq(routingKey), any(), Matchers.eq(Bytes.copyFrom(Array.fill(499)(32.toByte)).toByteArray)) + } + + test("Message not acked returned if number of attempts exhausted") { + val exchangeName = Random.nextString(10) + val routingKey = Random.nextString(10) + val seqNumber = 1L + + val channel = mock[AutorecoveringChannel] + val ref = Ref.of[Task, Map[Long, Deferred[Task, Either[NotAcknowledgedPublish, Unit]]]](Map.empty).await + val updatedState = updateMessageState(ref, seqNumber)(Left(NotAcknowledgedPublish("abcd", messageId = seqNumber))) + + val producer = new PublishConfirmsRabbitMQProducer[Task, Bytes]( + name = "test", + exchangeName = exchangeName, + channel = channel, + monitor = Monitor.noOp(), + defaultProperties = MessageProperties.empty, + reportUnroutable = false, + sizeLimitBytes = None, + blocker = TestBase.testBlocker, + logger = ImplicitContextLogger.createLogger, + sentMessages = ref, + sendAttempts = 1 + ) + when(channel.getNextPublishSeqNo).thenReturn(seqNumber) + + assertThrows[NotAcknowledgedPublish] { + producer.send(routingKey, Bytes.copyFrom(Array.fill(499)(32.toByte))).parProduct(updatedState).await + } + + verify(channel).basicPublish(Matchers.eq(exchangeName), Matchers.eq(routingKey), any(), Matchers.eq(Bytes.copyFrom(Array.fill(499)(32.toByte)).toByteArray)) + } + + private def updateMessageState(ref: SentMessages[Task], messageId: Long)(result: Either[NotAcknowledgedPublish, Unit]): Task[Unit] = { + ref.get.flatMap(map => map.get(messageId) match { + case Some(value) => value.complete(result) + case None => updateMessageState(ref, messageId)(result) + }) + } +}