diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java index 9306bed03..164863c0c 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java @@ -23,8 +23,8 @@ import io.quarkiverse.langchain4j.runtime.cache.AiCache; import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider; import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore; +import io.quarkiverse.langchain4j.runtime.cache.FixedAiCache; import io.quarkiverse.langchain4j.runtime.cache.InMemoryAiCacheStore; -import io.quarkiverse.langchain4j.runtime.cache.MessageWindowAiCache; /** * Used to create LangChain4j's {@link AiServices} in a declarative manner that the application can then use simply by using the @@ -101,7 +101,7 @@ * Configures the way to obtain the {@link AiCacheProvider}. *

* Be default, Quarkus configures a {@link AiCacheProvider} bean that uses a {@link InMemoryAiCacheStore} bean as the - * backing store. The default type for the actual {@link AiCache} is {@link MessageWindowAiCache} and it is configured with + * backing store. The default type for the actual {@link AiCache} is {@link FixedAiCache} and it is configured with * the value of the {@code quarkus.langchain4j.cache.max-size} configuration property (which default to * 1) as a way of limiting the number of messages in each cache. *

diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiCacheRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiCacheRecorder.java index 5e5d4e0e5..5f856a6e1 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiCacheRecorder.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiCacheRecorder.java @@ -8,7 +8,7 @@ import io.quarkiverse.langchain4j.runtime.cache.AiCache; import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider; import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore; -import io.quarkiverse.langchain4j.runtime.cache.MessageWindowAiCache; +import io.quarkiverse.langchain4j.runtime.cache.FixedAiCache; import io.quarkiverse.langchain4j.runtime.cache.config.AiCacheConfig; import io.quarkus.arc.SyntheticCreationalContext; import io.quarkus.runtime.annotations.Recorder; @@ -41,7 +41,7 @@ public AiCacheProvider apply(SyntheticCreationalContext context return new AiCacheProvider() { @Override public AiCache get(Object memoryId) { - return MessageWindowAiCache.Builder + return FixedAiCache.Builder .create(memoryId) .ttl(ttl) .maxSize(maxSize) diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java index 7aeff7bd1..740a97325 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java @@ -252,7 +252,8 @@ public T apply(SyntheticCreationalContext creationalContext) { } } - aiServiceContext.aiCaches = new ConcurrentHashMap<>(); + if (aiServiceContext.aiCaches == null) + aiServiceContext.aiCaches = new ConcurrentHashMap<>(); } return (T) aiServiceContext; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index 53217bfde..01b74b5db 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -122,13 +122,9 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob Object memoryId = memoryId(methodCreateInfo, methodArgs, context.chatMemoryProvider != null); AiCache cache = null; - // TODO: REMOVE THIS COMMENT BEFORE MERGING THE PR. - // - Understand how to implement the concept of cache for the stream responses. - // - What do we have to do when we have the tools? - if (methodCreateInfo.isRequiresCache()) { - Object cacheId = cacheId(methodCreateInfo, methodArgs); - cache = context.aiCacheProvider.get(cacheId); + Object cacheId = cacheId(methodCreateInfo); + cache = context.cache(cacheId); } if (context.retrievalAugmentor != null) { // TODO extract method/class List chatMemory = context.hasChatMemory() @@ -396,7 +392,7 @@ private static Object memoryId(AiServiceMethodCreateInfo createInfo, Object[] me return "default"; } - private static Object cacheId(AiServiceMethodCreateInfo createInfo, Object[] methodArgs) { + private static Object cacheId(AiServiceMethodCreateInfo createInfo) { for (DefaultMemoryIdProvider provider : DEFAULT_MEMORY_ID_PROVIDERS) { Object memoryId = provider.getMemoryId(); if (memoryId != null) { @@ -404,9 +400,7 @@ private static Object cacheId(AiServiceMethodCreateInfo createInfo, Object[] met return memoryId + perServiceSuffix; } } - - // fallback to the default since there is nothing else we can really use here - return "default"; + return "#" + createInfo.getInterfaceName() + "." + createInfo.getMethodName(); } // TODO: share these methods with LangChain4j diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java index e77cfb9b6..c608d9bb9 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java @@ -29,8 +29,8 @@ public boolean hasCache() { return aiCaches != null; } - public AiCache cache(Object memoryId) { - return aiCaches.computeIfAbsent(memoryId, ignored -> aiCacheProvider.get(memoryId)); + public AiCache cache(Object cacheId) { + return aiCaches.computeIfAbsent(cacheId, ignored -> aiCacheProvider.get(cacheId)); } /** @@ -58,7 +58,7 @@ private void clearAiCache() { if (aiCaches != null) { aiCaches.forEach(new BiConsumer<>() { @Override - public void accept(Object memoryId, AiCache aiCache) { + public void accept(Object cacheId, AiCache aiCache) { aiCache.clear(); } }); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/MessageWindowAiCache.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/FixedAiCache.java similarity index 69% rename from core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/MessageWindowAiCache.java rename to core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/FixedAiCache.java index 569bc751b..fc512d854 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/MessageWindowAiCache.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/FixedAiCache.java @@ -2,11 +2,11 @@ import java.time.Duration; import java.util.Date; -import java.util.LinkedList; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.concurrent.locks.ReentrantLock; +import java.util.stream.Collectors; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.SystemMessage; @@ -16,9 +16,9 @@ import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore.CacheRecord; /** - * This {@link AiCache} implementation operates as a sliding window of messages. + * This {@link AiCache} default implementation. */ -public class MessageWindowAiCache implements AiCache { +public class FixedAiCache implements AiCache { private final Object id; private final Integer maxMessages; @@ -30,7 +30,7 @@ public class MessageWindowAiCache implements AiCache { private final EmbeddingModel embeddingModel; private final ReentrantLock lock; - public MessageWindowAiCache(Builder builder) { + public FixedAiCache(Builder builder) { this.id = builder.id; this.maxMessages = builder.maxSize; this.store = builder.store; @@ -64,25 +64,17 @@ public void add(SystemMessage systemMessage, UserMessage userMessage, AiMessage lock.lock(); - List elements = store.getAll(id); - if (elements.size() == maxMessages) { - elements.remove(0); - } - - List items = new LinkedList<>(); - for (int i = 0; i < elements.size(); i++) { - - var expiredTime = Date.from(elements.get(i).creation().plus(ttl)); - var currentTime = new Date(); - - if (currentTime.after(expiredTime)) - continue; + List elements = store.getAll(id) + .stream() + .filter(this::checkTTL) + .collect(Collectors.toList()); - items.add(elements.get(i)); + if (elements.size() == maxMessages) { + return; } - items.add(CacheRecord.of(embeddingModel.embed(query).content(), aiResponse)); - store.updateCache(id, items); + elements.add(CacheRecord.of(embeddingModel.embed(query).content(), aiResponse)); + store.updateCache(id, elements); } finally { lock.unlock(); @@ -101,30 +93,49 @@ public Optional search(SystemMessage systemMessage, UserMessage userM else query = "%s%s%s".formatted(queryPrefix, systemMessage.text(), userMessage.text()); - var elements = store.getAll(id); - double maxScore = 0; - AiMessage result = null; + try { - for (var cacheRecord : elements) { + lock.lock(); - if (ttl != null) { - var expiredTime = Date.from(cacheRecord.creation().plus(ttl)); - var currentTime = new Date(); + double maxScore = 0; + AiMessage result = null; + List records = store.getAll(id) + .stream() + .filter(this::checkTTL) + .collect(Collectors.toList()); - if (currentTime.after(expiredTime)) - continue; - } + for (var record : records) { - var relevanceScore = CosineSimilarity.between(embeddingModel.embed(query).content(), cacheRecord.embedded()); - var score = (float) CosineSimilarity.fromRelevanceScore(relevanceScore); + var relevanceScore = CosineSimilarity.between(embeddingModel.embed(query).content(), record.embedded()); + var score = (float) CosineSimilarity.fromRelevanceScore(relevanceScore); - if (score >= threshold.doubleValue() && score >= maxScore) { - maxScore = score; - result = cacheRecord.response(); + if (score >= threshold.doubleValue() && score >= maxScore) { + maxScore = score; + result = record.response(); + } } + + store.updateCache(id, records); + return Optional.ofNullable(result); + + } finally { + lock.unlock(); + } + } + + private boolean checkTTL(CacheRecord record) { + + if (ttl == null) + return true; + + var expiredTime = Date.from(record.creation().plus(ttl)); + var currentTime = new Date(); + + if (currentTime.after(expiredTime)) { + return false; } - return Optional.ofNullable(result); + return true; } @Override @@ -187,7 +198,7 @@ public Builder embeddingModel(EmbeddingModel embeddingModel) { } public AiCache build() { - return new MessageWindowAiCache(this); + return new FixedAiCache(this); } } } diff --git a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheClassTest.java b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheClassTest.java index 515926f87..1ba6987f6 100644 --- a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheClassTest.java +++ b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheClassTest.java @@ -76,16 +76,22 @@ public Response> embedAll(List textSegments) { @Test void cache_test() { - String cacheId = "default"; + String chatCacheId = "#" + LLMService.class.getName() + ".chat"; + String chat2CacheId = "#" + LLMService.class.getName() + ".chat2"; + + assertEquals(0, aiCacheStore.getAll(chatCacheId).size()); + assertEquals(0, aiCacheStore.getAll(chat2CacheId).size()); - assertEquals(0, aiCacheStore.getAll(cacheId).size()); service.chat("chat"); - assertEquals(1, aiCacheStore.getAll(cacheId).size()); + assertEquals(1, aiCacheStore.getAll(chatCacheId).size()); + assertEquals("result", aiCacheStore.getAll(chatCacheId).get(0).response().text()); + assertEquals(es, aiCacheStore.getAll(chatCacheId).get(0).embedded().vector()); + assertEquals(0, aiCacheStore.getAll(chat2CacheId).size()); service.chat2("chat2"); - assertEquals(1, aiCacheStore.getAll(cacheId).size()); - assertEquals("result", aiCacheStore.getAll(cacheId).get(0).response().text()); - assertEquals(es, aiCacheStore.getAll(cacheId).get(0).embedded().vector()); + assertEquals(1, aiCacheStore.getAll(chat2CacheId).size()); + assertEquals("result", aiCacheStore.getAll(chat2CacheId).get(0).response().text()); + assertEquals(es, aiCacheStore.getAll(chat2CacheId).get(0).embedded().vector()); } @Test diff --git a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheConfigTest.java b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheConfigTest.java index 9aa9bc7ad..d8d02f0c4 100644 --- a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheConfigTest.java +++ b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheConfigTest.java @@ -26,6 +26,9 @@ import io.quarkiverse.langchain4j.CacheResult; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore; +import io.quarkus.arc.Arc; +import io.quarkus.arc.ArcContainer; +import io.quarkus.arc.ManagedContext; import io.quarkus.test.QuarkusUnitTest; public class CacheConfigTest { @@ -85,7 +88,7 @@ else if (textSegments.get(0).text().equals("TESTFOURTH")) @Order(1) void cache_ttl_test() throws InterruptedException { - String cacheId = "default"; + String cacheId = "#" + LLMService.class.getName() + ".chat"; aiCacheStore.deleteCache(cacheId); service.chat("FIRST"); @@ -107,7 +110,7 @@ void cache_ttl_test() throws InterruptedException { @Order(2) void cache_max_size_test() { - String cacheId = "default"; + String cacheId = "#" + LLMService.class.getName() + ".chat"; aiCacheStore.deleteCache(cacheId); service.chat("FIRST"); @@ -119,12 +122,18 @@ void cache_max_size_test() { service.chat("THIRD"); service.chat("FOURTH"); assertEquals(3, aiCacheStore.getAll(cacheId).size()); - assertEquals("cache: TESTSECOND", aiCacheStore.getAll(cacheId).get(0).response().text()); - assertEquals(second, aiCacheStore.getAll(cacheId).get(0).embedded().vector()); - assertEquals("cache: TESTTHIRD", aiCacheStore.getAll(cacheId).get(1).response().text()); - assertEquals(third, aiCacheStore.getAll(cacheId).get(1).embedded().vector()); - assertEquals("cache: TESTFOURTH", aiCacheStore.getAll(cacheId).get(2).response().text()); - assertEquals(fourth, aiCacheStore.getAll(cacheId).get(2).embedded().vector()); + assertEquals("cache: TESTFIRST", aiCacheStore.getAll(cacheId).get(0).response().text()); + assertEquals(first, aiCacheStore.getAll(cacheId).get(0).embedded().vector()); + assertEquals("cache: TESTSECOND", aiCacheStore.getAll(cacheId).get(1).response().text()); + assertEquals(second, aiCacheStore.getAll(cacheId).get(1).embedded().vector()); + assertEquals("cache: TESTTHIRD", aiCacheStore.getAll(cacheId).get(2).response().text()); + assertEquals(third, aiCacheStore.getAll(cacheId).get(2).embedded().vector()); + } + + private String getContext(String methodName) { + ArcContainer container = Arc.container(); + ManagedContext requestContext = container.requestContext(); + return requestContext.getState() + "#" + LLMService.class.getName() + "." + methodName; } static float[] first = { diff --git a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheMethodTest.java b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheMethodTest.java index f701c0cf5..3a8a91514 100644 --- a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheMethodTest.java +++ b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheMethodTest.java @@ -79,16 +79,20 @@ public Response> embedAll(List textSegments) { @Test void cache_test() { - String cacheId = "default"; + String chatCacheId = "#" + LLMService.class.getName() + ".chat"; + String chatNoCacheCacheId = "#" + LLMService.class.getName() + ".chatNoCache"; - assertEquals(0, aiCacheStore.getAll(cacheId).size()); + assertEquals(0, aiCacheStore.getAll(chatCacheId).size()); + assertEquals(0, aiCacheStore.getAll(chatNoCacheCacheId).size()); service.chatNoCache("noCache"); - assertEquals(0, aiCacheStore.getAll(cacheId).size()); + assertEquals(0, aiCacheStore.getAll(chatCacheId).size()); + assertEquals(0, aiCacheStore.getAll(chatNoCacheCacheId).size()); service.chat("cache"); - assertEquals(1, aiCacheStore.getAll(cacheId).size()); - assertEquals("result", aiCacheStore.getAll(cacheId).get(0).response().text()); - assertEquals(es, aiCacheStore.getAll(cacheId).get(0).embedded().vector()); + assertEquals(1, aiCacheStore.getAll(chatCacheId).size()); + assertEquals("result", aiCacheStore.getAll(chatCacheId).get(0).response().text()); + assertEquals(es, aiCacheStore.getAll(chatCacheId).get(0).embedded().vector()); + assertEquals(0, aiCacheStore.getAll(chatNoCacheCacheId).size()); } @Test diff --git a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CachePrefixTest.java b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CachePrefixTest.java index 47fd71623..0984e6021 100644 --- a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CachePrefixTest.java +++ b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CachePrefixTest.java @@ -94,7 +94,7 @@ public Response> embedAll(List textSegments) { @Order(1) void cache_prefix_test() throws InterruptedException { - String cacheId = "default"; + String cacheId = "#" + LLMService.class.getName() + ".chat"; aiCacheStore.deleteCache(cacheId); service.chat("firstMessage");