diff --git a/src/main/java/com/bakdata/kserve/client/KServeClient.java b/src/main/java/com/bakdata/kserve/client/KServeClient.java
index 1b44925..b893f38 100644
--- a/src/main/java/com/bakdata/kserve/client/KServeClient.java
+++ b/src/main/java/com/bakdata/kserve/client/KServeClient.java
@@ -1,7 +1,7 @@
/*
* MIT License
*
- * Copyright (c) 2022 bakdata
+ * Copyright (c) 2024 bakdata GmbH
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
@@ -30,6 +30,11 @@
import io.github.resilience4j.core.IntervalFunction;
import io.github.resilience4j.retry.Retry;
import io.github.resilience4j.retry.RetryConfig;
+import java.io.IOException;
+import java.net.HttpURLConnection;
+import java.time.Duration;
+import java.util.Optional;
+import java.util.concurrent.Callable;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.HttpUrl;
@@ -41,17 +46,11 @@
import okhttp3.Response;
import org.jetbrains.annotations.NotNull;
-import java.io.IOException;
-import java.net.HttpURLConnection;
-import java.time.Duration;
-import java.util.Optional;
-import java.util.concurrent.Callable;
-
/**
*
An abstract client base class to make requests to a KServe inference service.
* It builds the request, handles the response and deals with various errors from the inference service. It
- * automatically retries requests in case they time out, e.g. due to the inference service being scaled to zero at
- * the time of the request. The exponential random back-off retry mechanism can be configured using the following
+ * automatically retries requests in case they time out, e.g. due to the inference service being scaled to zero at the
+ * time of the request. The exponential random back-off retry mechanism can be configured using the following
* environment variables:
*
* - KSERVE_RETRY_MAX_ATTEMPTS: The maximum number of retry attempts.
@@ -68,6 +67,7 @@
@RequiredArgsConstructor
public abstract class KServeClient {
+ protected static final ObjectMapper OBJECT_MAPPER = createObjectMapper();
private static final int RETRY_MAX_ATTEMPTS = Optional.ofNullable(System.getenv("KSERVE_RETRY_MAX_ATTEMPTS"))
.map(Integer::parseInt)
.orElse(10);
@@ -82,47 +82,11 @@ public abstract class KServeClient {
Optional.ofNullable(System.getenv("KSERVE_RETRY_MAX_INTERVAL"))
.map(Integer::parseInt).map(Duration::ofMillis)
.orElse(Duration.ofMillis(16000));
-
- protected static final ObjectMapper OBJECT_MAPPER = createObjectMapper();
private final String service;
private final String modelName;
private final OkHttpClient httpClient;
private final boolean httpsEnabled;
- @Slf4j
- private static class RetryInterceptor implements Interceptor {
- @NotNull
- @Override
- public Response intercept(final Chain chain) throws IOException {
- final Request request = chain.request();
-
- // wait_interval = min(max_interval, (initial_interval * multiplier^n) +/- (random_interval))
- final IntervalFunction intervalFn = IntervalFunction.ofExponentialRandomBackoff(
- RETRY_INITIAL_INTERVAL, RETRY_MULTIPLIER, IntervalFunction.DEFAULT_RANDOMIZATION_FACTOR,
- RETRY_MAX_INTERVAL);
-
- final RetryConfig retryConfig = RetryConfig.custom()
- // IOException may be thrown by chain.proceed(request)
- .retryExceptions(IOException.class)
- .maxAttempts(RETRY_MAX_ATTEMPTS)
- .intervalFunction(intervalFn)
- .failAfterMaxAttempts(true)
- .build();
- final Retry retry = Retry.of("kserve-request-retry", retryConfig);
-
- final Callable requestCallable = Retry.decorateCallable(retry, () -> {
- log.debug("Making or retrying request {}.", request);
- return chain.proceed(request);
- });
-
- try {
- return requestCallable.call();
- } catch (final Exception e) {
- throw new IOException(e);
- }
- }
- }
-
protected static OkHttpClient getHttpClient(final Duration requestReadTimeout) {
return new OkHttpClient.Builder()
.readTimeout(requestReadTimeout)
@@ -147,6 +111,16 @@ private static Optional processJsonResponse(final String stringBody, fina
}
}
+ private static Request getRequest(final String bodyString, final HttpUrl url) {
+ final MediaType mediaType = MediaType.get("application/json; charset=utf-8");
+ final RequestBody requestBody = RequestBody
+ .create(bodyString, mediaType);
+ return new Request.Builder()
+ .url(url)
+ .post(requestBody)
+ .build();
+ }
+
/**
* Make a request to a KServe inference service and return the response.
*
@@ -205,14 +179,38 @@ private Optional processResponse(final Response response, final Class e
}
}
- private static Request getRequest(final String bodyString, final HttpUrl url) {
- final MediaType mediaType = MediaType.get("application/json; charset=utf-8");
- final RequestBody requestBody = RequestBody
- .create(bodyString, mediaType);
- return new Request.Builder()
- .url(url)
- .post(requestBody)
- .build();
+ @Slf4j
+ private static class RetryInterceptor implements Interceptor {
+ @NotNull
+ @Override
+ public Response intercept(final Chain chain) throws IOException {
+ final Request request = chain.request();
+
+ // wait_interval = min(max_interval, (initial_interval * multiplier^n) +/- (random_interval))
+ final IntervalFunction intervalFn = IntervalFunction.ofExponentialRandomBackoff(
+ RETRY_INITIAL_INTERVAL, RETRY_MULTIPLIER, IntervalFunction.DEFAULT_RANDOMIZATION_FACTOR,
+ RETRY_MAX_INTERVAL);
+
+ final RetryConfig retryConfig = RetryConfig.custom()
+ // IOException may be thrown by chain.proceed(request)
+ .retryExceptions(IOException.class)
+ .maxAttempts(RETRY_MAX_ATTEMPTS)
+ .intervalFunction(intervalFn)
+ .failAfterMaxAttempts(true)
+ .build();
+ final Retry retry = Retry.of("kserve-request-retry", retryConfig);
+
+ final Callable requestCallable = Retry.decorateCallable(retry, () -> {
+ log.debug("Making or retrying request {}.", request);
+ return chain.proceed(request);
+ });
+
+ try {
+ return requestCallable.call();
+ } catch (final Exception e) {
+ throw new IOException(e);
+ }
+ }
}
protected static final class InferenceRequestException extends IllegalArgumentException {
diff --git a/src/test/java/com/bakdata/kserve/client/KServeClientV1Test.java b/src/test/java/com/bakdata/kserve/client/KServeClientV1Test.java
index 8af6d1a..325551b 100644
--- a/src/test/java/com/bakdata/kserve/client/KServeClientV1Test.java
+++ b/src/test/java/com/bakdata/kserve/client/KServeClientV1Test.java
@@ -1,7 +1,7 @@
/*
* MIT License
*
- * Copyright (c) 2022 bakdata
+ * Copyright (c) 2024 bakdata GmbH
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
@@ -25,6 +25,8 @@
package com.bakdata.kserve.client;
import com.bakdata.kserve.client.KServeClient.InferenceRequestException;
+import java.io.IOException;
+import java.time.Duration;
import lombok.Getter;
import okhttp3.mockwebserver.Dispatcher;
import okhttp3.mockwebserver.MockResponse;
@@ -38,9 +40,6 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
-import java.io.IOException;
-import java.time.Duration;
-
@ExtendWith(SoftAssertionsExtension.class)
class KServeClientV1Test {
private com.bakdata.kserve.KServeMock mockServer = null;
@@ -63,8 +62,9 @@ void makeInferenceRequest() throws IOException {
.build();
this.softly.assertThat(client.makeInferenceRequest(new JSONObject("{ \"input\": \"data\" }"),
- FakePrediction.class, ""))
- .hasValueSatisfying(fakePrediction -> this.softly.assertThat(fakePrediction.getFake()).isEqualTo("data"));
+ FakePrediction.class, ""))
+ .hasValueSatisfying(
+ fakePrediction -> this.softly.assertThat(fakePrediction.getFake()).isEqualTo("data"));
}
@Test
@@ -78,7 +78,7 @@ void testPredictionNotExistingModel() {
.build();
this.softly.assertThatThrownBy(() -> client.makeInferenceRequest(new JSONObject("{ \"input\": \"data\" }"),
- FakePrediction.class, ""))
+ FakePrediction.class, ""))
.isInstanceOf(InferenceRequestException.class)
.hasMessage("Inference request failed: 404: Model with name model does not exist.");
}
@@ -108,9 +108,11 @@ public MockResponse dispatch(@NotNull final RecordedRequest recordedRequest) {
.build();
this.softly.assertThatThrownBy(() -> client.makeInferenceRequest(new JSONObject("{ \"input\": \"data\" }"),
- FakePrediction.class, ""))
+ FakePrediction.class, ""))
.isInstanceOf(KServeClientV1.InferenceRequestException.class)
- .hasMessage("Inference request failed: 400: Unrecognized request format: Expecting ',' delimiter: line 3 column 1 (char 48)");
+ .hasMessage(
+ "Inference request failed: 400: Unrecognized request format: Expecting ',' delimiter: line 3 "
+ + "column 1 (char 48)");
}
@Test
@@ -125,11 +127,11 @@ void testRetry() throws IOException {
.build();
this.softly.assertThat(client.makeInferenceRequest(new JSONObject("{ \"input\": \"data\" }"),
- CallCounterFakePrediction.class, ""))
+ CallCounterFakePrediction.class, ""))
.hasValueSatisfying(fakePrediction -> this.softly.assertThat(fakePrediction.getCounter()).isEqualTo(2));
this.softly.assertThat(client.makeInferenceRequest(new JSONObject("{ \"input\": \"data\" }"),
- CallCounterFakePrediction.class, ""))
+ CallCounterFakePrediction.class, ""))
.hasValueSatisfying(fakePrediction -> this.softly.assertThat(fakePrediction.getCounter()).isEqualTo(3));
}
diff --git a/src/test/java/com/bakdata/kserve/client/KServeClientV2Test.java b/src/test/java/com/bakdata/kserve/client/KServeClientV2Test.java
index 8df7f60..f5ac16f 100644
--- a/src/test/java/com/bakdata/kserve/client/KServeClientV2Test.java
+++ b/src/test/java/com/bakdata/kserve/client/KServeClientV2Test.java
@@ -1,7 +1,7 @@
/*
* MIT License
*
- * Copyright (c) 2022 bakdata
+ * Copyright (c) 2024 bakdata GmbH
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
@@ -30,6 +30,9 @@
import com.bakdata.kserve.predictv2.InferenceRequest;
import com.bakdata.kserve.predictv2.Parameters;
import com.bakdata.kserve.predictv2.RequestInput;
+import java.io.IOException;
+import java.time.Duration;
+import java.util.List;
import lombok.Getter;
import okhttp3.mockwebserver.Dispatcher;
import okhttp3.mockwebserver.MockResponse;
@@ -42,10 +45,6 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
-import java.io.IOException;
-import java.time.Duration;
-import java.util.List;
-
@ExtendWith(SoftAssertionsExtension.class)
class KServeClientV2Test {
private KServeMock mockServer = null;
@@ -82,7 +81,7 @@ void makeInferenceRequest() throws IOException {
.build();
this.softly.assertThat(client.makeInferenceRequest(getFakeInferenceRequest("data"),
- FakePrediction.class, ""))
+ FakePrediction.class, ""))
.map(FakePrediction::getFake)
.hasValue("data");
}
@@ -97,7 +96,8 @@ void testPredictionNotExistingModel() {
.httpClient(KServeClientV2.getHttpClient(Duration.ofMillis(10000)))
.build();
final InferenceRequest fakeInferenceRequest = getFakeInferenceRequest("data");
- this.softly.assertThatThrownBy(() -> client.makeInferenceRequest(fakeInferenceRequest, FakePrediction.class, "") )
+ this.softly.assertThatThrownBy(
+ () -> client.makeInferenceRequest(fakeInferenceRequest, FakePrediction.class, ""))
.isInstanceOf(InferenceRequestException.class)
.hasMessage("Inference request failed: Model test-model not found");
}
@@ -122,7 +122,8 @@ public MockResponse dispatch(@NotNull final RecordedRequest recordedRequest) {
.build();
final InferenceRequest fakeInferenceRequest = getFakeInferenceRequest("data");
- this.softly.assertThatThrownBy(() -> client.makeInferenceRequest(fakeInferenceRequest, FakePrediction.class, ""))
+ this.softly.assertThatThrownBy(
+ () -> client.makeInferenceRequest(fakeInferenceRequest, FakePrediction.class, ""))
.isInstanceOf(InferenceRequestException.class)
.hasMessage("Inference request failed: Not Found");
}
@@ -140,11 +141,11 @@ void testRetry() throws IOException {
final InferenceRequest fakeInferenceRequest = getFakeInferenceRequest("data");
this.softly.assertThat(client.makeInferenceRequest(fakeInferenceRequest,
- CallCounterFakePrediction.class, ""))
+ CallCounterFakePrediction.class, ""))
.hasValueSatisfying(fakePrediction -> this.softly.assertThat(fakePrediction.getCounter()).isEqualTo(2));
this.softly.assertThat(client.makeInferenceRequest(fakeInferenceRequest,
- CallCounterFakePrediction.class, ""))
+ CallCounterFakePrediction.class, ""))
.hasValueSatisfying(fakePrediction -> this.softly.assertThat(fakePrediction.getCounter()).isEqualTo(3));
}