Skip to content

Commit

Permalink
Accept service URL instead of String and separate protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobEdding committed Apr 18, 2024
1 parent 003a595 commit 02728d5
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 69 deletions.
20 changes: 10 additions & 10 deletions src/main/java/com/bakdata/kserve/client/KServeClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import io.github.resilience4j.retry.RetryConfig;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.URL;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.Callable;
Expand Down Expand Up @@ -82,10 +83,9 @@ public abstract class KServeClient<I> {
Optional.ofNullable(System.getenv("KSERVE_RETRY_MAX_INTERVAL"))
.map(Integer::parseInt).map(Duration::ofMillis)
.orElse(Duration.ofMillis(16000));
private final String service;
private final URL serviceBaseUrl;
private final String modelName;
private final OkHttpClient httpClient;
private final boolean httpsEnabled;

protected static OkHttpClient getHttpClient(final Duration requestReadTimeout) {
return new OkHttpClient.Builder()
Expand Down Expand Up @@ -124,16 +124,17 @@ private static Request getRequest(final String bodyString, final HttpUrl url) {
/**
* Make a request to a KServe inference service and return the response.
*
* @param inputObject An input object of type {@link I} that contains the data for which a prediction should be made
* @param inputObject An input object of type {@link I} that contains the data for which a prediction should be
* made
* @param responseType A class which extends T. The inference service JSON response will be mapped to an object of
* this class
* this class
* @param modelNameSuffix A suffix for the model name to use in case a model is deployed multiple times with
* different configurations which can be identified by a suffix to the model name. If not
* needed, it can be set to the empty string ""
* different configurations which can be identified by a suffix to the model name. If not needed, it can be set to
* the empty string ""
* @param <T> The base class of the response type.
* @return The response of type {@code responseType}.
* @throws IOException Thrown if the execution of the request fails or if the body of the response can not be
* decoded to a string
* decoded to a string
*/
public <T> Optional<T> makeInferenceRequest(final I inputObject, final Class<? extends T> responseType,
final String modelNameSuffix)
Expand All @@ -149,12 +150,11 @@ public <T> Optional<T> makeInferenceRequest(final I inputObject, final Class<? e

protected final HttpUrl getModelURI(final String modelNameSuffix) {
return HttpUrl.get(this.getUrlString(
this.httpsEnabled ? "https" : "http",
this.service,
this.serviceBaseUrl,
String.format("%s%s", this.modelName, modelNameSuffix)));
}

protected abstract String getUrlString(String protocol, String service, String modelName);
protected abstract String getUrlString(URL serviceBaseUrl, String modelName);

abstract String getBodyString(final I inputObject);

Expand Down
16 changes: 8 additions & 8 deletions src/main/java/com/bakdata/kserve/client/KServeClientFactory.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* MIT License
*
* Copyright (c) 2022 bakdata
* Copyright (c) 2024 bakdata GmbH
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -24,6 +24,7 @@

package com.bakdata.kserve.client;

import java.net.URL;
import java.time.Duration;

/**
Expand All @@ -37,18 +38,17 @@ public interface KServeClientFactory<T> {
* Get a {@link KServeClient} to make requests to an inference service supporting either the v1 or the v2 prediction
* protocol.
*
* @param service The host name of the service, e.g. "my-classifier.kserve-namespace.svc.cluster.local"
* @param serviceBaseUrl The base URL of the service, e.g.
* "http://my-classifier.kserve-namespace.svc.cluster.local"
* @param modelName The model name as specified in model-settings.json or as key metadata.name in the
* InferenceService k8s object configuration file.
* InferenceService k8s object configuration file.
* @param requestReadTimeout The read time out as documented for the
* <a href="https://square.github.io/okhttp/4.x/okhttp/okhttp3/-ok-http-client/-builder/read-timeout/">OkHttpClient
* </a> which this library uses
* @param httpsEnabled Whether HTTPS should be used (true) or HTTP (false)
* </a> which this library uses
* @return An instance of {@link KServeClient}
*/
KServeClient<T> getKServeClient(
final String service,
final URL serviceBaseUrl,
final String modelName,
final Duration requestReadTimeout,
final boolean httpsEnabled);
final Duration requestReadTimeout);
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* MIT License
*
* Copyright (c) 2022 bakdata
* Copyright (c) 2024 bakdata GmbH
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -24,22 +24,21 @@

package com.bakdata.kserve.client;

import java.net.URL;
import java.time.Duration;
import okhttp3.OkHttpClient;
import org.json.JSONObject;

import java.time.Duration;

/**
* A factory producing a {@link KServeClient} to support the v1 prediction protocol.
*/
public class KServeClientFactoryV1 implements KServeClientFactory<JSONObject> {
@Override
public KServeClient<JSONObject> getKServeClient(
final String service,
final URL serviceBaseUrl,
final String modelName,
final Duration requestReadTimeout,
final boolean httpsEnabled) {
final Duration requestReadTimeout) {
final OkHttpClient httpClient = KServeClient.getHttpClient(requestReadTimeout);
return new KServeClientV1(service, modelName, httpClient, httpsEnabled);
return new KServeClientV1(serviceBaseUrl, modelName, httpClient);
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* MIT License
*
* Copyright (c) 2022 bakdata
* Copyright (c) 2024 bakdata GmbH
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -25,9 +25,9 @@
package com.bakdata.kserve.client;

import com.bakdata.kserve.predictv2.InferenceRequest;
import okhttp3.OkHttpClient;

import java.net.URL;
import java.time.Duration;
import okhttp3.OkHttpClient;

/**
* A factory producing a {@link KServeClient} to support the
Expand All @@ -36,11 +36,10 @@
public class KServeClientFactoryV2 implements KServeClientFactory<InferenceRequest<?>> {
@Override
public KServeClient<InferenceRequest<?>> getKServeClient(
final String service,
final URL serviceBaseUrl,
final String modelName,
final Duration requestReadTimeout,
final boolean httpsEnabled) {
final Duration requestReadTimeout) {
final OkHttpClient httpClient = KServeClient.getHttpClient(requestReadTimeout);
return new KServeClientV2(service, modelName, httpClient, httpsEnabled);
return new KServeClientV2(serviceBaseUrl, modelName, httpClient);
}
}
11 changes: 6 additions & 5 deletions src/main/java/com/bakdata/kserve/client/KServeClientV1.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* MIT License
*
* Copyright (c) 2022 bakdata
* Copyright (c) 2024 bakdata GmbH
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -24,6 +24,7 @@

package com.bakdata.kserve.client;

import java.net.URL;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
Expand All @@ -38,8 +39,8 @@
public class KServeClientV1 extends KServeClient<JSONObject> {
@Builder
KServeClientV1(
final String service, final String modelName, final OkHttpClient httpClient, final boolean httpsEnabled) {
super(service, modelName, httpClient, httpsEnabled);
final URL serviceBaseUrl, final String modelName, final OkHttpClient httpClient) {
super(serviceBaseUrl, modelName, httpClient);
}

@Override
Expand All @@ -49,8 +50,8 @@ protected String extractErrorMessage(final String stringBody) {
}

@Override
protected String getUrlString(final String protocol, final String service, final String modelName) {
return String.format("%s://%s/v1/models/%s:predict", protocol, service, modelName);
protected String getUrlString(final URL serviceBaseUrl, final String modelName) {
return String.format("%s/v1/models/%s:predict", serviceBaseUrl, modelName);
}

@Override
Expand Down
14 changes: 7 additions & 7 deletions src/main/java/com/bakdata/kserve/client/KServeClientV2.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* MIT License
*
* Copyright (c) 2022 bakdata
* Copyright (c) 2024 bakdata GmbH
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -27,12 +27,12 @@
import com.bakdata.kserve.predictv2.InferenceError;
import com.bakdata.kserve.predictv2.InferenceRequest;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.net.URL;
import java.util.Optional;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;

import java.util.Optional;

/**
* An implementation of a {@link KServeClient} to support the
* <a href="https://kserve.github.io/website/modelserving/inference_api/">v2 prediction protocol</a>.
Expand All @@ -42,8 +42,8 @@ public class KServeClientV2 extends KServeClient<InferenceRequest<?>> {

@Builder
KServeClientV2(
final String service, final String modelName, final OkHttpClient httpClient, final boolean httpsEnabled) {
super(service, modelName, httpClient, httpsEnabled);
final URL serviceBaseUrl, final String modelName, final OkHttpClient httpClient) {
super(serviceBaseUrl, modelName, httpClient);
}

@Override
Expand All @@ -60,8 +60,8 @@ protected String extractErrorMessage(final String stringBody) {
}

@Override
protected String getUrlString(final String protocol, final String service, final String modelName) {
return String.format("%s://%s/v2/models/%s/infer", protocol, service, modelName);
protected String getUrlString(final URL serviceBaseUrl, final String modelName) {
return String.format("%s/v2/models/%s/infer", serviceBaseUrl, modelName);
}

@Override
Expand Down
6 changes: 2 additions & 4 deletions src/main/java/com/bakdata/kserve/predictv2/Parameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,17 @@

package com.bakdata.kserve.predictv2;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.jackson.Jacksonized;

/**
* A class to represent <a href="https://kserve.github.io/website/modelserving/inference_api/#parameters"> parameters as
* defined in the v2 prediction protocol</a>.
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@Jacksonized
public class Parameters {
private String contentType;
private Object extra;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void makeInferenceRequest() throws IOException {
this.mockServer.setModelEndpoint("test-model", "{ \"fake\": \"data\"}");

final KServeClientV1 client = KServeClientV1.builder()
.service(this.mockServer.getWholeServiceEndpoint())
.serviceBaseUrl(this.mockServer.getServiceBaseUrl())
.modelName("test-model")
.httpClient(KServeClientV1.getHttpClient(Duration.ofMillis(10000)))
.build();
Expand All @@ -72,7 +72,7 @@ void testPredictionNotExistingModel() {
this.mockServer.setModelEndpoint("test-model", "{ \"fake\": \"data\"}");

final KServeClientV1 client = KServeClientV1.builder()
.service(this.mockServer.getWholeServiceEndpoint())
.serviceBaseUrl(this.mockServer.getServiceBaseUrl())
.modelName("fake-model")
.httpClient(KServeClientV1.getHttpClient(Duration.ofMillis(10000)))
.build();
Expand Down Expand Up @@ -102,7 +102,7 @@ public MockResponse dispatch(@NotNull final RecordedRequest recordedRequest) {
this.mockServer.getMockWebServer().setDispatcher(dispatcher);

final KServeClientV1 client = KServeClientV1.builder()
.service(this.mockServer.getWholeServiceEndpoint())
.serviceBaseUrl(this.mockServer.getServiceBaseUrl())
.modelName("test-model")
.httpClient(KServeClientV1.getHttpClient(Duration.ofMillis(10000)))
.build();
Expand All @@ -120,7 +120,7 @@ void testRetry() throws IOException {
this.mockServer.setUpForRetryTest();

final KServeClientV1 client = KServeClientV1.builder()
.service(this.mockServer.getWholeServiceEndpoint())
.serviceBaseUrl(this.mockServer.getServiceBaseUrl())
// Important so that request is aborted and retried
.httpClient(KServeClientV1.getHttpClient(Duration.ofMillis(1000)))
.modelName("test-model")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void makeInferenceRequest() throws IOException {
this.mockServer.setModelEndpoint("test-model", "{ \"fake\": \"data\"}");

final KServeClientV2 client = KServeClientV2.builder()
.service(this.mockServer.getWholeServiceEndpoint())
.serviceBaseUrl(this.mockServer.getServiceBaseUrl())
.modelName("test-model")
.httpClient(KServeClientV2.getHttpClient(Duration.ofMillis(10000)))
.build();
Expand All @@ -91,7 +91,7 @@ void testPredictionNotExistingModel() {
this.mockServer.setModelEndpoint("test-model", "{ \"fake\": \"data\"}");

final KServeClientV2 client = KServeClientV2.builder()
.service(this.mockServer.getWholeServiceEndpoint())
.serviceBaseUrl(this.mockServer.getServiceBaseUrl())
.modelName("fake-model")
.httpClient(KServeClientV2.getHttpClient(Duration.ofMillis(10000)))
.build();
Expand All @@ -116,7 +116,7 @@ public MockResponse dispatch(@NotNull final RecordedRequest recordedRequest) {
this.mockServer.getMockWebServer().setDispatcher(dispatcher);

final KServeClientV2 client = KServeClientV2.builder()
.service(this.mockServer.getWholeServiceEndpoint())
.serviceBaseUrl(this.mockServer.getServiceBaseUrl())
.modelName("test-model")
.httpClient(KServeClientV2.getHttpClient(Duration.ofMillis(10000)))
.build();
Expand All @@ -133,7 +133,7 @@ void testRetry() throws IOException {
this.mockServer.setUpForRetryTest();

final KServeClientV2 client = KServeClientV2.builder()
.service(this.mockServer.getWholeServiceEndpoint())
.serviceBaseUrl(this.mockServer.getServiceBaseUrl())
// Important so that request is aborted and retried
.httpClient(KServeClientV2.getHttpClient(Duration.ofMillis(1000)))
.modelName("test-model")
Expand Down
Loading

0 comments on commit 02728d5

Please sign in to comment.