Skip to content

Commit

Permalink
PR fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Karel Fajkus committed Dec 13, 2022
1 parent 3b103df commit bd1a667
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ case class ChannelNotRecoveredException(desc: String, cause: Throwable = null) e

case class TooBigMessage(desc: String, cause: Throwable = null) extends IllegalArgumentException(desc, cause)

case class MaxAttempts(desc: String, cause: Throwable = null) extends RuntimeException(desc, cause)
case class MaxAttemptsReached(desc: String, cause: Throwable = null) extends RuntimeException(desc, cause)

case class NotAcknowledgedPublish(desc: String, cause: Throwable = null) extends RuntimeException(desc, cause)
case class NotAcknowledgedPublish(desc: String, cause: Throwable = null, messageId: Long) extends RuntimeException(desc, cause)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import scala.collection.compat._
import scala.collection.immutable
import scala.jdk.CollectionConverters._
import scala.language.implicitConversions
import scala.reflect.ClassTag

private[rabbitmq] class DefaultRabbitMQClientFactory[F[_]: ConcurrentEffect: Timer: ContextShift](
connection: RabbitMQConnection[F],
Expand Down Expand Up @@ -266,52 +267,60 @@ private[rabbitmq] class DefaultRabbitMQClientFactory[F[_]: ConcurrentEffect: Tim
private def prepareProducer[A: ProductConverter](producerConfig: ProducerConfig,
connection: RabbitMQConnection[F],
monitor: Monitor[F]): Resource[F, BaseRabbitMQProducer[F, A]] = {
val logger = ImplicitContextLogger.createLogger[F, BaseRabbitMQProducer[F, A]]
producerConfig.properties.confirms match {
case Some(PublisherConfirmsConfig(true, sendAttempts)) =>
prepareProducer(producerConfig, connection) { (defaultProperties, channel, logger) =>
Ref.of(Map.empty[Long, Deferred[F, Either[NotAcknowledgedPublish, Unit]]])
.map {
new PublishConfirmsRabbitMQProducer[F, A](
producerConfig.name,
producerConfig.exchange,
channel,
defaultProperties,
_,
sendAttempts,
producerConfig.reportUnroutable,
producerConfig.sizeLimitBytes,
blocker,
logger,
monitor)
}
}
case _ =>
prepareProducer(producerConfig, connection) { (defaultProperties, channel, logger) =>
F.pure {
new DefaultRabbitMQProducer[F, A](producerConfig.name,
producerConfig.exchange,
channel,
defaultProperties,
producerConfig.reportUnroutable,
producerConfig.sizeLimitBytes,
blocker,
logger,
monitor)
}
}
}
}

private def prepareProducer[T: ClassTag, A: ProductConverter](producerConfig: ProducerConfig, connection: RabbitMQConnection[F])(
createProducer: (MessageProperties, ServerChannel, ImplicitContextLogger[F]) => F[T]) = {
val logger: ImplicitContextLogger[F] = ImplicitContextLogger.createLogger[F, T]

connection
.newChannel()
.evalTap { channel =>
// auto declare exchange; if configured
producerConfig.declare.map { declareExchange(producerConfig.exchange, channel, _)(logger) }.getOrElse(F.unit)
}
.evalMap[F, BaseRabbitMQProducer[F, A]] { channel =>
.evalMap[F, T] { channel =>
val defaultProperties = MessageProperties(
deliveryMode = DeliveryMode.fromCode(producerConfig.properties.deliveryMode),
contentType = producerConfig.properties.contentType,
contentEncoding = producerConfig.properties.contentEncoding,
priority = producerConfig.properties.priority.map(Integer.valueOf)
)

producerConfig.properties.confirms match {
case Some(PublisherConfirmsConfig(true, sendAttempts)) =>
Ref.of(Map.empty[Long, Deferred[F, Either[Throwable, Unit]]])
.map {
new PublishConfirmsRabbitMQProducer[F, A](
producerConfig.name,
producerConfig.exchange,
channel,
defaultProperties,
_,
sendAttempts,
producerConfig.reportUnroutable,
producerConfig.sizeLimitBytes,
blocker,
logger,
monitor)
}
case _ =>
F.pure {
new DefaultRabbitMQProducer[F, A](producerConfig.name,
producerConfig.exchange,
channel,
defaultProperties,
producerConfig.reportUnroutable,
producerConfig.sizeLimitBytes,
blocker,
logger,
monitor)
}
}
)
createProducer(defaultProperties, channel, logger)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ abstract class BaseRabbitMQProducer[F[_], A: ProductConverter](name: String,
case Right(convertedBody) =>
for {
_ <- checkSize(convertedBody, routingKey)
_ <- logErrors(sendMessage(routingKey, convertedBody, finalProperties), routingKey)
_ <- processErrors(sendMessage(routingKey, convertedBody, finalProperties), routingKey)
} yield ()
case Left(ce) => Sync[F].raiseError(ce)
}
Expand All @@ -76,7 +76,7 @@ abstract class BaseRabbitMQProducer[F[_], A: ProductConverter](name: String,
} yield ()
}

private def logErrors(from: F[Unit], routingKey: String)(implicit correlationId: CorrelationId): F[Unit] = {
private def processErrors(from: F[Unit], routingKey: String)(implicit correlationId: CorrelationId): F[Unit] = {
from.recoverWith {
case ce: AlreadyClosedException =>
logger.debug(ce)(s"[$name] Failed to send message with routing key '$routingKey' to exchange '$exchangeName'") >>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import cats.effect.{Blocker, ConcurrentEffect, ContextShift}
import cats.syntax.flatMap._
import cats.syntax.functor._
import com.avast.bytes.Bytes
import com.avast.clients.rabbitmq.api.{MaxAttempts, MessageProperties, NotAcknowledgedPublish}
import com.avast.clients.rabbitmq.api.{MaxAttemptsReached, MessageProperties, NotAcknowledgedPublish}
import com.avast.clients.rabbitmq.logging.ImplicitContextLogger
import com.avast.clients.rabbitmq.publisher.PublishConfirmsRabbitMQProducer.SentMessages
import com.avast.clients.rabbitmq.{CorrelationId, ProductConverter, ServerChannel, startAndForget}
Expand Down Expand Up @@ -46,11 +46,11 @@ class PublishConfirmsRabbitMQProducer[F[_], A: ProductConverter](name: String,
implicit correlationId: CorrelationId): F[Unit] = {

if (attemptCount > sendAttempts) {
F.raiseError(MaxAttempts("Exhausted max number of attempts"))
F.raiseError(MaxAttemptsReached("Exhausted max number of attempts"))
} else {
val messageId = channel.getNextPublishSeqNo
for {
defer <- Deferred.apply[F, Either[Throwable, Unit]]
defer <- Deferred.apply[F, Either[NotAcknowledgedPublish, Unit]]
_ <- sentMessages.update(_ + (messageId -> defer))
_ <- basicSend(routingKey, body, properties)
result <- defer.get
Expand All @@ -59,7 +59,7 @@ class PublishConfirmsRabbitMQProducer[F[_], A: ProductConverter](name: String,
val sendResult = if (sendAttempts > 1) {
clearProcessedMessage(messageId) >> sendWithAck(routingKey, body, properties, attemptCount + 1)
} else {
F.raiseError(NotAcknowledgedPublish(s"Broker did not acknowledge publish of message $messageId", err))
F.raiseError(err)
}

nacked.mark >> sendResult
Expand Down Expand Up @@ -87,17 +87,17 @@ class PublishConfirmsRabbitMQProducer[F[_], A: ProductConverter](name: String,
startAndForget {
logger.plainTrace(s"Not acked $deliveryTag") >> completeDefer(
deliveryTag,
Left(new Exception(s"Message $deliveryTag not acknowledged by broker")))
Left(NotAcknowledgedPublish(s"Message $deliveryTag not acknowledged by broker", messageId = deliveryTag)))
}
}

private def completeDefer(deliveryTag: Long, result: Either[Throwable, Unit]): F[Unit] = {
private def completeDefer(deliveryTag: Long, result: Either[NotAcknowledgedPublish, Unit]): F[Unit] = {
sentMessages.get.flatMap(_.get(deliveryTag).traverse_(_.complete(result)))
}
}

}

object PublishConfirmsRabbitMQProducer {
type SentMessages[F[_]] = Ref[F, Map[Long, Deferred[F, Either[Throwable, Unit]]]]
type SentMessages[F[_]] = Ref[F, Map[Long, Deferred[F, Either[NotAcknowledgedPublish, Unit]]]]
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.avast.clients.rabbitmq

import cats.effect.ConcurrentEffect
import com.avast.bytes.Bytes
import com.avast.clients.rabbitmq.api._
import com.avast.clients.rabbitmq.logging.ImplicitContextLogger
Expand Down Expand Up @@ -192,7 +193,7 @@ class DefaultRabbitMQProducerTest extends TestBase {
reportUnroutable = false,
sizeLimitBytes = Some(limit),
blocker = TestBase.testBlocker,
logger = ImplicitContextLogger.createLogger,
logger = ImplicitContextLogger.createLogger
)

// don't test anything except it doesn't fail
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package com.avast.clients.rabbitmq

import cats.effect.concurrent.{Deferred, Ref}
import cats.syntax.parallel._
import com.avast.bytes.Bytes
import com.avast.clients.rabbitmq.api.{MessageProperties, NotAcknowledgedPublish}
import com.avast.clients.rabbitmq.logging.ImplicitContextLogger
import com.avast.clients.rabbitmq.publisher.PublishConfirmsRabbitMQProducer
import com.avast.clients.rabbitmq.publisher.PublishConfirmsRabbitMQProducer.SentMessages
import com.avast.metrics.scalaeffectapi.Monitor
import com.rabbitmq.client.impl.recovery.AutorecoveringChannel
import monix.eval.Task
import monix.execution.Scheduler.Implicits.global
import org.mockito.Matchers
import org.mockito.Matchers.any
import org.mockito.Mockito.{times, verify, when}

import scala.util.Random

class PublisherConfirmsRabbitMQProducerTest extends TestBase {
test("message is acked after one retry") {
val exchangeName = Random.nextString(10)
val routingKey = Random.nextString(10)
val seqNumber = 1L
val seqNumber2 = 2L

val channel = mock[AutorecoveringChannel]
val ref = Ref.of[Task, Map[Long, Deferred[Task, Either[NotAcknowledgedPublish, Unit]]]](Map.empty).await
val updatedState1 = updateMessageState(ref, seqNumber)(Left(NotAcknowledgedPublish("abcd", messageId = seqNumber)))
val updatedState2 = updateMessageState(ref, seqNumber2)(Right())

val producer = new PublishConfirmsRabbitMQProducer[Task, Bytes](
name = "test",
exchangeName = exchangeName,
channel = channel,
monitor = Monitor.noOp(),
defaultProperties = MessageProperties.empty,
reportUnroutable = false,
sizeLimitBytes = None,
blocker = TestBase.testBlocker,
logger = ImplicitContextLogger.createLogger,
sentMessages = ref,
sendAttempts = 2
)
when(channel.getNextPublishSeqNo).thenReturn(seqNumber, seqNumber2)

producer.send(routingKey, Bytes.copyFrom(Array.fill(499)(32.toByte))).parProduct(updatedState1.parProduct(updatedState2)).await

verify(channel, times(2))
.basicPublish(Matchers.eq(exchangeName), Matchers.eq(routingKey), any(), Matchers.eq(Bytes.copyFrom(Array.fill(499)(32.toByte)).toByteArray))
}

test("Message not acked returned if number of attempts exhausted") {
val exchangeName = Random.nextString(10)
val routingKey = Random.nextString(10)
val seqNumber = 1L

val channel = mock[AutorecoveringChannel]
val ref = Ref.of[Task, Map[Long, Deferred[Task, Either[NotAcknowledgedPublish, Unit]]]](Map.empty).await
val updatedState = updateMessageState(ref, seqNumber)(Left(NotAcknowledgedPublish("abcd", messageId = seqNumber)))

val producer = new PublishConfirmsRabbitMQProducer[Task, Bytes](
name = "test",
exchangeName = exchangeName,
channel = channel,
monitor = Monitor.noOp(),
defaultProperties = MessageProperties.empty,
reportUnroutable = false,
sizeLimitBytes = None,
blocker = TestBase.testBlocker,
logger = ImplicitContextLogger.createLogger,
sentMessages = ref,
sendAttempts = 1
)
when(channel.getNextPublishSeqNo).thenReturn(seqNumber)

assertThrows[NotAcknowledgedPublish] {
producer.send(routingKey, Bytes.copyFrom(Array.fill(499)(32.toByte))).parProduct(updatedState).await
}

verify(channel).basicPublish(Matchers.eq(exchangeName), Matchers.eq(routingKey), any(), Matchers.eq(Bytes.copyFrom(Array.fill(499)(32.toByte)).toByteArray))
}

private def updateMessageState(ref: SentMessages[Task], messageId: Long)(result: Either[NotAcknowledgedPublish, Unit]): Task[Unit] = {
ref.get.flatMap(map => map.get(messageId) match {
case Some(value) => value.complete(result)
case None => updateMessageState(ref, messageId)(result)
})
}
}

0 comments on commit bd1a667

Please sign in to comment.