diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java
index 9306bed03..164863c0c 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java
@@ -23,8 +23,8 @@
import io.quarkiverse.langchain4j.runtime.cache.AiCache;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore;
+import io.quarkiverse.langchain4j.runtime.cache.FixedAiCache;
import io.quarkiverse.langchain4j.runtime.cache.InMemoryAiCacheStore;
-import io.quarkiverse.langchain4j.runtime.cache.MessageWindowAiCache;
/**
* Used to create LangChain4j's {@link AiServices} in a declarative manner that the application can then use simply by using the
@@ -101,7 +101,7 @@
* Configures the way to obtain the {@link AiCacheProvider}.
*
* Be default, Quarkus configures a {@link AiCacheProvider} bean that uses a {@link InMemoryAiCacheStore} bean as the
- * backing store. The default type for the actual {@link AiCache} is {@link MessageWindowAiCache} and it is configured with
+ * backing store. The default type for the actual {@link AiCache} is {@link FixedAiCache} and it is configured with
* the value of the {@code quarkus.langchain4j.cache.max-size} configuration property (which default to
* 1) as a way of limiting the number of messages in each cache.
*
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiCacheRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiCacheRecorder.java
index 5e5d4e0e5..5f856a6e1 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiCacheRecorder.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiCacheRecorder.java
@@ -8,7 +8,7 @@
import io.quarkiverse.langchain4j.runtime.cache.AiCache;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore;
-import io.quarkiverse.langchain4j.runtime.cache.MessageWindowAiCache;
+import io.quarkiverse.langchain4j.runtime.cache.FixedAiCache;
import io.quarkiverse.langchain4j.runtime.cache.config.AiCacheConfig;
import io.quarkus.arc.SyntheticCreationalContext;
import io.quarkus.runtime.annotations.Recorder;
@@ -41,7 +41,7 @@ public AiCacheProvider apply(SyntheticCreationalContext context
return new AiCacheProvider() {
@Override
public AiCache get(Object memoryId) {
- return MessageWindowAiCache.Builder
+ return FixedAiCache.Builder
.create(memoryId)
.ttl(ttl)
.maxSize(maxSize)
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java
index 7aeff7bd1..740a97325 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java
@@ -252,7 +252,8 @@ public T apply(SyntheticCreationalContext creationalContext) {
}
}
- aiServiceContext.aiCaches = new ConcurrentHashMap<>();
+ if (aiServiceContext.aiCaches == null)
+ aiServiceContext.aiCaches = new ConcurrentHashMap<>();
}
return (T) aiServiceContext;
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java
index 53217bfde..01b74b5db 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java
@@ -122,13 +122,9 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob
Object memoryId = memoryId(methodCreateInfo, methodArgs, context.chatMemoryProvider != null);
AiCache cache = null;
- // TODO: REMOVE THIS COMMENT BEFORE MERGING THE PR.
- // - Understand how to implement the concept of cache for the stream responses.
- // - What do we have to do when we have the tools?
-
if (methodCreateInfo.isRequiresCache()) {
- Object cacheId = cacheId(methodCreateInfo, methodArgs);
- cache = context.aiCacheProvider.get(cacheId);
+ Object cacheId = cacheId(methodCreateInfo);
+ cache = context.cache(cacheId);
}
if (context.retrievalAugmentor != null) { // TODO extract method/class
List chatMemory = context.hasChatMemory()
@@ -396,7 +392,7 @@ private static Object memoryId(AiServiceMethodCreateInfo createInfo, Object[] me
return "default";
}
- private static Object cacheId(AiServiceMethodCreateInfo createInfo, Object[] methodArgs) {
+ private static Object cacheId(AiServiceMethodCreateInfo createInfo) {
for (DefaultMemoryIdProvider provider : DEFAULT_MEMORY_ID_PROVIDERS) {
Object memoryId = provider.getMemoryId();
if (memoryId != null) {
@@ -404,9 +400,7 @@ private static Object cacheId(AiServiceMethodCreateInfo createInfo, Object[] met
return memoryId + perServiceSuffix;
}
}
-
- // fallback to the default since there is nothing else we can really use here
- return "default";
+ return "#" + createInfo.getInterfaceName() + "." + createInfo.getMethodName();
}
// TODO: share these methods with LangChain4j
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java
index e77cfb9b6..c608d9bb9 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java
@@ -29,8 +29,8 @@ public boolean hasCache() {
return aiCaches != null;
}
- public AiCache cache(Object memoryId) {
- return aiCaches.computeIfAbsent(memoryId, ignored -> aiCacheProvider.get(memoryId));
+ public AiCache cache(Object cacheId) {
+ return aiCaches.computeIfAbsent(cacheId, ignored -> aiCacheProvider.get(cacheId));
}
/**
@@ -58,7 +58,7 @@ private void clearAiCache() {
if (aiCaches != null) {
aiCaches.forEach(new BiConsumer<>() {
@Override
- public void accept(Object memoryId, AiCache aiCache) {
+ public void accept(Object cacheId, AiCache aiCache) {
aiCache.clear();
}
});
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/MessageWindowAiCache.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/FixedAiCache.java
similarity index 69%
rename from core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/MessageWindowAiCache.java
rename to core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/FixedAiCache.java
index 569bc751b..fc512d854 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/MessageWindowAiCache.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/cache/FixedAiCache.java
@@ -2,11 +2,11 @@
import java.time.Duration;
import java.util.Date;
-import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.locks.ReentrantLock;
+import java.util.stream.Collectors;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.SystemMessage;
@@ -16,9 +16,9 @@
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore.CacheRecord;
/**
- * This {@link AiCache} implementation operates as a sliding window of messages.
+ * This {@link AiCache} default implementation.
*/
-public class MessageWindowAiCache implements AiCache {
+public class FixedAiCache implements AiCache {
private final Object id;
private final Integer maxMessages;
@@ -30,7 +30,7 @@ public class MessageWindowAiCache implements AiCache {
private final EmbeddingModel embeddingModel;
private final ReentrantLock lock;
- public MessageWindowAiCache(Builder builder) {
+ public FixedAiCache(Builder builder) {
this.id = builder.id;
this.maxMessages = builder.maxSize;
this.store = builder.store;
@@ -64,25 +64,17 @@ public void add(SystemMessage systemMessage, UserMessage userMessage, AiMessage
lock.lock();
- List elements = store.getAll(id);
- if (elements.size() == maxMessages) {
- elements.remove(0);
- }
-
- List items = new LinkedList<>();
- for (int i = 0; i < elements.size(); i++) {
-
- var expiredTime = Date.from(elements.get(i).creation().plus(ttl));
- var currentTime = new Date();
-
- if (currentTime.after(expiredTime))
- continue;
+ List elements = store.getAll(id)
+ .stream()
+ .filter(this::checkTTL)
+ .collect(Collectors.toList());
- items.add(elements.get(i));
+ if (elements.size() == maxMessages) {
+ return;
}
- items.add(CacheRecord.of(embeddingModel.embed(query).content(), aiResponse));
- store.updateCache(id, items);
+ elements.add(CacheRecord.of(embeddingModel.embed(query).content(), aiResponse));
+ store.updateCache(id, elements);
} finally {
lock.unlock();
@@ -101,30 +93,49 @@ public Optional search(SystemMessage systemMessage, UserMessage userM
else
query = "%s%s%s".formatted(queryPrefix, systemMessage.text(), userMessage.text());
- var elements = store.getAll(id);
- double maxScore = 0;
- AiMessage result = null;
+ try {
- for (var cacheRecord : elements) {
+ lock.lock();
- if (ttl != null) {
- var expiredTime = Date.from(cacheRecord.creation().plus(ttl));
- var currentTime = new Date();
+ double maxScore = 0;
+ AiMessage result = null;
+ List records = store.getAll(id)
+ .stream()
+ .filter(this::checkTTL)
+ .collect(Collectors.toList());
- if (currentTime.after(expiredTime))
- continue;
- }
+ for (var record : records) {
- var relevanceScore = CosineSimilarity.between(embeddingModel.embed(query).content(), cacheRecord.embedded());
- var score = (float) CosineSimilarity.fromRelevanceScore(relevanceScore);
+ var relevanceScore = CosineSimilarity.between(embeddingModel.embed(query).content(), record.embedded());
+ var score = (float) CosineSimilarity.fromRelevanceScore(relevanceScore);
- if (score >= threshold.doubleValue() && score >= maxScore) {
- maxScore = score;
- result = cacheRecord.response();
+ if (score >= threshold.doubleValue() && score >= maxScore) {
+ maxScore = score;
+ result = record.response();
+ }
}
+
+ store.updateCache(id, records);
+ return Optional.ofNullable(result);
+
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ private boolean checkTTL(CacheRecord record) {
+
+ if (ttl == null)
+ return true;
+
+ var expiredTime = Date.from(record.creation().plus(ttl));
+ var currentTime = new Date();
+
+ if (currentTime.after(expiredTime)) {
+ return false;
}
- return Optional.ofNullable(result);
+ return true;
}
@Override
@@ -187,7 +198,7 @@ public Builder embeddingModel(EmbeddingModel embeddingModel) {
}
public AiCache build() {
- return new MessageWindowAiCache(this);
+ return new FixedAiCache(this);
}
}
}
diff --git a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheClassTest.java b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheClassTest.java
index 515926f87..1ba6987f6 100644
--- a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheClassTest.java
+++ b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheClassTest.java
@@ -76,16 +76,22 @@ public Response> embedAll(List textSegments) {
@Test
void cache_test() {
- String cacheId = "default";
+ String chatCacheId = "#" + LLMService.class.getName() + ".chat";
+ String chat2CacheId = "#" + LLMService.class.getName() + ".chat2";
+
+ assertEquals(0, aiCacheStore.getAll(chatCacheId).size());
+ assertEquals(0, aiCacheStore.getAll(chat2CacheId).size());
- assertEquals(0, aiCacheStore.getAll(cacheId).size());
service.chat("chat");
- assertEquals(1, aiCacheStore.getAll(cacheId).size());
+ assertEquals(1, aiCacheStore.getAll(chatCacheId).size());
+ assertEquals("result", aiCacheStore.getAll(chatCacheId).get(0).response().text());
+ assertEquals(es, aiCacheStore.getAll(chatCacheId).get(0).embedded().vector());
+ assertEquals(0, aiCacheStore.getAll(chat2CacheId).size());
service.chat2("chat2");
- assertEquals(1, aiCacheStore.getAll(cacheId).size());
- assertEquals("result", aiCacheStore.getAll(cacheId).get(0).response().text());
- assertEquals(es, aiCacheStore.getAll(cacheId).get(0).embedded().vector());
+ assertEquals(1, aiCacheStore.getAll(chat2CacheId).size());
+ assertEquals("result", aiCacheStore.getAll(chat2CacheId).get(0).response().text());
+ assertEquals(es, aiCacheStore.getAll(chat2CacheId).get(0).embedded().vector());
}
@Test
diff --git a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheConfigTest.java b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheConfigTest.java
index 9aa9bc7ad..d8d02f0c4 100644
--- a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheConfigTest.java
+++ b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheConfigTest.java
@@ -26,6 +26,9 @@
import io.quarkiverse.langchain4j.CacheResult;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore;
+import io.quarkus.arc.Arc;
+import io.quarkus.arc.ArcContainer;
+import io.quarkus.arc.ManagedContext;
import io.quarkus.test.QuarkusUnitTest;
public class CacheConfigTest {
@@ -85,7 +88,7 @@ else if (textSegments.get(0).text().equals("TESTFOURTH"))
@Order(1)
void cache_ttl_test() throws InterruptedException {
- String cacheId = "default";
+ String cacheId = "#" + LLMService.class.getName() + ".chat";
aiCacheStore.deleteCache(cacheId);
service.chat("FIRST");
@@ -107,7 +110,7 @@ void cache_ttl_test() throws InterruptedException {
@Order(2)
void cache_max_size_test() {
- String cacheId = "default";
+ String cacheId = "#" + LLMService.class.getName() + ".chat";
aiCacheStore.deleteCache(cacheId);
service.chat("FIRST");
@@ -119,12 +122,18 @@ void cache_max_size_test() {
service.chat("THIRD");
service.chat("FOURTH");
assertEquals(3, aiCacheStore.getAll(cacheId).size());
- assertEquals("cache: TESTSECOND", aiCacheStore.getAll(cacheId).get(0).response().text());
- assertEquals(second, aiCacheStore.getAll(cacheId).get(0).embedded().vector());
- assertEquals("cache: TESTTHIRD", aiCacheStore.getAll(cacheId).get(1).response().text());
- assertEquals(third, aiCacheStore.getAll(cacheId).get(1).embedded().vector());
- assertEquals("cache: TESTFOURTH", aiCacheStore.getAll(cacheId).get(2).response().text());
- assertEquals(fourth, aiCacheStore.getAll(cacheId).get(2).embedded().vector());
+ assertEquals("cache: TESTFIRST", aiCacheStore.getAll(cacheId).get(0).response().text());
+ assertEquals(first, aiCacheStore.getAll(cacheId).get(0).embedded().vector());
+ assertEquals("cache: TESTSECOND", aiCacheStore.getAll(cacheId).get(1).response().text());
+ assertEquals(second, aiCacheStore.getAll(cacheId).get(1).embedded().vector());
+ assertEquals("cache: TESTTHIRD", aiCacheStore.getAll(cacheId).get(2).response().text());
+ assertEquals(third, aiCacheStore.getAll(cacheId).get(2).embedded().vector());
+ }
+
+ private String getContext(String methodName) {
+ ArcContainer container = Arc.container();
+ ManagedContext requestContext = container.requestContext();
+ return requestContext.getState() + "#" + LLMService.class.getName() + "." + methodName;
}
static float[] first = {
diff --git a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheMethodTest.java b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheMethodTest.java
index f701c0cf5..3a8a91514 100644
--- a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheMethodTest.java
+++ b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CacheMethodTest.java
@@ -79,16 +79,20 @@ public Response> embedAll(List textSegments) {
@Test
void cache_test() {
- String cacheId = "default";
+ String chatCacheId = "#" + LLMService.class.getName() + ".chat";
+ String chatNoCacheCacheId = "#" + LLMService.class.getName() + ".chatNoCache";
- assertEquals(0, aiCacheStore.getAll(cacheId).size());
+ assertEquals(0, aiCacheStore.getAll(chatCacheId).size());
+ assertEquals(0, aiCacheStore.getAll(chatNoCacheCacheId).size());
service.chatNoCache("noCache");
- assertEquals(0, aiCacheStore.getAll(cacheId).size());
+ assertEquals(0, aiCacheStore.getAll(chatCacheId).size());
+ assertEquals(0, aiCacheStore.getAll(chatNoCacheCacheId).size());
service.chat("cache");
- assertEquals(1, aiCacheStore.getAll(cacheId).size());
- assertEquals("result", aiCacheStore.getAll(cacheId).get(0).response().text());
- assertEquals(es, aiCacheStore.getAll(cacheId).get(0).embedded().vector());
+ assertEquals(1, aiCacheStore.getAll(chatCacheId).size());
+ assertEquals("result", aiCacheStore.getAll(chatCacheId).get(0).response().text());
+ assertEquals(es, aiCacheStore.getAll(chatCacheId).get(0).embedded().vector());
+ assertEquals(0, aiCacheStore.getAll(chatNoCacheCacheId).size());
}
@Test
diff --git a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CachePrefixTest.java b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CachePrefixTest.java
index 47fd71623..0984e6021 100644
--- a/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CachePrefixTest.java
+++ b/model-providers/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/CachePrefixTest.java
@@ -94,7 +94,7 @@ public Response> embedAll(List textSegments) {
@Order(1)
void cache_prefix_test() throws InterruptedException {
- String cacheId = "default";
+ String cacheId = "#" + LLMService.class.getName() + ".chat";
aiCacheStore.deleteCache(cacheId);
service.chat("firstMessage");