diff --git a/src/main/java/io/aiven/kafka/connect/opensearch/BulkProcessor.java b/src/main/java/io/aiven/kafka/connect/opensearch/BulkProcessor.java index 91f37009..154e491d 100644 --- a/src/main/java/io/aiven/kafka/connect/opensearch/BulkProcessor.java +++ b/src/main/java/io/aiven/kafka/connect/opensearch/BulkProcessor.java @@ -383,6 +383,18 @@ private void sendToErrantRecordReporter(final String errorMessage, final SinkRec } private BulkResponse execute() throws Exception { + class RetriableError extends RuntimeException { + private static final long serialVersionUID = 1L; + + public RetriableError(final String errorMessage) { + super(errorMessage); + } + + public RetriableError(final Throwable cause) { + super(cause); + } + } + return callWithRetry("bulk processing", () -> { try { final var response = @@ -403,7 +415,7 @@ private BulkResponse execute() throws Exception { } else if (responseContainsVersionConflict(itemResponse)) { handleVersionConflict(itemResponse); } else { - throw new RuntimeException( + throw new RetriableError( "One of the item in the bulk response failed. Reason: " + itemResponse.getFailureMessage()); } @@ -418,9 +430,9 @@ private BulkResponse execute() throws Exception { } catch (final IOException e) { LOGGER.error( "Failed to send bulk request from batch {} of {} records", batchId, batch.size(), e); - throw new ConnectException(e); + throw new RetriableError(e); } - }, maxRetries, retryBackoffMs, RuntimeException.class); + }, maxRetries, retryBackoffMs, RetriableError.class); } private void handleVersionConflict(final BulkItemResponse bulkItemResponse) { diff --git a/src/main/java/io/aiven/kafka/connect/opensearch/RetryUtil.java b/src/main/java/io/aiven/kafka/connect/opensearch/RetryUtil.java index 91d970bd..d3e2ea3e 100644 --- a/src/main/java/io/aiven/kafka/connect/opensearch/RetryUtil.java +++ b/src/main/java/io/aiven/kafka/connect/opensearch/RetryUtil.java @@ -118,7 +118,14 @@ public static T callWithRetry( LOGGER.trace("Try {} with attempt {}/{}", callName, attempts, maxAttempts); return callable.call(); } catch (final Exception e) { - if (attempts < maxAttempts && e.getClass().equals(repeatableException)) { + if (!repeatableException.isAssignableFrom(e.getClass())) { + final var msg = String.format( + "Non-repeatable exception trown by %s", + callName + ); + LOGGER.error(msg, e); + throw new ConnectException(msg, e); + } else if (attempts < maxAttempts) { final long sleepTimeMs = computeRandomRetryWaitTimeInMillis(retryAttempts, retryBackoffMs); final var msg = String.format( diff --git a/src/test/java/io/aiven/kafka/connect/opensearch/RetryUtilTest.java b/src/test/java/io/aiven/kafka/connect/opensearch/RetryUtilTest.java index 8db0e62d..758bebd0 100644 --- a/src/test/java/io/aiven/kafka/connect/opensearch/RetryUtilTest.java +++ b/src/test/java/io/aiven/kafka/connect/opensearch/RetryUtilTest.java @@ -17,9 +17,14 @@ package io.aiven.kafka.connect.opensearch; +import java.io.IOException; + +import org.apache.kafka.connect.errors.ConnectException; + import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class RetryUtilTest { @@ -49,6 +54,53 @@ public void computeNonRandomRetryTimes() { assertEquals(3200L, RetryUtil.computeRetryWaitTimeInMillis(5, 100L)); } + @Test + public void callWithRetryRetriableError() { + final int[] attempt = new int[1]; + final int maxRetries = 3; + final int res = RetryUtil.callWithRetry("test callWithRetryRetriableError", () -> { + if (attempt[0] < maxRetries) { + ++attempt[0]; + throw new ArithmeticException(); + } + return attempt[0]; + }, maxRetries, 1L, RuntimeException.class); + + assertEquals(maxRetries, res); + } + + @Test + public void callWithRetryMaxRetries() { + final int[] attempt = new int[1]; + final int maxRetries = 3; + assertThrows( + ConnectException.class, + () -> { + RetryUtil.callWithRetry("test callWithRetryMaxRetries", () -> { + ++attempt[0]; + throw new ArithmeticException(); + }, maxRetries, 1L, RuntimeException.class); + }); + + assertEquals(maxRetries + 1, attempt[0]); + } + + @Test + public void callWithRetryNonRetriableError() { + final int[] attempt = new int[1]; + final int maxRetries = 3; + assertThrows( + ConnectException.class, + () -> { + RetryUtil.callWithRetry("test callWithRetryNonRetriableError", () -> { + ++attempt[0]; + throw new ArithmeticException(); + }, maxRetries, 1L, IOException.class); + }); + + assertEquals(1, attempt[0]); + } + protected void assertComputeRetryInRange(final int retryAttempts, final long retryBackoffMs) { for (int i = 0; i != 20; ++i) { for (int retries = 0; retries <= retryAttempts; ++retries) {