Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce observability in watsonx #1142

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.TOKEN_COUNT_ESTIMATOR;

import java.util.List;
import java.util.function.Supplier;
import java.util.function.Function;

import jakarta.enterprise.context.ApplicationScoped;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.ClassType;
import org.jboss.jandex.ParameterizedType;
import org.jboss.jandex.Type;

import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.deployment.DotNames;
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.ScoringModelProviderCandidateBuildItem;
Expand All @@ -26,6 +30,7 @@
import io.quarkiverse.langchain4j.watsonx.runtime.WatsonxRecorder;
import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig;
import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxFixedRuntimeConfig;
import io.quarkus.arc.SyntheticCreationalContext;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.deployment.Capabilities;
import io.quarkus.deployment.Capability;
Expand Down Expand Up @@ -86,8 +91,8 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
? fixedRuntimeConfig.defaultConfig().mode()
: fixedRuntimeConfig.namedConfig().get(configName).mode();

Supplier<ChatLanguageModel> chatLanguageModel;
Supplier<StreamingChatLanguageModel> streamingChatLanguageModel;
Function<SyntheticCreationalContext<ChatLanguageModel>, ChatLanguageModel> chatLanguageModel;
Function<SyntheticCreationalContext<StreamingChatLanguageModel>, StreamingChatLanguageModel> streamingChatLanguageModel;

if (mode.equalsIgnoreCase("chat")) {
chatLanguageModel = recorder.chatModel(runtimeConfig, configName);
Expand All @@ -106,7 +111,9 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(chatLanguageModel);
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null))
.createWith(chatLanguageModel);

addQualifierIfNecessary(chatBuilder, configName);
beanProducer.produce(chatBuilder.done());
Expand All @@ -116,7 +123,9 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(chatLanguageModel);
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null))
.createWith(chatLanguageModel);

addQualifierIfNecessary(tokenizerBuilder, configName);
beanProducer.produce(tokenizerBuilder.done());
Expand All @@ -126,7 +135,9 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(streamingChatLanguageModel);
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null))
.createWith(streamingChatLanguageModel);

addQualifierIfNecessary(streamingBuilder, configName);
beanProducer.produce(streamingBuilder.done());
Expand Down Expand Up @@ -171,9 +182,8 @@ private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigur

/**
* When both {@code rest-client-jackson} and {@code rest-client-jsonb} are present on the classpath we need to make sure
* that Jackson is used.
* This is not a proper solution as it affects all clients, but it's better than the having the reader/writers be selected
* at random.
* that Jackson is used. This is not a proper solution as it affects all clients, but it's better than the having the
* reader/writers be selected at random.
*/
@BuildStep
public void deprioritizeJsonb(Capabilities capabilities,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import static org.awaitility.Awaitility.await;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.fail;

import java.time.Duration;
import java.util.Date;
Expand All @@ -21,10 +22,12 @@
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.scoring.ScoringModel;
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingParameters;
Expand Down Expand Up @@ -268,7 +271,25 @@ void check_chat_streaming_model_config() throws Exception {
dev.langchain4j.data.message.UserMessage.from("UserMessage"));

var streamingResponse = new AtomicReference<AiMessage>();
streamingChatModel.generate(messages, WireMockUtil.streamingResponseHandler(streamingResponse));
streamingChatModel.generate(messages, new StreamingResponseHandler<>() {
@Override
public void onNext(String token) {
}

@Override
public void onError(Throwable error) {
fail("Streaming failed: %s".formatted(error.getMessage()), error);
}

@Override
public void onComplete(Response<AiMessage> response) {
assertEquals(FinishReason.LENGTH, response.finishReason());
assertEquals(2, response.tokenUsage().inputTokenCount());
assertEquals(14, response.tokenUsage().outputTokenCount());
assertEquals(16, response.tokenUsage().totalTokenCount());
streamingResponse.set(response.content());
}
});

await().atMost(Duration.ofMinutes(1))
.pollInterval(Duration.ofSeconds(2))
Expand All @@ -277,5 +298,6 @@ void check_chat_streaming_model_config() throws Exception {
assertThat(streamingResponse.get().text())
.isNotNull()
.isEqualTo(". I'm a beginner");

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,31 @@

import java.net.URL;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import org.jboss.logging.Logger;
import org.jboss.resteasy.reactive.client.api.LoggingScope;

import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi;
import io.quarkiverse.langchain4j.watsonx.client.filter.BearerTokenHeaderFactory;
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;

public abstract class Watsonx {

private static final Logger logger = Logger.getLogger(Watsonx.class);

protected final String modelId, projectId, spaceId, version;
protected final WatsonxRestApi client;
protected final List<ChatModelListener> listeners;

public Watsonx(Builder<?> builder) {
QuarkusRestClientBuilder restClientBuilder = QuarkusRestClientBuilder.newBuilder()
Expand All @@ -34,6 +47,38 @@ public Watsonx(Builder<?> builder) {
this.spaceId = builder.spaceId;
this.projectId = builder.projectId;
this.version = builder.version;
this.listeners = builder.listeners;
}

protected void beforeSentRequest(ChatModelRequest request, Map<Object, Object> attributes) {
for (ChatModelListener listener : listeners) {
try {
listener.onRequest(new ChatModelRequestContext(request, attributes));
} catch (Exception e) {
logger.warn("Exception while calling model listener", e);
}
}
}

protected void afterReceivedResponse(ChatModelResponse response, ChatModelRequest request, Map<Object, Object> attributes) {
for (ChatModelListener listener : listeners) {
try {
listener.onResponse(new ChatModelResponseContext(response, request, attributes));
} catch (Exception e) {
logger.warn("Exception while calling model listener", e);
}
}
}

protected void onRequestError(Throwable error, ChatModelRequest request, ChatModelResponse partialResponse,
Map<Object, Object> attributes) {
for (ChatModelListener listener : listeners) {
try {
listener.onError(new ChatModelErrorContext(error, request, partialResponse, attributes));
} catch (Exception e) {
logger.warn("Exception while calling model listener", e);
}
}
}

public WatsonxRestApi getClient() {
Expand Down Expand Up @@ -67,6 +112,7 @@ public static abstract class Builder<T extends Builder<T>> {
protected URL url;
protected boolean logResponses;
protected boolean logRequests;
private List<ChatModelListener> listeners = Collections.emptyList();
protected WatsonxTokenGenerator tokenGenerator;

public T modelId(String modelId) {
Expand Down Expand Up @@ -99,6 +145,11 @@ public T timeout(Duration timeout) {
return (T) this;
}

public T listeners(List<ChatModelListener> listeners) {
this.listeners = listeners;
return (T) this;
}

public T tokenGenerator(WatsonxTokenGenerator tokenGenerator) {
this.tokenGenerator = tokenGenerator;
return (T) this;
Expand Down
Loading
Loading