diff --git a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaJsonOutputTest.java b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaJsonOutputTest.java new file mode 100644 index 000000000..4d4ebc6c1 --- /dev/null +++ b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaJsonOutputTest.java @@ -0,0 +1,93 @@ +package io.quarkiverse.langchain4j.ollama.deployment; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.model.output.structured.Description; +import dev.langchain4j.service.UserName; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.testing.internal.WiremockAware; +import io.quarkus.test.QuarkusUnitTest; + +public class OllamaJsonOutputTest extends WiremockAware { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) + .overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig()) + .overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false") + .overrideConfigKey("quarkus.langchain4j.ollama.chat-model.format", "json"); + + @Description("A person") + public record Person( + @Description("The firstname") String firstname, + @Description("The lastname") String lastname) { + } + + @Singleton + @RegisterAiService + interface AiService { + Person extractPerson(@UserName String text); + } + + @Inject + AiService aiService; + + @Test + void extract() { + wiremock().register( + post(urlEqualTo("/api/chat")) + .withRequestBody(equalToJson( + """ + { + "model": "llama3.2", + "messages": [ + { + "role": "user", + "content": "Tell me something about Alan Wake\\nYou must answer strictly in the following JSON format: {\\n\\\"firstname\\\": (The firstname; type: string),\\n\\\"lastname\\\": (The lastname; type: string)\\n}" + } + ], + "stream": false, + "options": { + "temperature": 0.8, + "top_k": 40, + "top_p": 0.9 + }, + "tools": [], + "format": "json" + }""")) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody(""" + { + "model": "llama3.2", + "created_at": "2024-12-11T15:21:23.422542932Z", + "message": { + "role": "assistant", + "content": "{\\\"firstname\\\":\\\"Alan\\\",\\\"lastname\\\":\\\"Wake\\\"}" + }, + "done_reason": "stop", + "done": true, + "total_duration": 8125806496, + "load_duration": 4223887064, + "prompt_eval_count": 31, + "prompt_eval_duration": 1331000000, + "eval_count": 18, + "eval_duration": 2569000000 + }"""))); + + var result = aiService.extractPerson("Tell me something about Alan Wake"); + assertEquals(new Person("Alan", "Wake"), result); + } +} diff --git a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaStructuredOutputTest.java b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaStructuredOutputTest.java new file mode 100644 index 000000000..37fc38271 --- /dev/null +++ b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaStructuredOutputTest.java @@ -0,0 +1,103 @@ +package io.quarkiverse.langchain4j.ollama.deployment; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.model.output.structured.Description; +import dev.langchain4j.service.UserName; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.testing.internal.WiremockAware; +import io.quarkus.test.QuarkusUnitTest; + +public class OllamaStructuredOutputTest extends WiremockAware { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) + .overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig()) + .overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false"); + + @Description("A person") + public record Person( + @Description("The firstname") String firstname, + @Description("The lastname") String lastname) { + } + + @Singleton + @RegisterAiService + interface AiService { + Person extractPerson(@UserName String text); + } + + @Inject + AiService aiService; + + @Test + void extract() { + wiremock().register( + post(urlEqualTo("/api/chat")) + .withRequestBody(equalToJson(""" + { + "model": "llama3.2", + "messages": [{"role": "user", "content": "Tell me something about Alan Wake"}], + "stream": false, + "options" : { + "temperature" : 0.8, + "top_k" : 40, + "top_p" : 0.9 + }, + "format": { + "type": "object", + "description": "A person", + "properties": { + "firstname": { + "description": "The firstname", + "type": "string" + }, + "lastname": { + "description": "The lastname", + "type": "string" + } + }, + "required": [ + "firstname", + "lastname" + ] + } + } + """)) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody(""" + { + "model": "llama3.2", + "created_at": "2024-12-11T15:21:23.422542932Z", + "message": { + "role": "assistant", + "content": "{\\\"firstname\\\":\\\"Alan\\\",\\\"lastname\\\":\\\"Wake\\\"}" + }, + "done_reason": "stop", + "done": true, + "total_duration": 8125806496, + "load_duration": 4223887064, + "prompt_eval_count": 31, + "prompt_eval_duration": 1331000000, + "eval_count": 18, + "eval_duration": 2569000000 + }"""))); + + var result = aiService.extractPerson("Tell me something about Alan Wake"); + assertEquals(new Person("Alan", "Wake"), result); + } +} diff --git a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaTextOutputTest.java b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaTextOutputTest.java new file mode 100644 index 000000000..48f08eea0 --- /dev/null +++ b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaTextOutputTest.java @@ -0,0 +1,84 @@ +package io.quarkiverse.langchain4j.ollama.deployment; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.service.UserName; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.testing.internal.WiremockAware; +import io.quarkus.test.QuarkusUnitTest; + +public class OllamaTextOutputTest extends WiremockAware { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) + .overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig()) + .overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false"); + + @Singleton + @RegisterAiService + interface AiService { + String question(@UserName String text); + } + + @Inject + AiService aiService; + + @Test + void extract() { + wiremock().register( + post(urlEqualTo("/api/chat")) + .withRequestBody(equalToJson( + """ + { + "model": "llama3.2", + "messages": [ + { + "role": "user", + "content": "Tell me something about Alan Wake" + } + ], + "stream": false, + "options": { + "temperature": 0.8, + "top_k": 40, + "top_p": 0.9 + }, + "tools": [] + }""")) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody(""" + { + "model": "llama3.2", + "created_at": "2024-12-11T15:21:23.422542932Z", + "message": { + "role": "assistant", + "content": "He is a writer!" + }, + "done_reason": "stop", + "done": true, + "total_duration": 8125806496, + "load_duration": 4223887064, + "prompt_eval_count": 31, + "prompt_eval_duration": 1331000000, + "eval_count": 18, + "eval_duration": 2569000000 + }"""))); + + var result = aiService.question("Tell me something about Alan Wake"); + assertEquals("He is a writer!", result); + } +} diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ChatRequest.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ChatRequest.java index d5fadfe30..967cab22c 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ChatRequest.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ChatRequest.java @@ -2,7 +2,14 @@ import java.util.List; -public record ChatRequest(String model, List messages, List tools, Options options, String format, +import com.fasterxml.jackson.databind.annotation.JsonSerialize; + +public record ChatRequest( + String model, + List messages, + List tools, + Options options, + @JsonSerialize(using = FormatJsonSerializer.class) String format, Boolean stream) { public static Builder builder() { diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/FormatJsonSerializer.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/FormatJsonSerializer.java new file mode 100644 index 000000000..21f1879ca --- /dev/null +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/FormatJsonSerializer.java @@ -0,0 +1,20 @@ +package io.quarkiverse.langchain4j.ollama; + +import java.io.IOException; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; + +public class FormatJsonSerializer extends JsonSerializer { + + @Override + public void serialize(String value, JsonGenerator gen, SerializerProvider serializers) throws IOException { + if (value == null) + return; + else if (value.startsWith("{") && value.endsWith("}")) + gen.writeRawValue(value); + else + gen.writeString(value); + } +} diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java index 471c951a3..9d68410bb 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java @@ -2,6 +2,8 @@ import static dev.langchain4j.data.message.AiMessage.aiMessage; import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA; +import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON; import static io.quarkiverse.langchain4j.ollama.MessageMapper.toOllamaMessages; import static io.quarkiverse.langchain4j.ollama.MessageMapper.toTools; @@ -9,14 +11,18 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import org.jboss.logging.Logger; +import com.fasterxml.jackson.databind.ObjectMapper; + import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.chat.Capability; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.listener.ChatModelErrorContext; import dev.langchain4j.model.chat.listener.ChatModelListener; @@ -24,12 +30,16 @@ import dev.langchain4j.model.chat.listener.ChatModelRequestContext; import dev.langchain4j.model.chat.listener.ChatModelResponse; import dev.langchain4j.model.chat.listener.ChatModelResponseContext; +import dev.langchain4j.model.chat.request.ResponseFormat; +import dev.langchain4j.model.chat.request.json.JsonSchemaElementHelper; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; +import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; public class OllamaChatLanguageModel implements ChatLanguageModel { private static final Logger log = Logger.getLogger(OllamaChatLanguageModel.class); + private static final ObjectMapper objectMapper = QuarkusJsonCodecFactory.SnakeCaseObjectMapperHolder.MAPPER; private final OllamaClient client; private final String model; @@ -51,29 +61,33 @@ public static Builder builder() { } @Override - public Response generate(List messages) { - return generate(messages, Collections.emptyList()); - } - - @Override - public Response generate(List messages, ToolSpecification toolSpecification) { - return generate(messages, - toolSpecification != null ? Collections.singletonList(toolSpecification) : Collections.emptyList()); - } + public dev.langchain4j.model.chat.response.ChatResponse chat(dev.langchain4j.model.chat.request.ChatRequest chatRequest) { + List messages = chatRequest.messages(); + List toolSpecifications = chatRequest.toolSpecifications(); + ResponseFormat responseFormat = chatRequest.responseFormat(); - @Override - public Response generate(List messages, List toolSpecifications) { ensureNotEmpty(messages, "messages"); - ChatRequest request = ChatRequest.builder() + ChatRequest.Builder builder = ChatRequest.builder() .model(model) .messages(toOllamaMessages(messages)) - .tools(toTools(toolSpecifications)) + .tools(toolSpecifications == null ? null : toTools(toolSpecifications)) .options(options) - .format(format) - .stream(false) - .build(); + .stream(false); + + if (format != null && !format.isBlank()) { + // If the developer specifies something in the "format" property, it has high priority. + builder.format(format); + } else if (responseFormat != null && responseFormat.type().equals(JSON)) { + try { + var jsonSchema = JsonSchemaElementHelper.toMap(responseFormat.jsonSchema().rootElement()); + builder.format(objectMapper.writeValueAsString(jsonSchema)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + ChatRequest request = builder.build(); ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications); Map attributes = new ConcurrentHashMap<>(); ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes); @@ -105,7 +119,11 @@ public Response generate(List messages, List generate(List messages, List generate(List messages, List toolSpecifications) { + var chatResponse = chat(dev.langchain4j.model.chat.request.ChatRequest.builder() + .messages(messages) + .toolSpecifications(toolSpecifications) + .build()); + + return Response.from(chatResponse.aiMessage()); + } + + @Override + public Response generate(List messages) { + return generate(messages, Collections.emptyList()); + } + + @Override + public Response generate(List messages, ToolSpecification toolSpecification) { + return generate(messages, + toolSpecification != null ? Collections.singletonList(toolSpecification) : Collections.emptyList()); + } + + @Override + public Set supportedCapabilities() { + if (format == null || !format.equalsIgnoreCase("json")) + return Set.of(RESPONSE_FORMAT_JSON_SCHEMA); + return Set.of(); + } + private static Response toResponse(ChatResponse response) { Response result; List toolCalls = response.message().toolCalls(); diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaStreamingChatLanguageModel.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaStreamingChatLanguageModel.java index 9bebbe6e3..e256a462c 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaStreamingChatLanguageModel.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaStreamingChatLanguageModel.java @@ -121,11 +121,14 @@ public void accept(ChatResponse response) { if (response.model() != null) { context.put(MODEL_ID, response.model()); } - TokenUsage tokenUsage = new TokenUsage( - response.evalCount(), - response.promptEvalCount(), - response.evalCount() + response.promptEvalCount()); - context.put(TOKEN_USAGE_CONTEXT, tokenUsage); + + if (response.evalCount() != null && response.promptEvalCount() != null) { + TokenUsage tokenUsage = new TokenUsage( + response.evalCount(), + response.promptEvalCount(), + response.evalCount() + response.promptEvalCount()); + context.put(TOKEN_USAGE_CONTEXT, tokenUsage); + } } } catch (Exception e) { @@ -170,7 +173,8 @@ public void accept(Throwable error) { @Override public void run() { - TokenUsage tokenUsage = context.get(TOKEN_USAGE_CONTEXT); + TokenUsage tokenUsage = context.contains(TOKEN_USAGE_CONTEXT) ? context.get(TOKEN_USAGE_CONTEXT) + : null; List chatResponses = context.get(RESPONSE_CONTEXT); List toolExecutionRequests = context.get(TOOLS_CONTEXT); diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ChatModelConfig.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ChatModelConfig.java index fd306bf5e..54305483c 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ChatModelConfig.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ChatModelConfig.java @@ -56,7 +56,7 @@ public interface ChatModelConfig { Optional seed(); /** - * the format to return a response in. Currently, the only accepted value is {@code json} + * The format to return a response in. Format can be {@code json} or a JSON schema. */ Optional format();