Skip to content

Commit

Permalink
Apply formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobEdding committed Apr 17, 2024
1 parent 4b65145 commit 87515ea
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 74 deletions.
104 changes: 51 additions & 53 deletions src/main/java/com/bakdata/kserve/client/KServeClient.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* MIT License
*
* Copyright (c) 2022 bakdata
* Copyright (c) 2024 bakdata GmbH
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -30,6 +30,11 @@
import io.github.resilience4j.core.IntervalFunction;
import io.github.resilience4j.retry.Retry;
import io.github.resilience4j.retry.RetryConfig;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.Callable;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.HttpUrl;
Expand All @@ -41,17 +46,11 @@
import okhttp3.Response;
import org.jetbrains.annotations.NotNull;

import java.io.IOException;
import java.net.HttpURLConnection;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.Callable;

/**
* <p>An abstract client base class to make requests to a KServe inference service.</p>
* It builds the request, handles the response and deals with various errors from the inference service. It
* automatically retries requests in case they time out, e.g. due to the inference service being scaled to zero at
* the time of the request. The exponential random back-off retry mechanism can be configured using the following
* automatically retries requests in case they time out, e.g. due to the inference service being scaled to zero at the
* time of the request. The exponential random back-off retry mechanism can be configured using the following
* environment variables:
* <ul>
* <li>KSERVE_RETRY_MAX_ATTEMPTS: The maximum number of retry attempts.</li>
Expand All @@ -68,6 +67,7 @@
@RequiredArgsConstructor
public abstract class KServeClient<I> {

protected static final ObjectMapper OBJECT_MAPPER = createObjectMapper();
private static final int RETRY_MAX_ATTEMPTS = Optional.ofNullable(System.getenv("KSERVE_RETRY_MAX_ATTEMPTS"))
.map(Integer::parseInt)
.orElse(10);
Expand All @@ -82,47 +82,11 @@ public abstract class KServeClient<I> {
Optional.ofNullable(System.getenv("KSERVE_RETRY_MAX_INTERVAL"))
.map(Integer::parseInt).map(Duration::ofMillis)
.orElse(Duration.ofMillis(16000));

protected static final ObjectMapper OBJECT_MAPPER = createObjectMapper();
private final String service;
private final String modelName;
private final OkHttpClient httpClient;
private final boolean httpsEnabled;

@Slf4j
private static class RetryInterceptor implements Interceptor {
@NotNull
@Override
public Response intercept(final Chain chain) throws IOException {
final Request request = chain.request();

// wait_interval = min(max_interval, (initial_interval * multiplier^n) +/- (random_interval))
final IntervalFunction intervalFn = IntervalFunction.ofExponentialRandomBackoff(
RETRY_INITIAL_INTERVAL, RETRY_MULTIPLIER, IntervalFunction.DEFAULT_RANDOMIZATION_FACTOR,
RETRY_MAX_INTERVAL);

final RetryConfig retryConfig = RetryConfig.custom()
// IOException may be thrown by chain.proceed(request)
.retryExceptions(IOException.class)
.maxAttempts(RETRY_MAX_ATTEMPTS)
.intervalFunction(intervalFn)
.failAfterMaxAttempts(true)
.build();
final Retry retry = Retry.of("kserve-request-retry", retryConfig);

final Callable<Response> requestCallable = Retry.decorateCallable(retry, () -> {
log.debug("Making or retrying request {}.", request);
return chain.proceed(request);
});

try {
return requestCallable.call();
} catch (final Exception e) {
throw new IOException(e);
}
}
}

protected static OkHttpClient getHttpClient(final Duration requestReadTimeout) {
return new OkHttpClient.Builder()
.readTimeout(requestReadTimeout)
Expand All @@ -147,6 +111,16 @@ private static <T> Optional<T> processJsonResponse(final String stringBody, fina
}
}

private static Request getRequest(final String bodyString, final HttpUrl url) {
final MediaType mediaType = MediaType.get("application/json; charset=utf-8");
final RequestBody requestBody = RequestBody
.create(bodyString, mediaType);
return new Request.Builder()
.url(url)
.post(requestBody)
.build();
}

/**
* Make a request to a KServe inference service and return the response.
*
Expand Down Expand Up @@ -205,14 +179,38 @@ private <T> Optional<T> processResponse(final Response response, final Class<? e
}
}

private static Request getRequest(final String bodyString, final HttpUrl url) {
final MediaType mediaType = MediaType.get("application/json; charset=utf-8");
final RequestBody requestBody = RequestBody
.create(bodyString, mediaType);
return new Request.Builder()
.url(url)
.post(requestBody)
.build();
@Slf4j
private static class RetryInterceptor implements Interceptor {
@NotNull
@Override
public Response intercept(final Chain chain) throws IOException {
final Request request = chain.request();

// wait_interval = min(max_interval, (initial_interval * multiplier^n) +/- (random_interval))
final IntervalFunction intervalFn = IntervalFunction.ofExponentialRandomBackoff(
RETRY_INITIAL_INTERVAL, RETRY_MULTIPLIER, IntervalFunction.DEFAULT_RANDOMIZATION_FACTOR,
RETRY_MAX_INTERVAL);

final RetryConfig retryConfig = RetryConfig.custom()
// IOException may be thrown by chain.proceed(request)
.retryExceptions(IOException.class)
.maxAttempts(RETRY_MAX_ATTEMPTS)
.intervalFunction(intervalFn)
.failAfterMaxAttempts(true)
.build();
final Retry retry = Retry.of("kserve-request-retry", retryConfig);

final Callable<Response> requestCallable = Retry.decorateCallable(retry, () -> {
log.debug("Making or retrying request {}.", request);
return chain.proceed(request);
});

try {
return requestCallable.call();
} catch (final Exception e) {
throw new IOException(e);
}
}
}

protected static final class InferenceRequestException extends IllegalArgumentException {
Expand Down
24 changes: 13 additions & 11 deletions src/test/java/com/bakdata/kserve/client/KServeClientV1Test.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* MIT License
*
* Copyright (c) 2022 bakdata
* Copyright (c) 2024 bakdata GmbH
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -25,6 +25,8 @@
package com.bakdata.kserve.client;

import com.bakdata.kserve.client.KServeClient.InferenceRequestException;
import java.io.IOException;
import java.time.Duration;
import lombok.Getter;
import okhttp3.mockwebserver.Dispatcher;
import okhttp3.mockwebserver.MockResponse;
Expand All @@ -38,9 +40,6 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

import java.io.IOException;
import java.time.Duration;

@ExtendWith(SoftAssertionsExtension.class)
class KServeClientV1Test {
private com.bakdata.kserve.KServeMock mockServer = null;
Expand All @@ -63,8 +62,9 @@ void makeInferenceRequest() throws IOException {
.build();

this.softly.assertThat(client.makeInferenceRequest(new JSONObject("{ \"input\": \"data\" }"),
FakePrediction.class, ""))
.hasValueSatisfying(fakePrediction -> this.softly.assertThat(fakePrediction.getFake()).isEqualTo("data"));
FakePrediction.class, ""))
.hasValueSatisfying(
fakePrediction -> this.softly.assertThat(fakePrediction.getFake()).isEqualTo("data"));
}

@Test
Expand All @@ -78,7 +78,7 @@ void testPredictionNotExistingModel() {
.build();

this.softly.assertThatThrownBy(() -> client.makeInferenceRequest(new JSONObject("{ \"input\": \"data\" }"),
FakePrediction.class, ""))
FakePrediction.class, ""))
.isInstanceOf(InferenceRequestException.class)
.hasMessage("Inference request failed: 404: Model with name model does not exist.");
}
Expand Down Expand Up @@ -108,9 +108,11 @@ public MockResponse dispatch(@NotNull final RecordedRequest recordedRequest) {
.build();

this.softly.assertThatThrownBy(() -> client.makeInferenceRequest(new JSONObject("{ \"input\": \"data\" }"),
FakePrediction.class, ""))
FakePrediction.class, ""))
.isInstanceOf(KServeClientV1.InferenceRequestException.class)
.hasMessage("Inference request failed: 400: Unrecognized request format: Expecting ',' delimiter: line 3 column 1 (char 48)");
.hasMessage(
"Inference request failed: 400: Unrecognized request format: Expecting ',' delimiter: line 3 "
+ "column 1 (char 48)");
}

@Test
Expand All @@ -125,11 +127,11 @@ void testRetry() throws IOException {
.build();

this.softly.assertThat(client.makeInferenceRequest(new JSONObject("{ \"input\": \"data\" }"),
CallCounterFakePrediction.class, ""))
CallCounterFakePrediction.class, ""))
.hasValueSatisfying(fakePrediction -> this.softly.assertThat(fakePrediction.getCounter()).isEqualTo(2));

this.softly.assertThat(client.makeInferenceRequest(new JSONObject("{ \"input\": \"data\" }"),
CallCounterFakePrediction.class, ""))
CallCounterFakePrediction.class, ""))
.hasValueSatisfying(fakePrediction -> this.softly.assertThat(fakePrediction.getCounter()).isEqualTo(3));
}

Expand Down
21 changes: 11 additions & 10 deletions src/test/java/com/bakdata/kserve/client/KServeClientV2Test.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* MIT License
*
* Copyright (c) 2022 bakdata
* Copyright (c) 2024 bakdata GmbH
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -30,6 +30,9 @@
import com.bakdata.kserve.predictv2.InferenceRequest;
import com.bakdata.kserve.predictv2.Parameters;
import com.bakdata.kserve.predictv2.RequestInput;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
import lombok.Getter;
import okhttp3.mockwebserver.Dispatcher;
import okhttp3.mockwebserver.MockResponse;
Expand All @@ -42,10 +45,6 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

import java.io.IOException;
import java.time.Duration;
import java.util.List;

@ExtendWith(SoftAssertionsExtension.class)
class KServeClientV2Test {
private KServeMock mockServer = null;
Expand Down Expand Up @@ -82,7 +81,7 @@ void makeInferenceRequest() throws IOException {
.build();

this.softly.assertThat(client.makeInferenceRequest(getFakeInferenceRequest("data"),
FakePrediction.class, ""))
FakePrediction.class, ""))
.map(FakePrediction::getFake)
.hasValue("data");
}
Expand All @@ -97,7 +96,8 @@ void testPredictionNotExistingModel() {
.httpClient(KServeClientV2.getHttpClient(Duration.ofMillis(10000)))
.build();
final InferenceRequest<String> fakeInferenceRequest = getFakeInferenceRequest("data");
this.softly.assertThatThrownBy(() -> client.makeInferenceRequest(fakeInferenceRequest, FakePrediction.class, "") )
this.softly.assertThatThrownBy(
() -> client.makeInferenceRequest(fakeInferenceRequest, FakePrediction.class, ""))
.isInstanceOf(InferenceRequestException.class)
.hasMessage("Inference request failed: Model test-model not found");
}
Expand All @@ -122,7 +122,8 @@ public MockResponse dispatch(@NotNull final RecordedRequest recordedRequest) {
.build();

final InferenceRequest<String> fakeInferenceRequest = getFakeInferenceRequest("data");
this.softly.assertThatThrownBy(() -> client.makeInferenceRequest(fakeInferenceRequest, FakePrediction.class, ""))
this.softly.assertThatThrownBy(
() -> client.makeInferenceRequest(fakeInferenceRequest, FakePrediction.class, ""))
.isInstanceOf(InferenceRequestException.class)
.hasMessage("Inference request failed: Not Found");
}
Expand All @@ -140,11 +141,11 @@ void testRetry() throws IOException {

final InferenceRequest<String> fakeInferenceRequest = getFakeInferenceRequest("data");
this.softly.assertThat(client.makeInferenceRequest(fakeInferenceRequest,
CallCounterFakePrediction.class, ""))
CallCounterFakePrediction.class, ""))
.hasValueSatisfying(fakePrediction -> this.softly.assertThat(fakePrediction.getCounter()).isEqualTo(2));

this.softly.assertThat(client.makeInferenceRequest(fakeInferenceRequest,
CallCounterFakePrediction.class, ""))
CallCounterFakePrediction.class, ""))
.hasValueSatisfying(fakePrediction -> this.softly.assertThat(fakePrediction.getCounter()).isEqualTo(3));
}

Expand Down

0 comments on commit 87515ea

Please sign in to comment.