diff --git a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-openai.adoc b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-openai.adoc index 1b0f76588..20740a991 100644 --- a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-openai.adoc +++ b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-openai.adoc @@ -112,6 +112,23 @@ endif::add-copy-button-to-env-var[] | +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.organization-id]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.organization-id[quarkus.langchain4j.openai.organization-id]` + + +[.description] +-- +OpenAI Organization ID (https://platform.openai.com/docs/api-reference/organization-optional) + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI_ORGANIZATION_ID+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI_ORGANIZATION_ID+++` +endif::add-copy-button-to-env-var[] +--|string +| + + a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.timeout]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.timeout[quarkus.langchain4j.openai.timeout]` diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/QuarkusRestApiResource.java b/integration-tests/openai/src/main/java/org/acme/example/openai/QuarkusRestApiResource.java index c4af53107..eb9aec48a 100644 --- a/integration-tests/openai/src/main/java/org/acme/example/openai/QuarkusRestApiResource.java +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/QuarkusRestApiResource.java @@ -16,9 +16,7 @@ */ package org.acme.example.openai; -import static org.acme.example.openai.MessageUtil.createChatCompletionRequest; -import static org.acme.example.openai.MessageUtil.createCompletionRequest; -import static org.acme.example.openai.MessageUtil.createEmbeddingRequest; +import static org.acme.example.openai.MessageUtil.*; import java.net.URI; import java.net.URISyntaxException; @@ -50,6 +48,7 @@ public class QuarkusRestApiResource { private final OpenAiRestApi restApi; private final String token; + private final String organizationId; public QuarkusRestApiResource(Langchain4jOpenAiConfig runtimeConfig) throws URISyntaxException { @@ -58,7 +57,7 @@ public QuarkusRestApiResource(Langchain4jOpenAiConfig runtimeConfig) .build(OpenAiRestApi.class); this.token = runtimeConfig.apiKey() .orElseThrow(() -> new IllegalArgumentException("quarkus.langchain4j.openai.api-key must be provided")); - + this.organizationId = runtimeConfig.organizationId().orElse(null); } @GET @@ -66,7 +65,10 @@ public QuarkusRestApiResource(Langchain4jOpenAiConfig runtimeConfig) public String chatSync() { return restApi.blockingChatCompletion( createChatCompletionRequest("Write a short 1 paragraph funny poem about segmentation fault"), - OpenAiRestApi.ApiMetadata.of(token, null)) + OpenAiRestApi.ApiMetadata.builder() + .apiKey(token) + .organizationId(organizationId) + .build()) .content(); } @@ -75,7 +77,10 @@ public String chatSync() { public Uni chatAsync() { return restApi .createChatCompletion(createChatCompletionRequest("Write a short 1 paragraph funny poem about Unicode"), - OpenAiRestApi.ApiMetadata.of(token, null)) + OpenAiRestApi.ApiMetadata.builder() + .apiKey(token) + .organizationId(organizationId) + .build()) .map(ChatCompletionResponse::content); } @@ -85,7 +90,10 @@ public Uni chatAsync() { public Multi chatStreaming() { return restApi.streamingChatCompletion( createChatCompletionRequest("Write a short 1 paragraph funny poem about Enterprise Java"), - OpenAiRestApi.ApiMetadata.of(token, null)) + OpenAiRestApi.ApiMetadata.builder() + .apiKey(token) + .organizationId(organizationId) + .build()) .map(r -> { if (r.choices() != null) { if (r.choices().size() == 1) { @@ -115,7 +123,10 @@ public Multi chatStreaming() { public String languageSync() { return restApi.blockingCompletion( createCompletionRequest("Write a short 1 paragraph funny poem about segmentation fault"), - OpenAiRestApi.ApiMetadata.of(token, null)) + OpenAiRestApi.ApiMetadata.builder() + .apiKey(token) + .organizationId(organizationId) + .build()) .text(); } @@ -124,7 +135,10 @@ public String languageSync() { public Uni languageAsync() { return restApi .completion(createCompletionRequest("Write a short 1 paragraph funny poem about Unicode"), - OpenAiRestApi.ApiMetadata.of(token, null)) + OpenAiRestApi.ApiMetadata.builder() + .apiKey(token) + .organizationId(organizationId) + .build()) .map(CompletionResponse::text); } @@ -134,7 +148,10 @@ public Uni languageAsync() { public Multi languageStreaming() { return restApi.streamingCompletion( createCompletionRequest("Write a short 1 paragraph funny poem about Enterprise Java"), - OpenAiRestApi.ApiMetadata.of(token, null)) + OpenAiRestApi.ApiMetadata.builder() + .apiKey(token) + .organizationId(organizationId) + .build()) .map(r -> { if (r.choices() != null) { if (r.choices().size() == 1) { @@ -153,14 +170,22 @@ public Multi languageStreaming() { @Path("embedding/sync") public List embeddingSync() { return restApi.blockingEmbedding(createEmbeddingRequest("Your text string goes here"), - OpenAiRestApi.ApiMetadata.of(token, null)).embedding(); + OpenAiRestApi.ApiMetadata.builder() + .apiKey(token) + .organizationId(organizationId) + .build()) + .embedding(); } @GET @Path("embedding/async") public Uni> embeddingAsync() { return restApi - .embedding(createEmbeddingRequest("Your text string goes here"), OpenAiRestApi.ApiMetadata.of(token, null)) + .embedding(createEmbeddingRequest("Your text string goes here"), + OpenAiRestApi.ApiMetadata.builder() + .apiKey(token) + .organizationId(organizationId) + .build()) .map(EmbeddingResponse::embedding); } diff --git a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java index 55c7c6455..1a5260276 100644 --- a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java +++ b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java @@ -420,21 +420,48 @@ class ApiMetadata { @QueryParam("api-version") public final String apiVersion; + @HeaderParam("OpenAI-Organization") + public final String organizationId; + private ApiMetadata(String authorization, String apiKey, - String apiVersion) { + String apiVersion, String organizationId) { this.authorization = authorization; this.apiKey = apiKey; this.apiVersion = apiVersion; + this.organizationId = organizationId; + } + + public static ApiMetadata.Builder builder() { + return new Builder(); } - public static ApiMetadata of(String apiKey, String apiVersion) { - if (apiKey == null) { - return new ApiMetadata(null, null, apiVersion); + public static class Builder { + private String apiKey; + private String apiVersion; + private String organizationId; + + public ApiMetadata build() { + return (apiKey == null) ? new ApiMetadata(null, null, apiVersion, organizationId) + : new ApiMetadata( + "Bearer " + apiKey, // typical OpenAI authentication + apiKey, // used by AzureAI + apiVersion, organizationId); + } + + public ApiMetadata.Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + public ApiMetadata.Builder apiVersion(String apiVersion) { + this.apiVersion = apiVersion; + return this; + } + + public ApiMetadata.Builder organizationId(String organizationId) { + this.organizationId = organizationId; + return this; } - return new ApiMetadata( - "Bearer " + apiKey, // typical OpenAI authentication - apiKey, // used by AzureAI - apiVersion); } } } diff --git a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java index 21f1cb32a..0a7b0d161 100644 --- a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java +++ b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java @@ -50,6 +50,7 @@ public class QuarkusOpenAiClient extends OpenAiClient { private final String apiKey; private final String apiVersion; + private final String organizationId; private final OpenAiRestApi restApi; @@ -70,6 +71,7 @@ public static void clearCache() { private QuarkusOpenAiClient(Builder builder) { this.apiKey = determineApiKey(builder); this.apiVersion = builder.apiVersion; + this.organizationId = builder.organizationId; // cache the client the builder could be called with the same parameters from multiple models this.restApi = cache.compute(builder, new BiFunction() { @Override @@ -94,6 +96,7 @@ public OpenAiRestApi apply(Builder builder, OpenAiRestApi openAiRestApi) { InetSocketAddress socketAddress = (InetSocketAddress) builder.proxy.address(); restApiBuilder.proxyAddress(socketAddress.getHostName(), socketAddress.getPort()); } + return restApiBuilder.build(OpenAiRestApi.class); } catch (URISyntaxException e) { throw new RuntimeException(e); @@ -119,7 +122,11 @@ public SyncOrAsyncOrStreaming completion(CompletionRequest r public CompletionResponse execute() { return restApi.blockingCompletion( CompletionRequest.builder().from(request).stream(null).build(), - OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)); + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()); } @Override @@ -128,7 +135,12 @@ public AsyncResponseHandling onResponse(Consumer responseHan new Supplier<>() { @Override public Uni get() { - return restApi.completion(request, OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)); + return restApi.completion(request, + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()); } }, responseHandler); @@ -141,7 +153,12 @@ public StreamingResponseHandling onPartialResponse( new Supplier<>() { @Override public Multi get() { - return restApi.streamingCompletion(request, OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)); + return restApi.streamingCompletion(request, + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()); } }, partialResponseHandler); } @@ -160,7 +177,11 @@ public SyncOrAsyncOrStreaming chatCompletion(ChatComplet public ChatCompletionResponse execute() { return restApi.blockingChatCompletion( ChatCompletionRequest.builder().from(request).stream(null).build(), - OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)); + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()); } @Override @@ -169,7 +190,12 @@ public AsyncResponseHandling onResponse(Consumer respons new Supplier<>() { @Override public Uni get() { - return restApi.createChatCompletion(request, OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)); + return restApi.createChatCompletion(request, + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()); } }, responseHandler); @@ -183,7 +209,11 @@ public StreamingResponseHandling onPartialResponse( @Override public Multi get() { return restApi.streamingChatCompletion(request, - OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)); + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()); } }, partialResponseHandler); } @@ -200,7 +230,12 @@ public SyncOrAsyncOrStreaming chatCompletion(String userMessage) { @Override public String execute() { return restApi - .blockingChatCompletion(request, OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)) + .blockingChatCompletion(request, + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()) .content(); } @@ -213,7 +248,11 @@ public Uni get() { return restApi .createChatCompletion( ChatCompletionRequest.builder().from(request).stream(null).build(), - OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)) + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()) .map(ChatCompletionResponse::content); } }, @@ -230,7 +269,11 @@ public Multi get() { return restApi .streamingChatCompletion( ChatCompletionRequest.builder().from(request).stream(true).build(), - OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)) + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()) .filter(r -> { if (r.choices() != null) { if (r.choices().size() == 1) { @@ -255,7 +298,12 @@ public SyncOrAsync embedding(EmbeddingRequest request) { return new SyncOrAsync<>() { @Override public EmbeddingResponse execute() { - return restApi.blockingEmbedding(request, OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)); + return restApi.blockingEmbedding(request, + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()); } @Override @@ -264,7 +312,12 @@ public AsyncResponseHandling onResponse(Consumer responseHand new Supplier<>() { @Override public Uni get() { - return restApi.embedding(request, OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)); + return restApi.embedding(request, + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()); } }, responseHandler); @@ -280,7 +333,13 @@ public SyncOrAsync> embedding(String input) { return new SyncOrAsync<>() { @Override public List execute() { - return restApi.blockingEmbedding(request, OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)).embedding(); + return restApi.blockingEmbedding(request, + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()) + .embedding(); } @Override @@ -289,7 +348,12 @@ public AsyncResponseHandling onResponse(Consumer> responseHandler) { new Supplier<>() { @Override public Uni> get() { - return restApi.embedding(request, OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)) + return restApi.embedding(request, + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()) .map(EmbeddingResponse::embedding); } }, @@ -303,7 +367,12 @@ public SyncOrAsync moderation(ModerationRequest request) { return new SyncOrAsync<>() { @Override public ModerationResponse execute() { - return restApi.blockingModeration(request, OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)); + return restApi.blockingModeration(request, + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()); } @Override @@ -312,7 +381,12 @@ public AsyncResponseHandling onResponse(Consumer responseHan new Supplier<>() { @Override public Uni get() { - return restApi.moderation(request, OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)); + return restApi.moderation(request, + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()); } }, responseHandler); @@ -329,7 +403,13 @@ public SyncOrAsync moderation(String input) { return new SyncOrAsync<>() { @Override public ModerationResult execute() { - return restApi.blockingModeration(request, OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)).results().get(0); + return restApi.blockingModeration(request, + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()) + .results().get(0); } @Override @@ -338,7 +418,12 @@ public AsyncResponseHandling onResponse(Consumer responseHandl new Supplier<>() { @Override public Uni get() { - return restApi.moderation(request, OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)) + return restApi.moderation(request, + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()) .map(r -> r.results().get(0)); } }, @@ -353,7 +438,11 @@ public SyncOrAsync imagesGeneration(GenerateImagesReques @Override public GenerateImagesResponse execute() { return restApi.blockingImagesGenerations(generateImagesRequest, - OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)); + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()); } @Override @@ -363,7 +452,11 @@ public AsyncResponseHandling onResponse(Consumer respons @Override public Uni get() { return restApi.imagesGenerations(generateImagesRequest, - OpenAiRestApi.ApiMetadata.of(apiKey, apiVersion)); + OpenAiRestApi.ApiMetadata.builder() + .apiKey(apiKey) + .apiVersion(apiVersion) + .organizationId(organizationId) + .build()); } }, responseHandler); @@ -404,8 +497,9 @@ public boolean equals(Object o) { && logStreamingResponses == builder.logStreamingResponses && Objects.equals(baseUrl, builder.baseUrl) && Objects.equals(apiVersion, builder.apiVersion) && Objects.equals(openAiApiKey, builder.openAiApiKey) - && Objects.equals(azureApiKey, builder.azureApiKey) && Objects.equals( - callTimeout, builder.callTimeout) + && Objects.equals(azureApiKey, builder.azureApiKey) + && Objects.equals(organizationId, builder.organizationId) + && Objects.equals(callTimeout, builder.callTimeout) && Objects.equals(connectTimeout, builder.connectTimeout) && Objects.equals(readTimeout, builder.readTimeout) && Objects.equals(writeTimeout, builder.writeTimeout) @@ -414,7 +508,8 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(baseUrl, apiVersion, openAiApiKey, azureApiKey, callTimeout, connectTimeout, readTimeout, + return Objects.hash(baseUrl, apiVersion, openAiApiKey, azureApiKey, organizationId, callTimeout, connectTimeout, + readTimeout, writeTimeout, proxy, logRequests, logResponses, logStreamingResponses); } } diff --git a/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/OpenAiRestApiSmokeTest.java b/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/OpenAiRestApiSmokeTest.java index 3543b11d2..f8197c4b3 100644 --- a/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/OpenAiRestApiSmokeTest.java +++ b/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/OpenAiRestApiSmokeTest.java @@ -2,8 +2,7 @@ import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.*; import java.net.URI; import java.net.URISyntaxException; @@ -31,6 +30,7 @@ public class OpenAiRestApiSmokeTest { static final QuarkusUnitTest unitTest = new QuarkusUnitTest() .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WiremockUtils.class)); private static final String TOKEN = "whatever"; + private static final String ORGANIZATION = "org"; static WireMockServer wireMockServer; @@ -52,8 +52,10 @@ void happyPath() throws URISyntaxException { OpenAiRestApi restApi = createClient(); ChatCompletionResponse response = restApi.blockingChatCompletion(ChatCompletionRequest.builder().build(), - OpenAiRestApi.ApiMetadata.of(TOKEN, null)); + OpenAiRestApi.ApiMetadata.builder().apiKey(TOKEN).organizationId(ORGANIZATION).build()); assertThat(response).isNotNull(); + + wireMockServer.verify(WiremockUtils.chatCompletionRequestPattern(TOKEN, ORGANIZATION)); } @Test @@ -67,10 +69,14 @@ void server500() throws URISyntaxException { OpenAiRestApi restApi = createClient(); assertThatThrownBy(() -> restApi.blockingChatCompletion(ChatCompletionRequest.builder().build(), - OpenAiRestApi.ApiMetadata.of(TOKEN, null))) + OpenAiRestApi.ApiMetadata.builder().apiKey(TOKEN).build())) .isInstanceOf( OpenAiHttpException.class) .hasMessage("This is a dummy error message"); + + wireMockServer.verify( + WiremockUtils.chatCompletionRequestPattern(TOKEN) + .withoutHeader("OpenAI-Organization")); } @Test @@ -94,7 +100,7 @@ void server200ButAPIError() throws URISyntaxException { OpenAiRestApi restApi = createClient(); assertThatThrownBy(() -> restApi.blockingChatCompletion(ChatCompletionRequest.builder().build(), - OpenAiRestApi.ApiMetadata.of(TOKEN, null))) + OpenAiRestApi.ApiMetadata.builder().apiKey(TOKEN).build())) .isInstanceOf( OpenAiApiException.class); } diff --git a/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/WiremockUtils.java b/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/WiremockUtils.java index ee0b0eb4c..ec1fb4582 100644 --- a/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/WiremockUtils.java +++ b/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/WiremockUtils.java @@ -1,9 +1,6 @@ package io.quarkiverse.langchain4j.openai.test; -import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; -import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; -import static com.github.tomakehurst.wiremock.client.WireMock.post; -import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static com.github.tomakehurst.wiremock.client.WireMock.*; import java.io.IOException; import java.io.InputStream; @@ -13,6 +10,7 @@ import com.github.tomakehurst.wiremock.client.MappingBuilder; import com.github.tomakehurst.wiremock.client.ResponseDefinitionBuilder; +import com.github.tomakehurst.wiremock.matching.RequestPatternBuilder; import io.quarkus.bootstrap.classloading.QuarkusClassLoader; @@ -54,6 +52,16 @@ public static MappingBuilder chatCompletionMapping(String token) { .withHeader("Authorization", equalTo("Bearer " + token)); } + public static RequestPatternBuilder chatCompletionRequestPattern(String token) { + return postRequestedFor(urlEqualTo("/v1/chat/completions")) + .withHeader("Authorization", equalTo("Bearer " + token)); + } + + public static RequestPatternBuilder chatCompletionRequestPattern(String token, String organization) { + return chatCompletionRequestPattern(token) + .withHeader("OpenAI-Organization", equalTo(organization)); + } + public static MappingBuilder moderationMapping(String token) { return post(urlEqualTo("/v1/moderations")) .withHeader("Authorization", equalTo("Bearer " + token)); diff --git a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java index 346372cd9..6ea240e64 100644 --- a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java +++ b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java @@ -38,13 +38,14 @@ public Supplier chatModel(Langchain4jOpenAiConfig runtimeConfig) { .maxRetries(runtimeConfig.maxRetries()) .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests())) .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses())) - .modelName(chatModelConfig.modelName()) .temperature(chatModelConfig.temperature()) .topP(chatModelConfig.topP()) .presencePenalty(chatModelConfig.presencePenalty()) .frequencyPenalty(chatModelConfig.frequencyPenalty()); + runtimeConfig.organizationId().ifPresent(builder::organizationId); + if (chatModelConfig.maxTokens().isPresent()) { builder.maxTokens(chatModelConfig.maxTokens().get()); } @@ -69,13 +70,14 @@ public Supplier streamingChatModel(Langchain4jOpenAiConfig runtimeConfig) { .timeout(runtimeConfig.timeout()) .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests())) .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses())) - .modelName(chatModelConfig.modelName()) .temperature(chatModelConfig.temperature()) .topP(chatModelConfig.topP()) .presencePenalty(chatModelConfig.presencePenalty()) .frequencyPenalty(chatModelConfig.frequencyPenalty()); + runtimeConfig.organizationId().ifPresent(builder::organizationId); + if (chatModelConfig.maxTokens().isPresent()) { builder.maxTokens(chatModelConfig.maxTokens().get()); } @@ -101,9 +103,10 @@ public Supplier embeddingModel(Langchain4jOpenAiConfig runtimeConfig) { .maxRetries(runtimeConfig.maxRetries()) .logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), runtimeConfig.logRequests())) .logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), runtimeConfig.logResponses())) - .modelName(embeddingModelConfig.modelName()); + runtimeConfig.organizationId().ifPresent(builder::organizationId); + return new Supplier<>() { @Override public Object get() { @@ -125,9 +128,10 @@ public Supplier moderationModel(Langchain4jOpenAiConfig runtimeConfig) { .maxRetries(runtimeConfig.maxRetries()) .logRequests(firstOrDefault(false, moderationModelConfig.logRequests(), runtimeConfig.logRequests())) .logResponses(firstOrDefault(false, moderationModelConfig.logResponses(), runtimeConfig.logResponses())) - .modelName(moderationModelConfig.modelName()); + runtimeConfig.organizationId().ifPresent(builder::organizationId); + return new Supplier<>() { @Override public Object get() { @@ -149,7 +153,6 @@ public Supplier imageModel(Langchain4jOpenAiConfig runtimeConfig) { .maxRetries(runtimeConfig.maxRetries()) .logRequests(firstOrDefault(false, imageModelConfig.logRequests(), runtimeConfig.logRequests())) .logResponses(firstOrDefault(false, imageModelConfig.logResponses(), runtimeConfig.logResponses())) - .modelName(imageModelConfig.modelName()) .size(imageModelConfig.size()) .quality(imageModelConfig.quality()) @@ -157,6 +160,8 @@ public Supplier imageModel(Langchain4jOpenAiConfig runtimeConfig) { .responseFormat(imageModelConfig.responseFormat()) .user(imageModelConfig.user()); + runtimeConfig.organizationId().ifPresent(builder::organizationId); + // we persist if the directory was set explicitly and the boolean flag was not set to false // or if the boolean flag was set explicitly to true Optional persistDirectory = Optional.empty(); diff --git a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/Langchain4jOpenAiConfig.java b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/Langchain4jOpenAiConfig.java index dfee8791c..0b3e172d0 100644 --- a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/Langchain4jOpenAiConfig.java +++ b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/Langchain4jOpenAiConfig.java @@ -25,6 +25,11 @@ public interface Langchain4jOpenAiConfig { */ Optional apiKey(); + /** + * OpenAI Organization ID (https://platform.openai.com/docs/api-reference/organization-optional) + */ + Optional organizationId(); + /** * Timeout for OpenAI calls */