From b7d3f47ce5a4810a09156d8140fd85ad98e11c77 Mon Sep 17 00:00:00 2001 From: Jakob Edding <15202881+JakobEdding@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:32:27 +0200 Subject: [PATCH] Accept service URL instead of String and separate protocol --- .../bakdata/kserve/client/KServeClient.java | 20 ++++++++-------- .../kserve/client/KServeClientFactory.java | 16 ++++++------- .../kserve/client/KServeClientFactoryV1.java | 13 +++++----- .../kserve/client/KServeClientFactoryV2.java | 13 +++++----- .../bakdata/kserve/client/KServeClientV1.java | 11 +++++---- .../bakdata/kserve/client/KServeClientV2.java | 14 +++++------ .../bakdata/kserve/predictv2/Parameters.java | 6 ++--- .../kserve/client/KServeClientV1Test.java | 8 +++---- .../kserve/client/KServeClientV2Test.java | 8 +++---- .../java/com/bakdata/kserve/KServeMock.java | 24 +++++++++---------- 10 files changed, 64 insertions(+), 69 deletions(-) diff --git a/src/main/java/com/bakdata/kserve/client/KServeClient.java b/src/main/java/com/bakdata/kserve/client/KServeClient.java index f01f24a..0a0f88b 100644 --- a/src/main/java/com/bakdata/kserve/client/KServeClient.java +++ b/src/main/java/com/bakdata/kserve/client/KServeClient.java @@ -32,6 +32,7 @@ import io.github.resilience4j.retry.RetryConfig; import java.io.IOException; import java.net.HttpURLConnection; +import java.net.URL; import java.time.Duration; import java.util.Optional; import java.util.concurrent.Callable; @@ -82,10 +83,9 @@ public abstract class KServeClient { Optional.ofNullable(System.getenv("KSERVE_RETRY_MAX_INTERVAL")) .map(Integer::parseInt).map(Duration::ofMillis) .orElse(Duration.ofMillis(16000)); - private final String service; + private final URL serviceBaseUrl; private final String modelName; private final OkHttpClient httpClient; - private final boolean httpsEnabled; protected static OkHttpClient getHttpClient(final Duration requestReadTimeout) { return new OkHttpClient.Builder() @@ -124,16 +124,17 @@ private static Request getRequest(final String bodyString, final HttpUrl url) { /** * Make a request to a KServe inference service and return the response. * - * @param inputObject An input object of type {@link I} that contains the data for which a prediction should be made + * @param inputObject An input object of type {@link I} that contains the data for which a prediction should be + * made * @param responseType A class which extends T. The inference service JSON response will be mapped to an object of - * this class + * this class * @param modelNameSuffix A suffix for the model name to use in case a model is deployed multiple times with - * different configurations which can be identified by a suffix to the model name. If not - * needed, it can be set to the empty string "" + * different configurations which can be identified by a suffix to the model name. If not needed, it can be set to + * the empty string "" * @param The base class of the response type. * @return The response of type {@code responseType}. * @throws IOException Thrown if the execution of the request fails or if the body of the response can not be - * decoded to a string + * decoded to a string */ public Optional makeInferenceRequest(final I inputObject, final Class responseType, final String modelNameSuffix) @@ -149,12 +150,11 @@ public Optional makeInferenceRequest(final I inputObject, final Class { * Get a {@link KServeClient} to make requests to an inference service supporting either the v1 or the v2 prediction * protocol. * - * @param service The host name of the service, e.g. "my-classifier.kserve-namespace.svc.cluster.local" + * @param serviceBaseUrl The base URL of the service, e.g. + * "http://my-classifier.kserve-namespace.svc.cluster.local" * @param modelName The model name as specified in model-settings.json or as key metadata.name in the - * InferenceService k8s object configuration file. + * InferenceService k8s object configuration file. * @param requestReadTimeout The read time out as documented for the * OkHttpClient - * which this library uses - * @param httpsEnabled Whether HTTPS should be used (true) or HTTP (false) + * which this library uses * @return An instance of {@link KServeClient} */ KServeClient getKServeClient( - final String service, + final URL serviceBaseUrl, final String modelName, - final Duration requestReadTimeout, - final boolean httpsEnabled); + final Duration requestReadTimeout); } diff --git a/src/main/java/com/bakdata/kserve/client/KServeClientFactoryV1.java b/src/main/java/com/bakdata/kserve/client/KServeClientFactoryV1.java index b7ac5de..04b46fc 100644 --- a/src/main/java/com/bakdata/kserve/client/KServeClientFactoryV1.java +++ b/src/main/java/com/bakdata/kserve/client/KServeClientFactoryV1.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2022 bakdata + * Copyright (c) 2024 bakdata * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -24,22 +24,21 @@ package com.bakdata.kserve.client; +import java.net.URL; +import java.time.Duration; import okhttp3.OkHttpClient; import org.json.JSONObject; -import java.time.Duration; - /** * A factory producing a {@link KServeClient} to support the v1 prediction protocol. */ public class KServeClientFactoryV1 implements KServeClientFactory { @Override public KServeClient getKServeClient( - final String service, + final URL serviceBaseUrl, final String modelName, - final Duration requestReadTimeout, - final boolean httpsEnabled) { + final Duration requestReadTimeout) { final OkHttpClient httpClient = KServeClient.getHttpClient(requestReadTimeout); - return new KServeClientV1(service, modelName, httpClient, httpsEnabled); + return new KServeClientV1(serviceBaseUrl, modelName, httpClient); } } diff --git a/src/main/java/com/bakdata/kserve/client/KServeClientFactoryV2.java b/src/main/java/com/bakdata/kserve/client/KServeClientFactoryV2.java index e542309..ac3edd3 100644 --- a/src/main/java/com/bakdata/kserve/client/KServeClientFactoryV2.java +++ b/src/main/java/com/bakdata/kserve/client/KServeClientFactoryV2.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2022 bakdata + * Copyright (c) 2024 bakdata * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,9 +25,9 @@ package com.bakdata.kserve.client; import com.bakdata.kserve.predictv2.InferenceRequest; -import okhttp3.OkHttpClient; - +import java.net.URL; import java.time.Duration; +import okhttp3.OkHttpClient; /** * A factory producing a {@link KServeClient} to support the @@ -36,11 +36,10 @@ public class KServeClientFactoryV2 implements KServeClientFactory> { @Override public KServeClient> getKServeClient( - final String service, + final URL serviceBaseUrl, final String modelName, - final Duration requestReadTimeout, - final boolean httpsEnabled) { + final Duration requestReadTimeout) { final OkHttpClient httpClient = KServeClient.getHttpClient(requestReadTimeout); - return new KServeClientV2(service, modelName, httpClient, httpsEnabled); + return new KServeClientV2(serviceBaseUrl, modelName, httpClient); } } diff --git a/src/main/java/com/bakdata/kserve/client/KServeClientV1.java b/src/main/java/com/bakdata/kserve/client/KServeClientV1.java index ff4a212..384547b 100644 --- a/src/main/java/com/bakdata/kserve/client/KServeClientV1.java +++ b/src/main/java/com/bakdata/kserve/client/KServeClientV1.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2022 bakdata + * Copyright (c) 2024 bakdata * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -24,6 +24,7 @@ package com.bakdata.kserve.client; +import java.net.URL; import lombok.Builder; import lombok.extern.slf4j.Slf4j; import okhttp3.OkHttpClient; @@ -38,8 +39,8 @@ public class KServeClientV1 extends KServeClient { @Builder KServeClientV1( - final String service, final String modelName, final OkHttpClient httpClient, final boolean httpsEnabled) { - super(service, modelName, httpClient, httpsEnabled); + final URL serviceBaseUrl, final String modelName, final OkHttpClient httpClient) { + super(serviceBaseUrl, modelName, httpClient); } @Override @@ -49,8 +50,8 @@ protected String extractErrorMessage(final String stringBody) { } @Override - protected String getUrlString(final String protocol, final String service, final String modelName) { - return String.format("%s://%s/v1/models/%s:predict", protocol, service, modelName); + protected String getUrlString(final URL serviceBaseUrl, final String modelName) { + return String.format("%s/v1/models/%s:predict", serviceBaseUrl, modelName); } @Override diff --git a/src/main/java/com/bakdata/kserve/client/KServeClientV2.java b/src/main/java/com/bakdata/kserve/client/KServeClientV2.java index 55fc7bc..f781a31 100644 --- a/src/main/java/com/bakdata/kserve/client/KServeClientV2.java +++ b/src/main/java/com/bakdata/kserve/client/KServeClientV2.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2022 bakdata + * Copyright (c) 2024 bakdata * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,12 +27,12 @@ import com.bakdata.kserve.predictv2.InferenceError; import com.bakdata.kserve.predictv2.InferenceRequest; import com.fasterxml.jackson.core.JsonProcessingException; +import java.net.URL; +import java.util.Optional; import lombok.Builder; import lombok.extern.slf4j.Slf4j; import okhttp3.OkHttpClient; -import java.util.Optional; - /** * An implementation of a {@link KServeClient} to support the * v2 prediction protocol. @@ -42,8 +42,8 @@ public class KServeClientV2 extends KServeClient> { @Builder KServeClientV2( - final String service, final String modelName, final OkHttpClient httpClient, final boolean httpsEnabled) { - super(service, modelName, httpClient, httpsEnabled); + final URL serviceBaseUrl, final String modelName, final OkHttpClient httpClient) { + super(serviceBaseUrl, modelName, httpClient); } @Override @@ -60,8 +60,8 @@ protected String extractErrorMessage(final String stringBody) { } @Override - protected String getUrlString(final String protocol, final String service, final String modelName) { - return String.format("%s://%s/v2/models/%s/infer", protocol, service, modelName); + protected String getUrlString(final URL serviceBaseUrl, final String modelName) { + return String.format("%s/v2/models/%s/infer", serviceBaseUrl, modelName); } @Override diff --git a/src/main/java/com/bakdata/kserve/predictv2/Parameters.java b/src/main/java/com/bakdata/kserve/predictv2/Parameters.java index e53d8d4..124bd38 100644 --- a/src/main/java/com/bakdata/kserve/predictv2/Parameters.java +++ b/src/main/java/com/bakdata/kserve/predictv2/Parameters.java @@ -24,10 +24,9 @@ package com.bakdata.kserve.predictv2; -import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; -import lombok.NoArgsConstructor; +import lombok.extern.jackson.Jacksonized; /** * A class to represent parameters as @@ -35,8 +34,7 @@ */ @Data @Builder -@NoArgsConstructor -@AllArgsConstructor +@Jacksonized public class Parameters { private String contentType; private Object extra; diff --git a/src/test/java/com/bakdata/kserve/client/KServeClientV1Test.java b/src/test/java/com/bakdata/kserve/client/KServeClientV1Test.java index 78dd980..b4cdcdf 100644 --- a/src/test/java/com/bakdata/kserve/client/KServeClientV1Test.java +++ b/src/test/java/com/bakdata/kserve/client/KServeClientV1Test.java @@ -56,7 +56,7 @@ void makeInferenceRequest() throws IOException { this.mockServer.setModelEndpoint("test-model", "{ \"fake\": \"data\"}"); final KServeClientV1 client = KServeClientV1.builder() - .service(this.mockServer.getWholeServiceEndpoint()) + .serviceBaseUrl(this.mockServer.getServiceBaseUrl()) .modelName("test-model") .httpClient(KServeClientV1.getHttpClient(Duration.ofMillis(10000))) .build(); @@ -72,7 +72,7 @@ void testPredictionNotExistingModel() { this.mockServer.setModelEndpoint("test-model", "{ \"fake\": \"data\"}"); final KServeClientV1 client = KServeClientV1.builder() - .service(this.mockServer.getWholeServiceEndpoint()) + .serviceBaseUrl(this.mockServer.getServiceBaseUrl()) .modelName("fake-model") .httpClient(KServeClientV1.getHttpClient(Duration.ofMillis(10000))) .build(); @@ -102,7 +102,7 @@ public MockResponse dispatch(@NotNull final RecordedRequest recordedRequest) { this.mockServer.getMockWebServer().setDispatcher(dispatcher); final KServeClientV1 client = KServeClientV1.builder() - .service(this.mockServer.getWholeServiceEndpoint()) + .serviceBaseUrl(this.mockServer.getServiceBaseUrl()) .modelName("test-model") .httpClient(KServeClientV1.getHttpClient(Duration.ofMillis(10000))) .build(); @@ -120,7 +120,7 @@ void testRetry() throws IOException { this.mockServer.setUpForRetryTest(); final KServeClientV1 client = KServeClientV1.builder() - .service(this.mockServer.getWholeServiceEndpoint()) + .serviceBaseUrl(this.mockServer.getServiceBaseUrl()) // Important so that request is aborted and retried .httpClient(KServeClientV1.getHttpClient(Duration.ofMillis(1000))) .modelName("test-model") diff --git a/src/test/java/com/bakdata/kserve/client/KServeClientV2Test.java b/src/test/java/com/bakdata/kserve/client/KServeClientV2Test.java index 85fc53b..6209eed 100644 --- a/src/test/java/com/bakdata/kserve/client/KServeClientV2Test.java +++ b/src/test/java/com/bakdata/kserve/client/KServeClientV2Test.java @@ -75,7 +75,7 @@ void makeInferenceRequest() throws IOException { this.mockServer.setModelEndpoint("test-model", "{ \"fake\": \"data\"}"); final KServeClientV2 client = KServeClientV2.builder() - .service(this.mockServer.getWholeServiceEndpoint()) + .serviceBaseUrl(this.mockServer.getServiceBaseUrl()) .modelName("test-model") .httpClient(KServeClientV2.getHttpClient(Duration.ofMillis(10000))) .build(); @@ -91,7 +91,7 @@ void testPredictionNotExistingModel() { this.mockServer.setModelEndpoint("test-model", "{ \"fake\": \"data\"}"); final KServeClientV2 client = KServeClientV2.builder() - .service(this.mockServer.getWholeServiceEndpoint()) + .serviceBaseUrl(this.mockServer.getServiceBaseUrl()) .modelName("fake-model") .httpClient(KServeClientV2.getHttpClient(Duration.ofMillis(10000))) .build(); @@ -116,7 +116,7 @@ public MockResponse dispatch(@NotNull final RecordedRequest recordedRequest) { this.mockServer.getMockWebServer().setDispatcher(dispatcher); final KServeClientV2 client = KServeClientV2.builder() - .service(this.mockServer.getWholeServiceEndpoint()) + .serviceBaseUrl(this.mockServer.getServiceBaseUrl()) .modelName("test-model") .httpClient(KServeClientV2.getHttpClient(Duration.ofMillis(10000))) .build(); @@ -133,7 +133,7 @@ void testRetry() throws IOException { this.mockServer.setUpForRetryTest(); final KServeClientV2 client = KServeClientV2.builder() - .service(this.mockServer.getWholeServiceEndpoint()) + .serviceBaseUrl(this.mockServer.getServiceBaseUrl()) // Important so that request is aborted and retried .httpClient(KServeClientV2.getHttpClient(Duration.ofMillis(1000))) .modelName("test-model") diff --git a/src/testFixtures/java/com/bakdata/kserve/KServeMock.java b/src/testFixtures/java/com/bakdata/kserve/KServeMock.java index dff5295..4223bd2 100644 --- a/src/testFixtures/java/com/bakdata/kserve/KServeMock.java +++ b/src/testFixtures/java/com/bakdata/kserve/KServeMock.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2022 bakdata + * Copyright (c) 2024 bakdata * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -24,6 +24,9 @@ package com.bakdata.kserve; +import java.net.MalformedURLException; +import java.net.URL; +import java.util.concurrent.atomic.AtomicInteger; import lombok.Getter; import okhttp3.mockwebserver.Dispatcher; import okhttp3.mockwebserver.MockResponse; @@ -31,24 +34,19 @@ import okhttp3.mockwebserver.RecordedRequest; import org.jetbrains.annotations.NotNull; -import java.util.concurrent.atomic.AtomicInteger; - public abstract class KServeMock { @Getter private final MockWebServer mockWebServer = new MockWebServer(); abstract MockResponse getModelNotFoundResponse(String modelName); - public String getWholeServiceEndpoint() { - return this.getServiceName() + this.getBaseEndpoint(); - } - - public String getBaseEndpoint() { - return ":" + this.mockWebServer.getPort(); - } - - public String getServiceName() { - return this.mockWebServer.getHostName(); + public URL getServiceBaseUrl() { + try { + return new URL( + String.format("http://%s:%s", this.mockWebServer.getHostName(), this.mockWebServer.getPort())); + } catch (final MalformedURLException e) { + throw new RuntimeException(e); + } } public void setModelEndpoint(final String modelName, final String body) {