From 1ca082d0336828e586c25c57da8d9342fa14fc96 Mon Sep 17 00:00:00 2001 From: Ales Justin Date: Sat, 30 Nov 2024 14:21:53 +0100 Subject: [PATCH] Make response handling more generic. Resulting in generic cost handling. --- .../deployment/ListenersProcessor.java | 20 +++++++ .../config/LangChain4jBuildConfig.java | 6 ++ .../cost/CostEstimatorResponseListener.java | 55 +++++++++++++++++++ .../cost/CostEstimatorService.java | 7 ++- .../langchain4j/cost/CostListener.java | 14 +++++ .../response/ResponseInterceptor.java | 47 ++++++++++++++++ .../response/ResponseInterceptorBase.java | 41 ++++++++++++++ .../response/ResponseInterceptorBinding.java | 14 +++++ .../ResponseInterceptorBindingSource.java | 5 ++ .../response/ResponseListener.java | 12 ++++ .../langchain4j/response/ResponseRecord.java | 18 ++++++ .../includes/quarkus-langchain4j-core.adoc | 17 ++++++ ...-langchain4j-core_quarkus.langchain4j.adoc | 17 ++++++ .../multiple/MultipleChatProvidersTest.java | 4 +- .../multiple/MultipleEmbeddingModelsTest.java | 2 +- .../MultipleModerationProvidersTest.java | 5 +- .../acme/example/multiple/SubclassUtil.java | 24 ++++++++ .../openai/deployment/OpenAiProcessor.java | 13 +++-- .../openai/runtime/OpenAiRecorder.java | 46 +++++++++------- .../DisabledModelsOpenAiRecorderTest.java | 6 +- .../sample/chatbot/MovieMuseCostListener.java | 15 +++++ .../src/main/resources/application.properties | 2 + 22 files changed, 356 insertions(+), 34 deletions(-) create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorResponseListener.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostListener.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptor.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBase.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBinding.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBindingSource.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseListener.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseRecord.java create mode 100644 integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/SubclassUtil.java create mode 100644 samples/sql-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/MovieMuseCostListener.java diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ListenersProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ListenersProcessor.java index 872525073..ae836050b 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ListenersProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ListenersProcessor.java @@ -2,6 +2,12 @@ import java.util.Optional; +import jakarta.inject.Singleton; + +import org.jboss.jandex.DotName; + +import io.quarkiverse.langchain4j.cost.CostEstimatorResponseListener; +import io.quarkiverse.langchain4j.deployment.config.LangChain4jBuildConfig; import io.quarkiverse.langchain4j.runtime.listeners.MetricsChatModelListener; import io.quarkiverse.langchain4j.runtime.listeners.SpanChatModelListener; import io.quarkus.arc.deployment.AdditionalBeanBuildItem; @@ -14,6 +20,20 @@ public class ListenersProcessor { + @BuildStep + public void costListener( + LangChain4jBuildConfig config, + BuildProducer additionalBeanProducer) { + if (config.costListener()) { + additionalBeanProducer.produce( + AdditionalBeanBuildItem.builder() + .addBeanClass(CostEstimatorResponseListener.class) + .setDefaultScope(DotName.createSimple(Singleton.class)) + .setUnremovable() + .build()); + } + } + @BuildStep public void spanListeners(Capabilities capabilities, Optional metricsCapability, diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/LangChain4jBuildConfig.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/LangChain4jBuildConfig.java index 1eb4448d1..b0040fec5 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/LangChain4jBuildConfig.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/LangChain4jBuildConfig.java @@ -43,6 +43,12 @@ public interface LangChain4jBuildConfig { @WithDefault("true") boolean responseSchema(); + /** + * Configuration property to enable or disable generic cost listener + */ + @WithDefault("false") + boolean costListener(); + interface BaseConfig { /** * Chat model diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorResponseListener.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorResponseListener.java new file mode 100644 index 000000000..2fb73fb13 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorResponseListener.java @@ -0,0 +1,55 @@ +package io.quarkiverse.langchain4j.cost; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +import jakarta.inject.Inject; + +import dev.langchain4j.model.output.TokenUsage; +import io.quarkiverse.langchain4j.response.ResponseListener; +import io.quarkiverse.langchain4j.response.ResponseRecord; +import io.quarkus.arc.All; +import io.smallrye.common.annotation.Experimental; + +/** + * Allows for user code to provide a custom strategy for estimating the cost of API calls + */ +@Experimental("This feature is experimental and the API is subject to change") +public class CostEstimatorResponseListener implements ResponseListener { + + private final CostEstimatorService service; + private final List listeners; + + @Inject + public CostEstimatorResponseListener(CostEstimatorService service, @All List listeners) { + this.service = service; + this.listeners = new ArrayList<>(listeners); + this.listeners.sort(Comparator.comparingInt(CostListener::order)); + } + + @Override + public void onResponse(ResponseRecord rr) { + String model = rr.model(); + TokenUsage tokenUsage = rr.tokenUsage(); + CostEstimator.CostContext context = new MyCostContext(tokenUsage, model); + Cost cost = service.estimate(context); + if (cost != null) { + for (CostListener cl : listeners) { + cl.handleCost(model, tokenUsage, cost); + } + } + } + + private record MyCostContext(TokenUsage tokenUsage, String model) implements CostEstimator.CostContext { + @Override + public Integer inputTokens() { + return tokenUsage().inputTokenCount(); + } + + @Override + public Integer outputTokens() { + return tokenUsage().outputTokenCount(); + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorService.java index 7c499fdd6..cb36a5ed3 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorService.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorService.java @@ -28,10 +28,13 @@ public CostEstimatorService(@All List costEstimators) { public Cost estimate(ChatModelResponseContext response) { TokenUsage tokenUsage = response.response().tokenUsage(); CostEstimator.CostContext costContext = new MyCostContext(tokenUsage, response); + return estimate(costContext); + } + public Cost estimate(CostEstimator.CostContext context) { for (CostEstimator costEstimator : costEstimators) { - if (costEstimator.supports(costContext)) { - CostEstimator.CostResult costResult = costEstimator.estimate(costContext); + if (costEstimator.supports(context)) { + CostEstimator.CostResult costResult = costEstimator.estimate(context); if (costResult != null) { BigDecimal totalCost = costResult.inputTokensCost().add(costResult.outputTokensCost()); return new Cost(totalCost, costResult.currency()); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostListener.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostListener.java new file mode 100644 index 000000000..bd21c6b90 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostListener.java @@ -0,0 +1,14 @@ +package io.quarkiverse.langchain4j.cost; + +import dev.langchain4j.model.output.TokenUsage; + +/** + * Allows for user code to handle estimate cost; e.g. some simple accounting + */ +public interface CostListener { + void handleCost(String model, TokenUsage tokenUsage, Cost cost); + + default int order() { + return 0; + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptor.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptor.java new file mode 100644 index 000000000..d8cfcbda2 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptor.java @@ -0,0 +1,47 @@ +package io.quarkiverse.langchain4j.response; + +import java.util.Map; + +import jakarta.annotation.Priority; +import jakarta.interceptor.AroundInvoke; +import jakarta.interceptor.Interceptor; +import jakarta.interceptor.InvocationContext; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.listener.ChatModelResponse; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.output.Response; + +/** + * Simple (Chat)Response interceptor, to be applied directly on the model. + */ +@Interceptor +@ResponseInterceptorBinding +@Priority(0) +public class ResponseInterceptor extends ResponseInterceptorBase { + + @AroundInvoke + public Object intercept(InvocationContext context) throws Exception { + Object result = context.proceed(); + ResponseRecord rr = null; + if (result instanceof Response response) { + Object content = response.content(); + if (content instanceof AiMessage am) { + rr = new ResponseRecord(getModel(context.getTarget()), am, response.tokenUsage(), response.finishReason(), + response.metadata()); + } + } else if (result instanceof ChatResponse response) { + rr = new ResponseRecord(getModel(context.getTarget()), response.aiMessage(), response.tokenUsage(), + response.finishReason(), Map.of()); + } else if (result instanceof ChatModelResponse response) { + rr = new ResponseRecord(response.model(), response.aiMessage(), response.tokenUsage(), response.finishReason(), + Map.of("id", response.id())); + } + if (rr != null) { + for (ResponseListener l : getListeners()) { + l.onResponse(rr); + } + } + return result; + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBase.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBase.java new file mode 100644 index 000000000..b3ace1fb7 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBase.java @@ -0,0 +1,41 @@ +package io.quarkiverse.langchain4j.response; + +import java.lang.reflect.Method; +import java.util.Comparator; +import java.util.List; + +import jakarta.enterprise.inject.Any; +import jakarta.enterprise.inject.spi.CDI; + +/** + * Simple (Chat)Response interceptor base, to be applied directly on the model. + */ +public abstract class ResponseInterceptorBase { + + private volatile String model; + private volatile List listeners; + + // TODO -- uh uh ... reflection ... puke + protected String getModel(Object target) { + if (model == null) { + try { + Class clazz = target.getClass(); + Method method = clazz.getMethod("modelName"); + model = (String) method.invoke(target); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return model; + } + + protected List getListeners() { + if (listeners == null) { + listeners = CDI.current().select(ResponseListener.class, Any.Literal.INSTANCE) + .stream() + .sorted(Comparator.comparing(ResponseListener::order)) + .toList(); + } + return listeners; + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBinding.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBinding.java new file mode 100644 index 000000000..986c1d2db --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBinding.java @@ -0,0 +1,14 @@ +package io.quarkiverse.langchain4j.response; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import jakarta.interceptor.InterceptorBinding; + +@InterceptorBinding +@Target({ ElementType.TYPE, ElementType.METHOD }) +@Retention(RetentionPolicy.RUNTIME) +public @interface ResponseInterceptorBinding { +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBindingSource.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBindingSource.java new file mode 100644 index 000000000..05fac2f3d --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBindingSource.java @@ -0,0 +1,5 @@ +package io.quarkiverse.langchain4j.response; + +@ResponseInterceptorBinding +public abstract class ResponseInterceptorBindingSource { +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseListener.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseListener.java new file mode 100644 index 000000000..ab7966f68 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseListener.java @@ -0,0 +1,12 @@ +package io.quarkiverse.langchain4j.response; + +/** + * Simple ResponseRecord listener, to be implemented by the (advanced) users. + */ +public interface ResponseListener { + void onResponse(ResponseRecord response); + + default int order() { + return 0; + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseRecord.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseRecord.java new file mode 100644 index 000000000..81c717174 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseRecord.java @@ -0,0 +1,18 @@ +package io.quarkiverse.langchain4j.response; + +import java.util.Map; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.TokenUsage; + +/** + * Abstract away Response vs ChatResponse. + */ +public record ResponseRecord( + String model, + AiMessage content, + TokenUsage tokenUsage, + FinishReason finishReason, + Map metadata) { +} diff --git a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core.adoc b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core.adoc index cd612e282..332fa2c73 100644 --- a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core.adoc +++ b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core.adoc @@ -101,6 +101,23 @@ endif::add-copy-button-to-env-var[] |boolean |`true` +a|icon:lock[title=Fixed at build time] [[quarkus-langchain4j-core_quarkus-langchain4j-cost-listener]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-cost-listener[`quarkus.langchain4j.cost-listener`]## + +[.description] +-- +Configuration property to enable or disable generic cost listener + + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++` +endif::add-copy-button-to-env-var[] +-- +|boolean +|`false` + a| [[quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages[`quarkus.langchain4j.chat-memory.memory-window.max-messages`]## [.description] diff --git a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core_quarkus.langchain4j.adoc b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core_quarkus.langchain4j.adoc index cd612e282..332fa2c73 100644 --- a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core_quarkus.langchain4j.adoc +++ b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core_quarkus.langchain4j.adoc @@ -101,6 +101,23 @@ endif::add-copy-button-to-env-var[] |boolean |`true` +a|icon:lock[title=Fixed at build time] [[quarkus-langchain4j-core_quarkus-langchain4j-cost-listener]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-cost-listener[`quarkus.langchain4j.cost-listener`]## + +[.description] +-- +Configuration property to enable or disable generic cost listener + + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++` +endif::add-copy-button-to-env-var[] +-- +|boolean +|`false` + a| [[quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages[`quarkus.langchain4j.chat-memory.memory-window.max-messages`]## [.description] diff --git a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java index 857cd61c9..03e69150e 100644 --- a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java +++ b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java @@ -64,12 +64,12 @@ public class MultipleChatProvidersTest { @Test void defaultModel() { - assertThat(ClientProxy.unwrap(defaultModel)).isInstanceOf(OpenAiChatModel.class); + assertThat(SubclassUtil.unwrap(defaultModel)).isInstanceOf(OpenAiChatModel.class); } @Test void firstNamedModel() { - assertThat(ClientProxy.unwrap(firstNamedModel)).isInstanceOf(OpenAiChatModel.class); + assertThat(SubclassUtil.unwrap(firstNamedModel)).isInstanceOf(OpenAiChatModel.class); } @Test diff --git a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleEmbeddingModelsTest.java b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleEmbeddingModelsTest.java index 6d4a05fe5..172b5eef2 100644 --- a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleEmbeddingModelsTest.java +++ b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleEmbeddingModelsTest.java @@ -46,7 +46,7 @@ public class MultipleEmbeddingModelsTest { @Test void firstNamedModel() { - assertThat(ClientProxy.unwrap(firstNamedModel)).isInstanceOf(OpenAiEmbeddingModel.class); + assertThat(SubclassUtil.unwrap(firstNamedModel)).isInstanceOf(OpenAiEmbeddingModel.class); } @Test diff --git a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleModerationProvidersTest.java b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleModerationProvidersTest.java index a02bf5032..5ba8cb320 100644 --- a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleModerationProvidersTest.java +++ b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleModerationProvidersTest.java @@ -11,7 +11,6 @@ import dev.langchain4j.model.moderation.ModerationModel; import dev.langchain4j.model.openai.OpenAiModerationModel; import io.quarkiverse.langchain4j.ModelName; -import io.quarkus.arc.ClientProxy; import io.quarkus.test.junit.QuarkusTest; @QuarkusTest @@ -30,12 +29,12 @@ public class MultipleModerationProvidersTest { @Test void defaultModel() { - assertThat(ClientProxy.unwrap(defaultModel)).isInstanceOf(OpenAiModerationModel.class); + assertThat(SubclassUtil.unwrap(defaultModel)).isInstanceOf(OpenAiModerationModel.class); } @Test void firstNamedModel() { - assertThat(ClientProxy.unwrap(firstNamedModel)).isInstanceOf(OpenAiModerationModel.class); + assertThat(SubclassUtil.unwrap(firstNamedModel)).isInstanceOf(OpenAiModerationModel.class); } @Test diff --git a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/SubclassUtil.java b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/SubclassUtil.java new file mode 100644 index 000000000..e3b5e9feb --- /dev/null +++ b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/SubclassUtil.java @@ -0,0 +1,24 @@ +package org.acme.example.multiple; + +import java.lang.reflect.Field; + +import io.quarkus.arc.ClientProxy; +import io.quarkus.arc.Subclass; + +public class SubclassUtil { + + public static T unwrap(T target) { + T sub = ClientProxy.unwrap(target); + if (sub instanceof Subclass) { + try { + Field delegate = sub.getClass().getDeclaredField("delegate"); + delegate.setAccessible(true); + sub = (T) delegate.get(sub); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return sub; + } + +} diff --git a/model-providers/openai/openai-vanilla/deployment/src/main/java/io/quarkiverse/langchain4j/openai/deployment/OpenAiProcessor.java b/model-providers/openai/openai-vanilla/deployment/src/main/java/io/quarkiverse/langchain4j/openai/deployment/OpenAiProcessor.java index 0bdfa48ca..2c0438a3d 100644 --- a/model-providers/openai/openai-vanilla/deployment/src/main/java/io/quarkiverse/langchain4j/openai/deployment/OpenAiProcessor.java +++ b/model-providers/openai/openai-vanilla/deployment/src/main/java/io/quarkiverse/langchain4j/openai/deployment/OpenAiProcessor.java @@ -35,6 +35,7 @@ import io.quarkiverse.langchain4j.openai.QuarkusOpenAiStreamingChatModelBuilderFactory; import io.quarkiverse.langchain4j.openai.runtime.OpenAiRecorder; import io.quarkiverse.langchain4j.openai.runtime.config.LangChain4jOpenAiConfig; +import io.quarkiverse.langchain4j.response.ResponseInterceptorBindingSource; import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.deployment.annotations.BuildProducer; @@ -96,7 +97,8 @@ void generateBeans(OpenAiRecorder recorder, .scope(ApplicationScoped.class) .addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE, new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null)) - .createWith(recorder.chatModel(config, configName)); + .createWith(recorder.chatModel(config, configName)) + .injectInterceptionProxy(ResponseInterceptorBindingSource.class); addQualifierIfNecessary(builder, configName); beanProducer.produce(builder.done()); @@ -122,7 +124,8 @@ void generateBeans(OpenAiRecorder recorder, .defaultBean() .unremovable() .scope(ApplicationScoped.class) - .supplier(recorder.embeddingModel(config, configName)); + .createWith(recorder.embeddingModel(config, configName)) + .injectInterceptionProxy(ResponseInterceptorBindingSource.class); addQualifierIfNecessary(builder, configName); beanProducer.produce(builder.done()); } @@ -136,7 +139,8 @@ void generateBeans(OpenAiRecorder recorder, .setRuntimeInit() .defaultBean() .scope(ApplicationScoped.class) - .supplier(recorder.moderationModel(config, configName)); + .createWith(recorder.moderationModel(config, configName)) + .injectInterceptionProxy(ResponseInterceptorBindingSource.class); addQualifierIfNecessary(builder, configName); beanProducer.produce(builder.done()); } @@ -150,7 +154,8 @@ void generateBeans(OpenAiRecorder recorder, .setRuntimeInit() .defaultBean() .scope(ApplicationScoped.class) - .supplier(recorder.imageModel(config, configName)); + .createWith(recorder.imageModel(config, configName)) + .injectInterceptionProxy(ResponseInterceptorBindingSource.class); addQualifierIfNecessary(builder, configName); beanProducer.produce(builder.done()); } diff --git a/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java b/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java index 625315a7f..5033e2112 100644 --- a/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java +++ b/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java @@ -44,6 +44,7 @@ import io.quarkiverse.langchain4j.openai.runtime.config.LangChain4jOpenAiConfig; import io.quarkiverse.langchain4j.openai.runtime.config.ModerationModelConfig; import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; +import io.quarkus.arc.InterceptionProxy; import io.quarkus.arc.SyntheticCreationalContext; import io.quarkus.runtime.ShutdownContext; import io.quarkus.runtime.annotations.Recorder; @@ -102,7 +103,8 @@ public Function, ChatLanguageModel public ChatLanguageModel apply(SyntheticCreationalContext context) { builder.listeners(context.getInjectedReference(CHAT_MODEL_LISTENER_TYPE_LITERAL).stream() .collect(Collectors.toList())); - return builder.build(); + InterceptionProxy proxy = context.getInterceptionProxy(); + return proxy.create(builder.build()); } }; } else { @@ -175,7 +177,8 @@ public StreamingChatLanguageModel apply( } } - public Supplier embeddingModel(LangChain4jOpenAiConfig runtimeConfig, String configName) { + public Function, EmbeddingModel> embeddingModel( + LangChain4jOpenAiConfig runtimeConfig, String configName) { LangChain4jOpenAiConfig.OpenAiConfig openAiConfig = correspondingOpenAiConfig(runtimeConfig, configName); if (openAiConfig.enableIntegration()) { @@ -206,17 +209,18 @@ public Supplier embeddingModel(LangChain4jOpenAiConfig runtimeCo new InetSocketAddress(host, openAiConfig.proxyPort()))); }); - return new Supplier<>() { + return new Function<>() { @Override - public EmbeddingModel get() { - return builder.build(); + public EmbeddingModel apply(SyntheticCreationalContext context) { + InterceptionProxy proxy = context.getInterceptionProxy(); + return proxy.create(builder.build()); } }; } else { - return new Supplier<>() { + return new Function<>() { @Override - public EmbeddingModel get() { + public EmbeddingModel apply(SyntheticCreationalContext context) { return new DisabledEmbeddingModel(); } @@ -224,7 +228,8 @@ public EmbeddingModel get() { } } - public Supplier moderationModel(LangChain4jOpenAiConfig runtimeConfig, String configName) { + public Function, ModerationModel> moderationModel( + LangChain4jOpenAiConfig runtimeConfig, String configName) { LangChain4jOpenAiConfig.OpenAiConfig openAiConfig = correspondingOpenAiConfig(runtimeConfig, configName); if (openAiConfig.enableIntegration()) { @@ -251,17 +256,18 @@ public Supplier moderationModel(LangChain4jOpenAiConfig runtime new InetSocketAddress(host, openAiConfig.proxyPort()))); }); - return new Supplier<>() { + return new Function<>() { @Override - public ModerationModel get() { - return builder.build(); + public ModerationModel apply(SyntheticCreationalContext context) { + InterceptionProxy proxy = context.getInterceptionProxy(); + return proxy.create(builder.build()); } }; } else { - return new Supplier<>() { + return new Function<>() { @Override - public ModerationModel get() { + public ModerationModel apply(SyntheticCreationalContext context) { return new DisabledModerationModel(); } @@ -269,7 +275,8 @@ public ModerationModel get() { } } - public Supplier imageModel(LangChain4jOpenAiConfig runtimeConfig, String configName) { + public Function, ImageModel> imageModel(LangChain4jOpenAiConfig runtimeConfig, + String configName) { LangChain4jOpenAiConfig.OpenAiConfig openAiConfig = correspondingOpenAiConfig(runtimeConfig, configName); if (openAiConfig.enableIntegration()) { @@ -316,17 +323,18 @@ public Optional get() { builder.persistDirectory(persistDirectory); - return new Supplier<>() { + return new Function<>() { @Override - public ImageModel get() { - return builder.build(); + public ImageModel apply(SyntheticCreationalContext context) { + InterceptionProxy proxy = context.getInterceptionProxy(); + return proxy.create(builder.build()); } }; } else { - return new Supplier<>() { + return new Function<>() { @Override - public ImageModel get() { + public ImageModel apply(SyntheticCreationalContext context) { return new DisabledImageModel(); } }; diff --git a/model-providers/openai/openai-vanilla/runtime/src/test/java/io/quarkiverse/langchain4j/openai/runtime/DisabledModelsOpenAiRecorderTest.java b/model-providers/openai/openai-vanilla/runtime/src/test/java/io/quarkiverse/langchain4j/openai/runtime/DisabledModelsOpenAiRecorderTest.java index 6f0dfe3b6..81d97d90e 100644 --- a/model-providers/openai/openai-vanilla/runtime/src/test/java/io/quarkiverse/langchain4j/openai/runtime/DisabledModelsOpenAiRecorderTest.java +++ b/model-providers/openai/openai-vanilla/runtime/src/test/java/io/quarkiverse/langchain4j/openai/runtime/DisabledModelsOpenAiRecorderTest.java @@ -45,21 +45,21 @@ void disabledStreamingChatModel() { @Test void disabledEmbeddingModel() { - assertThat(recorder.embeddingModel(config, NamedConfigUtil.DEFAULT_NAME).get()) + assertThat(recorder.embeddingModel(config, NamedConfigUtil.DEFAULT_NAME).apply(null)) .isNotNull() .isExactlyInstanceOf(DisabledEmbeddingModel.class); } @Test void disabledImageModel() { - assertThat(recorder.imageModel(config, NamedConfigUtil.DEFAULT_NAME).get()) + assertThat(recorder.imageModel(config, NamedConfigUtil.DEFAULT_NAME).apply(null)) .isNotNull() .isExactlyInstanceOf(DisabledImageModel.class); } @Test void disabledModerationModel() { - assertThat(recorder.moderationModel(config, NamedConfigUtil.DEFAULT_NAME).get()) + assertThat(recorder.moderationModel(config, NamedConfigUtil.DEFAULT_NAME).apply(null)) .isNotNull() .isExactlyInstanceOf(DisabledModerationModel.class); } diff --git a/samples/sql-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/MovieMuseCostListener.java b/samples/sql-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/MovieMuseCostListener.java new file mode 100644 index 000000000..12de240c8 --- /dev/null +++ b/samples/sql-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/MovieMuseCostListener.java @@ -0,0 +1,15 @@ +package io.quarkiverse.langchain4j.sample.chatbot; + +import dev.langchain4j.model.output.TokenUsage; +import io.quarkiverse.langchain4j.cost.Cost; +import io.quarkiverse.langchain4j.cost.CostListener; +import jakarta.enterprise.context.ApplicationScoped; + +@ApplicationScoped +public class MovieMuseCostListener implements CostListener { + public void handleCost(String model, TokenUsage tokenUsage, Cost cost) { + System.out.println("model = " + model); + System.out.println("tokenUsage = " + tokenUsage); + System.out.println("cost = " + cost); + } +} diff --git a/samples/sql-chatbot/src/main/resources/application.properties b/samples/sql-chatbot/src/main/resources/application.properties index 8fb28bb44..75ba3b789 100644 --- a/samples/sql-chatbot/src/main/resources/application.properties +++ b/samples/sql-chatbot/src/main/resources/application.properties @@ -1,6 +1,8 @@ quarkus.langchain4j.timeout=60s csv.file=src/main/resources/data/movies.csv +quarkus.langchain4j.cost-listener=true + quarkus.hibernate-orm.database.generation=drop-and-create # if you want to log the requests and responses that go to OpenAI: