Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make response handling more generic. Resulting in generic cost handling. #1138

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

import java.util.Optional;

import jakarta.inject.Singleton;

import org.jboss.jandex.DotName;

import io.quarkiverse.langchain4j.cost.CostEstimatorResponseListener;
import io.quarkiverse.langchain4j.deployment.config.LangChain4jBuildConfig;
import io.quarkiverse.langchain4j.runtime.listeners.MetricsChatModelListener;
import io.quarkiverse.langchain4j.runtime.listeners.SpanChatModelListener;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
Expand All @@ -14,6 +20,20 @@

public class ListenersProcessor {

@BuildStep
public void costListener(
LangChain4jBuildConfig config,
BuildProducer<AdditionalBeanBuildItem> additionalBeanProducer) {
if (config.costListener()) {
additionalBeanProducer.produce(
AdditionalBeanBuildItem.builder()
.addBeanClass(CostEstimatorResponseListener.class)
.setDefaultScope(DotName.createSimple(Singleton.class))
.setUnremovable()
.build());
}
}

@BuildStep
public void spanListeners(Capabilities capabilities,
Optional<MetricsCapabilityBuildItem> metricsCapability,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ public interface LangChain4jBuildConfig {
@WithDefault("true")
boolean responseSchema();

/**
* Configuration property to enable or disable generic cost listener
*/
@WithDefault("false")
boolean costListener();

interface BaseConfig {
/**
* Chat model
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package io.quarkiverse.langchain4j.cost;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;

import jakarta.inject.Inject;

import dev.langchain4j.model.output.TokenUsage;
import io.quarkiverse.langchain4j.response.ResponseListener;
import io.quarkiverse.langchain4j.response.ResponseRecord;
import io.quarkus.arc.All;
import io.smallrye.common.annotation.Experimental;

/**
* Allows for user code to provide a custom strategy for estimating the cost of API calls
*/
@Experimental("This feature is experimental and the API is subject to change")
public class CostEstimatorResponseListener implements ResponseListener {

private final CostEstimatorService service;
private final List<CostListener> listeners;

@Inject
public CostEstimatorResponseListener(CostEstimatorService service, @All List<CostListener> listeners) {
this.service = service;
this.listeners = new ArrayList<>(listeners);
this.listeners.sort(Comparator.comparingInt(CostListener::order));
}

@Override
public void onResponse(ResponseRecord rr) {
String model = rr.model();
TokenUsage tokenUsage = rr.tokenUsage();
CostEstimator.CostContext context = new MyCostContext(tokenUsage, model);
Cost cost = service.estimate(context);
if (cost != null) {
for (CostListener cl : listeners) {
cl.handleCost(model, tokenUsage, cost);
}
}
}

private record MyCostContext(TokenUsage tokenUsage, String model) implements CostEstimator.CostContext {
@Override
public Integer inputTokens() {
return tokenUsage().inputTokenCount();
}

@Override
public Integer outputTokens() {
return tokenUsage().outputTokenCount();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ public CostEstimatorService(@All List<CostEstimator> costEstimators) {
public Cost estimate(ChatModelResponseContext response) {
TokenUsage tokenUsage = response.response().tokenUsage();
CostEstimator.CostContext costContext = new MyCostContext(tokenUsage, response);
return estimate(costContext);
}

public Cost estimate(CostEstimator.CostContext context) {
for (CostEstimator costEstimator : costEstimators) {
if (costEstimator.supports(costContext)) {
CostEstimator.CostResult costResult = costEstimator.estimate(costContext);
if (costEstimator.supports(context)) {
CostEstimator.CostResult costResult = costEstimator.estimate(context);
if (costResult != null) {
BigDecimal totalCost = costResult.inputTokensCost().add(costResult.outputTokensCost());
return new Cost(totalCost, costResult.currency());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package io.quarkiverse.langchain4j.cost;

import dev.langchain4j.model.output.TokenUsage;

/**
* Allows for user code to handle estimate cost; e.g. some simple accounting
*/
public interface CostListener {
void handleCost(String model, TokenUsage tokenUsage, Cost cost);

default int order() {
return 0;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package io.quarkiverse.langchain4j.response;

import java.util.Map;

import jakarta.annotation.Priority;
import jakarta.interceptor.AroundInvoke;
import jakarta.interceptor.Interceptor;
import jakarta.interceptor.InvocationContext;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.output.Response;

/**
* Simple (Chat)Response interceptor, to be applied directly on the model.
*/
@Interceptor
@ResponseInterceptorBinding
@Priority(0)
public class ResponseInterceptor extends ResponseInterceptorBase {

@AroundInvoke
public Object intercept(InvocationContext context) throws Exception {
Object result = context.proceed();
ResponseRecord rr = null;
if (result instanceof Response<?> response) {
Object content = response.content();
if (content instanceof AiMessage am) {
rr = new ResponseRecord(getModel(context.getTarget()), am, response.tokenUsage(), response.finishReason(),
response.metadata());
}
} else if (result instanceof ChatResponse response) {
rr = new ResponseRecord(getModel(context.getTarget()), response.aiMessage(), response.tokenUsage(),
response.finishReason(), Map.of());
} else if (result instanceof ChatModelResponse response) {
rr = new ResponseRecord(response.model(), response.aiMessage(), response.tokenUsage(), response.finishReason(),
Map.of("id", response.id()));
}
if (rr != null) {
for (ResponseListener l : getListeners()) {
l.onResponse(rr);
}
}
return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package io.quarkiverse.langchain4j.response;

import java.lang.reflect.Method;
import java.util.Comparator;
import java.util.List;

import jakarta.enterprise.inject.Any;
import jakarta.enterprise.inject.spi.CDI;

/**
* Simple (Chat)Response interceptor base, to be applied directly on the model.
*/
public abstract class ResponseInterceptorBase {

private volatile String model;
private volatile List<ResponseListener> listeners;

// TODO -- uh uh ... reflection ... puke
protected String getModel(Object target) {
if (model == null) {
try {
Class<?> clazz = target.getClass();
Method method = clazz.getMethod("modelName");
model = (String) method.invoke(target);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
return model;
}

protected List<ResponseListener> getListeners() {
if (listeners == null) {
listeners = CDI.current().select(ResponseListener.class, Any.Literal.INSTANCE)
.stream()
.sorted(Comparator.comparing(ResponseListener::order))
.toList();
}
return listeners;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package io.quarkiverse.langchain4j.response;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import jakarta.interceptor.InterceptorBinding;

@InterceptorBinding
@Target({ ElementType.TYPE, ElementType.METHOD })
@Retention(RetentionPolicy.RUNTIME)
public @interface ResponseInterceptorBinding {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package io.quarkiverse.langchain4j.response;

@ResponseInterceptorBinding
public abstract class ResponseInterceptorBindingSource {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package io.quarkiverse.langchain4j.response;

/**
* Simple ResponseRecord listener, to be implemented by the (advanced) users.
*/
public interface ResponseListener {
void onResponse(ResponseRecord response);

default int order() {
return 0;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package io.quarkiverse.langchain4j.response;

import java.util.Map;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;

/**
* Abstract away Response vs ChatResponse.
*/
public record ResponseRecord(
String model,
AiMessage content,
TokenUsage tokenUsage,
FinishReason finishReason,
Map<String, Object> metadata) {
}
17 changes: 17 additions & 0 deletions docs/modules/ROOT/pages/includes/quarkus-langchain4j-core.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,23 @@ endif::add-copy-button-to-env-var[]
|boolean
|`true`

a|icon:lock[title=Fixed at build time] [[quarkus-langchain4j-core_quarkus-langchain4j-cost-listener]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-cost-listener[`quarkus.langchain4j.cost-listener`]##

[.description]
--
Configuration property to enable or disable generic cost listener


ifdef::add-copy-button-to-env-var[]
Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++[]
endif::add-copy-button-to-env-var[]
ifndef::add-copy-button-to-env-var[]
Environment variable: `+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++`
endif::add-copy-button-to-env-var[]
--
|boolean
|`false`

a| [[quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages[`quarkus.langchain4j.chat-memory.memory-window.max-messages`]##

[.description]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,23 @@ endif::add-copy-button-to-env-var[]
|boolean
|`true`

a|icon:lock[title=Fixed at build time] [[quarkus-langchain4j-core_quarkus-langchain4j-cost-listener]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-cost-listener[`quarkus.langchain4j.cost-listener`]##

[.description]
--
Configuration property to enable or disable generic cost listener


ifdef::add-copy-button-to-env-var[]
Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++[]
endif::add-copy-button-to-env-var[]
ifndef::add-copy-button-to-env-var[]
Environment variable: `+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++`
endif::add-copy-button-to-env-var[]
--
|boolean
|`false`

a| [[quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages[`quarkus.langchain4j.chat-memory.memory-window.max-messages`]##

[.description]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ public class MultipleChatProvidersTest {

@Test
void defaultModel() {
assertThat(ClientProxy.unwrap(defaultModel)).isInstanceOf(OpenAiChatModel.class);
assertThat(SubclassUtil.unwrap(defaultModel)).isInstanceOf(OpenAiChatModel.class);
}

@Test
void firstNamedModel() {
assertThat(ClientProxy.unwrap(firstNamedModel)).isInstanceOf(OpenAiChatModel.class);
assertThat(SubclassUtil.unwrap(firstNamedModel)).isInstanceOf(OpenAiChatModel.class);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public class MultipleEmbeddingModelsTest {

@Test
void firstNamedModel() {
assertThat(ClientProxy.unwrap(firstNamedModel)).isInstanceOf(OpenAiEmbeddingModel.class);
assertThat(SubclassUtil.unwrap(firstNamedModel)).isInstanceOf(OpenAiEmbeddingModel.class);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.model.openai.OpenAiModerationModel;
import io.quarkiverse.langchain4j.ModelName;
import io.quarkus.arc.ClientProxy;
import io.quarkus.test.junit.QuarkusTest;

@QuarkusTest
Expand All @@ -30,12 +29,12 @@ public class MultipleModerationProvidersTest {

@Test
void defaultModel() {
assertThat(ClientProxy.unwrap(defaultModel)).isInstanceOf(OpenAiModerationModel.class);
assertThat(SubclassUtil.unwrap(defaultModel)).isInstanceOf(OpenAiModerationModel.class);
}

@Test
void firstNamedModel() {
assertThat(ClientProxy.unwrap(firstNamedModel)).isInstanceOf(OpenAiModerationModel.class);
assertThat(SubclassUtil.unwrap(firstNamedModel)).isInstanceOf(OpenAiModerationModel.class);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package org.acme.example.multiple;

import java.lang.reflect.Field;

import io.quarkus.arc.ClientProxy;
import io.quarkus.arc.Subclass;

public class SubclassUtil {

public static <T> T unwrap(T target) {
T sub = ClientProxy.unwrap(target);
if (sub instanceof Subclass) {
try {
Field delegate = sub.getClass().getDeclaredField("delegate");
delegate.setAccessible(true);
sub = (T) delegate.get(sub);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
return sub;
}

}
Loading
Loading