diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index 63d60fd01..294c1d212 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -522,6 +522,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, boolean needsAuditServiceBean = false; boolean needsModerationModelBean = false; boolean needsImageModelBean = false; + boolean needsToolProviderBean = false; Set allToolNames = new HashSet<>(); Set allToolProviders = new HashSet<>(); @@ -761,7 +762,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, .equals(toolProviderSupplierClassName)) { configurator.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE, new Type[] { ClassType.create(LangChain4jDotNames.TOOL_PROVIDER) }, null)); - } else if (!RegisterAiService.BeanIfExistsToolProviderSupplier.class.getName() + needsToolProviderBean = true; + } else if (!RegisterAiService.NoToolProviderSupplier.class.getName() .equals(toolProviderSupplierClassName) && toolProviderSupplierClassName != null) { DotName toolProvider = DotName.createSimple(toolProviderSupplierClassName); configurator.addInjectionPoint(ClassType.create(toolProvider)); @@ -800,6 +802,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, if (needsImageModelBean) { unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.IMAGE_MODEL)); } + if (needsToolProviderBean) { + unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.TOOL_PROVIDER)); + } if (!allToolProviders.isEmpty()) { unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolProviders)); } diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolProviderTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolProviderTest.java deleted file mode 100644 index f4b18ff3c..000000000 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolProviderTest.java +++ /dev/null @@ -1,150 +0,0 @@ -package io.quarkiverse.langchain4j.test; - -import static dev.langchain4j.data.message.ChatMessageType.TOOL_EXECUTION_RESULT; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; - -import java.util.List; -import java.util.function.Supplier; - -import jakarta.enterprise.context.ApplicationScoped; -import jakarta.enterprise.context.control.ActivateRequestContext; -import jakarta.inject.Inject; - -import org.jboss.shrinkwrap.api.ShrinkWrap; -import org.jboss.shrinkwrap.api.spec.JavaArchive; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; - -import dev.langchain4j.agent.tool.ToolExecutionRequest; -import dev.langchain4j.agent.tool.ToolSpecification; -import dev.langchain4j.data.message.AiMessage; -import dev.langchain4j.data.message.ChatMessage; -import dev.langchain4j.data.message.ToolExecutionResultMessage; -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.output.FinishReason; -import dev.langchain4j.model.output.Response; -import dev.langchain4j.model.output.TokenUsage; -import dev.langchain4j.service.MemoryId; -import dev.langchain4j.service.UserMessage; -import dev.langchain4j.service.tool.ToolExecutor; -import dev.langchain4j.service.tool.ToolProvider; -import dev.langchain4j.service.tool.ToolProviderRequest; -import dev.langchain4j.service.tool.ToolProviderResult; -import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkus.test.QuarkusUnitTest; - -class ToolProviderTest { - @Inject - MyServiceWithCustomToolProvider myServiceWithTools; - - @Inject - MyServiceWithDefaultToolProviderConfig myServiceWithIfExistsTools; - - @Inject - MyServiceWithNoToolProvider myServiceWithNoToolProvider; - - @ApplicationScoped - public static class MyCustomToolProviderSupplier implements Supplier { - @Inject - MyCustomToolProvider myCustomToolProvider; - - @Override - public ToolProvider get() { - return myCustomToolProvider; - } - } - - @ApplicationScoped - public static class MyCustomToolProvider implements ToolProvider { - @Inject - MyServiceWithDefaultToolProviderConfig myServiceWithoutTools; - - @Override - public ToolProviderResult provideTools(ToolProviderRequest request) { - assertNotNull(myServiceWithoutTools); - - ToolSpecification toolSpecification = ToolSpecification.builder() - .name("get_booking_details") - .description("Returns booking details") - .build(); - ToolExecutor toolExecutor = (t, m) -> "0"; - return ToolProviderResult.builder() - .add(toolSpecification, toolExecutor) - .build(); - } - } - - public static class TestAiSupplier implements Supplier { - @Override - public ChatLanguageModel get() { - return new TestAiModel(); - } - } - - public static class TestAiModel implements ChatLanguageModel { - @Override - public Response generate(List messages) { - return new Response<>(new AiMessage("42")); - } - - @Override - public Response generate(List messages, List toolSpecifications) { - ChatMessage lastMsg = messages.get(messages.size() - 1); - boolean isLastMsgToolResponse = lastMsg.type().equals(TOOL_EXECUTION_RESULT); - if (isLastMsgToolResponse) { - ToolExecutionResultMessage msg = (ToolExecutionResultMessage) lastMsg; - return new Response<>(new AiMessage(msg.text())); - } - ToolSpecification toolSpecification = toolSpecifications.get(0); - ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() - .name(toolSpecification.name()) - .id(toolSpecification.name()) - .build(); - TokenUsage usage = new TokenUsage(42, 42); - return new Response<>(AiMessage.from(toolExecutionRequest), usage, FinishReason.TOOL_EXECUTION); - } - } - - @RegisterAiService(toolProviderSupplier = MyCustomToolProviderSupplier.class, chatLanguageModelSupplier = TestAiSupplier.class) - interface MyServiceWithCustomToolProvider { - String chat(@UserMessage String msg, @MemoryId Object id); - } - - @RegisterAiService(chatLanguageModelSupplier = TestAiSupplier.class) - interface MyServiceWithDefaultToolProviderConfig { - String chat(@UserMessage String msg, @MemoryId Object id); - } - - @RegisterAiService(toolProviderSupplier = RegisterAiService.NoToolProviderSupplier.class, chatLanguageModelSupplier = TestAiSupplier.class, chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) - interface MyServiceWithNoToolProvider { - String chat(@UserMessage String msg, @MemoryId Object id); - } - - @RegisterExtension - static final QuarkusUnitTest unitTest = new QuarkusUnitTest() - .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) - .addClasses(MyServiceWithCustomToolProvider.class, MyCustomToolProvider.class, - BlockingChatLanguageModelSupplierTest.MyModelSupplier.class)); - - @Test - @ActivateRequestContext - void testCall() { - String answer = myServiceWithTools.chat("hello", 1); - assertEquals("0", answer); - } - - @Test - @ActivateRequestContext - void testCallDefaultTools() { - String answer = myServiceWithIfExistsTools.chat("hello", 1); - assertEquals("0", answer); - } - - @Test - @ActivateRequestContext - void testCallNoTools() { - String answer = myServiceWithNoToolProvider.chat("hello", 1); - assertEquals("42", answer); - } -} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/AutomaticToolProviderTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/AutomaticToolProviderTest.java new file mode 100644 index 000000000..ef5126ed2 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/AutomaticToolProviderTest.java @@ -0,0 +1,47 @@ +package io.quarkiverse.langchain4j.test.toolresolution; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkus.test.QuarkusUnitTest; + +/** + * If the AI service does not explicitly specify tools nor a tool provider + * and there is a bean that implements ToolProvider, that bean should be used. + */ +public class AutomaticToolProviderTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(TestAiSupplier.class, + TestAiModel.class, + ServiceWithDefaultToolProviderConfig.class, + MyCustomToolProvider.class)); + + @RegisterAiService(chatLanguageModelSupplier = TestAiSupplier.class) + interface ServiceWithDefaultToolProviderConfig { + String chat(@UserMessage String msg, @MemoryId Object id); + } + + @Inject + ServiceWithDefaultToolProviderConfig service; + + @Test + @ActivateRequestContext + void testCall() { + String answer = service.chat("hello", 1); + assertEquals("TOOL1", answer); + } + +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ExplicitToolProviderSupplierTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ExplicitToolProviderSupplierTest.java new file mode 100644 index 000000000..c3441e3ec --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ExplicitToolProviderSupplierTest.java @@ -0,0 +1,50 @@ +package io.quarkiverse.langchain4j.test.toolresolution; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkus.test.QuarkusUnitTest; + +/** + * If an AI service specifies an explicit tool provider (and no specific tools), + * that tool provider should be used. + */ +public class ExplicitToolProviderSupplierTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(TestAiSupplier.class, + TestAiModel.class, + ServiceWithExplicitToolProviderSupplier.class, + MyCustomToolProviderSupplier.class, + MyCustomToolProvider.class)); + + @RegisterAiService(toolProviderSupplier = MyCustomToolProviderSupplier.class, chatLanguageModelSupplier = TestAiSupplier.class) + interface ServiceWithExplicitToolProviderSupplier { + + String chat(@UserMessage String msg, @MemoryId Object id); + + } + + @Inject + ServiceWithExplicitToolProviderSupplier service; + + @Test + @ActivateRequestContext + void testCall() { + String answer = service.chat("hello", 1); + assertEquals("TOOL1", answer); + } + +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ExplicitToolsAndNoBeanToolProviderTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ExplicitToolsAndNoBeanToolProviderTest.java new file mode 100644 index 000000000..42957299c --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ExplicitToolsAndNoBeanToolProviderTest.java @@ -0,0 +1,48 @@ +package io.quarkiverse.langchain4j.test.toolresolution; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkus.test.QuarkusUnitTest; + +/** + * If the AI service explicitly specifies tools, and there is a bean that implements ToolProvider, + * but the service also declares a NoToolProviderSupplier, the explicit tools should be used. + */ +public class ExplicitToolsAndNoBeanToolProviderTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(TestAiSupplier.class, + TestAiModel.class, + ServiceWithExplicitToolsAndNoToolProviderSupplier.class, + MyCustomToolProvider.class, + ToolsClass.class)); + + @RegisterAiService(chatLanguageModelSupplier = TestAiSupplier.class, tools = ToolsClass.class, toolProviderSupplier = RegisterAiService.NoToolProviderSupplier.class) + interface ServiceWithExplicitToolsAndNoToolProviderSupplier { + String chat(@UserMessage String msg, @MemoryId Object id); + } + + @Inject + ServiceWithExplicitToolsAndNoToolProviderSupplier service; + + @Test + @ActivateRequestContext + void testCall() { + String answer = service.chat("hello", 1); + assertEquals("\"EXPLICIT TOOL\"", answer); + } + +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ExplicitToolsAndProviderSupplierTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ExplicitToolsAndProviderSupplierTest.java new file mode 100644 index 000000000..76ca5ed87 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ExplicitToolsAndProviderSupplierTest.java @@ -0,0 +1,53 @@ +package io.quarkiverse.langchain4j.test.toolresolution; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.assertj.core.api.Assertions; +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkus.test.QuarkusUnitTest; + +/** + * + */ +public class ExplicitToolsAndProviderSupplierTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(TestAiSupplier.class, + TestAiModel.class, + ServiceWithToolClash.class, + MyCustomToolProviderSupplier.class, + MyCustomToolProvider.class, + ToolsClass.class)); + + @RegisterAiService(toolProviderSupplier = MyCustomToolProviderSupplier.class, tools = ToolsClass.class, chatLanguageModelSupplier = TestAiSupplier.class) + interface ServiceWithToolClash { + + String chat(@UserMessage String msg, @MemoryId Object id); + + } + + @Inject + ServiceWithToolClash service; + + @Test + @ActivateRequestContext + void testCall() { + try { + String answer = service.chat("hello", 1); + Assertions.fail("Exception expected"); + } catch (Exception e) { + Assertions.assertThat(e.getMessage()).contains(" Cannot use a tool provider when explicit tools are provided"); + } + } + +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ExplicitToolsWhenBeanToolProviderExistsTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ExplicitToolsWhenBeanToolProviderExistsTest.java new file mode 100644 index 000000000..5401d25f0 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ExplicitToolsWhenBeanToolProviderExistsTest.java @@ -0,0 +1,48 @@ +package io.quarkiverse.langchain4j.test.toolresolution; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkus.test.QuarkusUnitTest; + +/** + * If the AI service explicitly specifies tools AND there is a bean that implements ToolProvider, + * the explicit tools should be used, and the ToolProvider bean should be ignored. + */ +public class ExplicitToolsWhenBeanToolProviderExistsTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(TestAiSupplier.class, + TestAiModel.class, + ServiceWithExplicitTools.class, + MyCustomToolProvider.class, + ToolsClass.class)); + + @RegisterAiService(chatLanguageModelSupplier = TestAiSupplier.class, tools = ToolsClass.class) + interface ServiceWithExplicitTools { + String chat(@UserMessage String msg, @MemoryId Object id); + } + + @Inject + ServiceWithExplicitTools service; + + @Test + @ActivateRequestContext + void testCall() { + String answer = service.chat("hello", 1); + assertEquals("\"EXPLICIT TOOL\"", answer); + } + +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/MyCustomToolProvider.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/MyCustomToolProvider.java new file mode 100644 index 000000000..2a9e41cc5 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/MyCustomToolProvider.java @@ -0,0 +1,25 @@ +package io.quarkiverse.langchain4j.test.toolresolution; + +import jakarta.enterprise.context.ApplicationScoped; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.service.tool.ToolExecutor; +import dev.langchain4j.service.tool.ToolProvider; +import dev.langchain4j.service.tool.ToolProviderRequest; +import dev.langchain4j.service.tool.ToolProviderResult; + +@ApplicationScoped +public class MyCustomToolProvider implements ToolProvider { + + @Override + public ToolProviderResult provideTools(ToolProviderRequest request) { + ToolSpecification toolSpecification = ToolSpecification.builder() + .name("get_booking_details") + .description("Returns booking details") + .build(); + ToolExecutor toolExecutor = (t, m) -> "TOOL1"; + return ToolProviderResult.builder() + .add(toolSpecification, toolExecutor) + .build(); + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/MyCustomToolProviderSupplier.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/MyCustomToolProviderSupplier.java new file mode 100644 index 000000000..0ace25f60 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/MyCustomToolProviderSupplier.java @@ -0,0 +1,19 @@ +package io.quarkiverse.langchain4j.test.toolresolution; + +import java.util.function.Supplier; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import dev.langchain4j.service.tool.ToolProvider; + +@ApplicationScoped +public class MyCustomToolProviderSupplier implements Supplier { + @Inject + MyCustomToolProvider myCustomToolProvider; + + @Override + public ToolProvider get() { + return myCustomToolProvider; + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/NoExplicitToolsAndNoToolProviderTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/NoExplicitToolsAndNoToolProviderTest.java new file mode 100644 index 000000000..820605773 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/NoExplicitToolsAndNoToolProviderTest.java @@ -0,0 +1,47 @@ +package io.quarkiverse.langchain4j.test.toolresolution; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkus.test.QuarkusUnitTest; + +/** + * If the AI service explicitly specifies a NoToolProviderSupplier, then even if a ToolProvider + * instance exists as a CDI bean, it should not be used. + */ +public class NoExplicitToolsAndNoToolProviderTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(TestAiSupplier.class, + TestAiModel.class, + ServiceWithNoToolProvider.class, + MyCustomToolProvider.class)); + + @RegisterAiService(toolProviderSupplier = RegisterAiService.NoToolProviderSupplier.class, chatLanguageModelSupplier = TestAiSupplier.class, chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) + interface ServiceWithNoToolProvider { + String chat(@UserMessage String msg, @MemoryId Object id); + } + + @Inject + ServiceWithNoToolProvider service; + + @Test + @ActivateRequestContext + void testCall() { + String answer = service.chat("hello", 1); + assertEquals("NO TOOL", answer); + } + +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/TestAiModel.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/TestAiModel.java new file mode 100644 index 000000000..90dbca7c3 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/TestAiModel.java @@ -0,0 +1,39 @@ +package io.quarkiverse.langchain4j.test.toolresolution; + +import static dev.langchain4j.data.message.ChatMessageType.TOOL_EXECUTION_RESULT; + +import java.util.List; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; + +public class TestAiModel implements ChatLanguageModel { + @Override + public Response generate(List messages) { + return new Response<>(new AiMessage("NO TOOL")); + } + + @Override + public Response generate(List messages, List toolSpecifications) { + ChatMessage lastMsg = messages.get(messages.size() - 1); + boolean isLastMsgToolResponse = lastMsg.type().equals(TOOL_EXECUTION_RESULT); + if (isLastMsgToolResponse) { + ToolExecutionResultMessage msg = (ToolExecutionResultMessage) lastMsg; + return new Response<>(new AiMessage(msg.text())); + } + ToolSpecification toolSpecification = toolSpecifications.get(0); + ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() + .name(toolSpecification.name()) + .id(toolSpecification.name()) + .build(); + TokenUsage usage = new TokenUsage(42, 42); + return new Response<>(AiMessage.from(toolExecutionRequest), usage, FinishReason.TOOL_EXECUTION); + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/TestAiSupplier.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/TestAiSupplier.java new file mode 100644 index 000000000..3104c94f2 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/TestAiSupplier.java @@ -0,0 +1,12 @@ +package io.quarkiverse.langchain4j.test.toolresolution; + +import java.util.function.Supplier; + +import dev.langchain4j.model.chat.ChatLanguageModel; + +public class TestAiSupplier implements Supplier { + @Override + public ChatLanguageModel get() { + return new TestAiModel(); + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ToolsClass.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ToolsClass.java new file mode 100644 index 000000000..84a5930e7 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/toolresolution/ToolsClass.java @@ -0,0 +1,14 @@ +package io.quarkiverse.langchain4j.test.toolresolution; + +import jakarta.enterprise.context.ApplicationScoped; + +import dev.langchain4j.agent.tool.Tool; + +@ApplicationScoped +public class ToolsClass { + + @Tool + public String hello() { + return "EXPLICIT TOOL"; + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java index 834f1b014..4ca13700b 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java @@ -153,7 +153,8 @@ public T apply(SyntheticCreationalContext creationalContext) { } Map> toolsClasses = info.toolsClassInfo(); - if ((toolsClasses != null) && !toolsClasses.isEmpty()) { + boolean hasExplicitTools = (toolsClasses != null) && !toolsClasses.isEmpty(); + if (hasExplicitTools) { List tools = new ArrayList<>(toolsClasses.size()); for (var entry : toolsClasses.entrySet()) { AnnotationLiteral qualifier = entry.getValue(); @@ -172,10 +173,16 @@ public T apply(SyntheticCreationalContext creationalContext) { quarkusAiServices.tools(tools); } + // if no explicit tools are provided, check if we should use a tool provider if (info.toolProviderSupplier() != null) { if (!RegisterAiService.BeanIfExistsToolProviderSupplier.class.getName() .equals(info.toolProviderSupplier())) { // specific provider + if (hasExplicitTools) { + // if the service has both explicit tools and a specific tool provider, + // this is an error + throw new IllegalStateException("Cannot use a tool provider when explicit tools are provided"); + } Class toolProviderClass = Thread.currentThread().getContextClassLoader() .loadClass(info.toolProviderSupplier()); Supplier toolProvider = (Supplier) creationalContext @@ -185,7 +192,9 @@ public T apply(SyntheticCreationalContext creationalContext) { // if-exists provider Instance instance = creationalContext .getInjectedReference(TOOL_PROVIDER_TYPE_LITERAL); - if (instance.isResolvable()) { + // if the service has explicit tools and a BeanIfExistsToolProviderSupplier, + // just give priority to the explicit tools, don't throw an error + if (instance.isResolvable() && !hasExplicitTools) { quarkusAiServices.toolProvider(instance.get()); } }