Skip to content

Commit

Permalink
Make @RegisterAiService beans request scoped by default
Browse files Browse the repository at this point in the history
This is done because otherwise the chat memory
does not get cleared properly.

Furthermore, add a way to remove memory entries when the service goes out of scope

Fixes: #95
  • Loading branch information
geoand committed Dec 7, 2023
1 parent aab4e44 commit ae6d811
Show file tree
Hide file tree
Showing 22 changed files with 458 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.function.Predicate;
import java.util.stream.Collectors;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;

import org.jboss.jandex.AnnotationInstance;
Expand All @@ -50,6 +49,7 @@
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceBeanDestroyer;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.MetricsWrapper;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
Expand All @@ -60,6 +60,8 @@
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.arc.processor.ScopeInfo;
import io.quarkus.builder.item.MultiBuildItem;
import io.quarkus.deployment.Capabilities;
import io.quarkus.deployment.Capability;
Expand Down Expand Up @@ -101,6 +103,9 @@ public class AiServicesProcessor {
private static final MethodDescriptor SUPPORT_IMPLEMENT = MethodDescriptor.ofMethod(
AiServiceMethodImplementationSupport.class,
"implement", Object.class, AiServiceMethodImplementationSupport.Input.class);

private static final MethodDescriptor QUARKUS_AI_SERVICES_CONTEXT_CLOSE = MethodDescriptor.ofMethod(
QuarkusAiServiceContext.class, "close", void.class);
public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class);

@BuildStep
Expand Down Expand Up @@ -211,14 +216,18 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
validateSupplierAndRegisterForReflection(auditServiceClassSupplierName, index, reflectiveClassProducer);
}

BuiltinScope declaredScope = BuiltinScope.from(declarativeAiServiceClassInfo);
ScopeInfo cdiScope = declaredScope != null ? declaredScope.getInfo() : BuiltinScope.REQUEST.getInfo();

declarativeAiServiceProducer.produce(
new DeclarativeAiServiceBuildItem(
declarativeAiServiceClassInfo,
chatLanguageModelSupplierClassDotName,
toolDotNames,
chatMemoryProviderSupplierClassDotName,
retrieverSupplierClassDotName,
auditServiceClassSupplierName));
auditServiceClassSupplierName,
cdiScope));
}

if (needChatModelBean) {
Expand Down Expand Up @@ -285,8 +294,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
toolClassNames, chatMemoryProviderSupplierClassName,
retrieverSupplierClassName,
auditServiceClassSupplierName)))
.destroyer(DeclarativeAiServiceBeanDestroyer.class)
.setRuntimeInit()
.scope(ApplicationScoped.class);
.scope(bi.getCdiScope());
if ((chatLanguageModelSupplierClassName == null) && selectedChatModelProvider.isPresent()) { // TODO: is second condition needed?
configurator.addInjectionPoint(ClassType.create(Langchain4jDotNames.CHAT_MODEL));
needsChatModelBean = true;
Expand Down Expand Up @@ -403,8 +413,10 @@ public void handleAiServices(AiServicesRecorder recorder,
Set<String> detectedForCreate = new HashSet<>(nameToUsed.keySet());
addCreatedAware(index, detectedForCreate);
addIfacesWithMessageAnns(index, detectedForCreate);
detectedForCreate.addAll(declarativeAiServiceItems.stream().map(bi -> bi.getServiceClassInfo().name().toString())
.collect(Collectors.toList()));
Set<String> registeredAiServiceClassNames = declarativeAiServiceItems.stream()
.map(bi -> bi.getServiceClassInfo().name().toString()).collect(
Collectors.toUnmodifiableSet());
detectedForCreate.addAll(registeredAiServiceClassNames);

Set<ClassInfo> ifacesForCreate = new HashSet<>();
for (String className : detectedForCreate) {
Expand Down Expand Up @@ -453,12 +465,18 @@ public void handleAiServices(AiServicesRecorder recorder,
methodsToImplement.add(method);
}

String implClassName = iface.name().toString() + "$$QuarkusImpl";
try (ClassCreator classCreator = ClassCreator.builder()
String ifaceName = iface.name().toString();
String implClassName = ifaceName + "$$QuarkusImpl";
boolean isRegisteredService = registeredAiServiceClassNames.contains(ifaceName);

ClassCreator.Builder classCreatorBuilder = ClassCreator.builder()
.classOutput(classOutput)
.className(implClassName)
.interfaces(iface.name().toString())
.build()) {
.interfaces(ifaceName);
if (isRegisteredService) {
classCreatorBuilder.interfaces(AutoCloseable.class);
}
try (ClassCreator classCreator = classCreatorBuilder.build()) {

FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class)
.setModifiers(Modifier.PRIVATE | Modifier.FINAL)
Expand All @@ -480,7 +498,7 @@ public void handleAiServices(AiServicesRecorder recorder,
MethodCreator mc = classCreator.getMethodCreator(MethodDescriptor.of(methodInfo));
ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis());
ResultHandle methodCreateInfoHandle = mc.invokeStaticMethod(RECORDER_METHOD_CREATE_INFO,
mc.load(iface.name().toString()),
mc.load(ifaceName),
mc.load(methodId));
ResultHandle paramsHandle = mc.newArray(Object.class, methodInfo.parametersCount());
for (int i = 0; i < methodInfo.parametersCount(); i++) {
Expand All @@ -498,8 +516,16 @@ public void handleAiServices(AiServicesRecorder recorder,

aiServicesMethodProducer.produce(new AiServicesMethodBuildItem(methodInfo));
}

if (isRegisteredService) {
MethodCreator mc = classCreator.getMethodCreator(
MethodDescriptor.ofMethod(implClassName, "close", void.class));
ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis());
mc.invokeVirtualMethod(QUARKUS_AI_SERVICES_CONTEXT_CLOSE, contextHandle);
mc.returnVoid();
}
}
perClassMetadata.put(iface.name().toString(), new AiServiceClassCreateInfo(perMethodMetadata, implClassName));
perClassMetadata.put(ifaceName, new AiServiceClassCreateInfo(perMethodMetadata, implClassName));
// make the constructor accessible reflectively since that is how we create the instance
reflectiveClassProducer.produce(ReflectiveClassBuildItem.builder(implClassName).build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.DotName;

import io.quarkus.arc.processor.ScopeInfo;
import io.quarkus.builder.item.MultiBuildItem;

/**
Expand All @@ -19,18 +20,21 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final DotName chatMemoryProviderSupplierClassDotName;
private final DotName retrieverSupplierClassDotName;
private final DotName auditServiceClassSupplierDotName;
private final ScopeInfo cdiScope;

public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languageModelSupplierClassDotName,
List<DotName> toolDotNames,
DotName chatMemoryProviderSupplierClassDotName,
DotName retrieverSupplierClassDotName,
DotName auditServiceClassSupplierDotName) {
DotName auditServiceClassSupplierDotName,
ScopeInfo cdiScope) {
this.serviceClassInfo = serviceClassInfo;
this.languageModelSupplierClassDotName = languageModelSupplierClassDotName;
this.toolDotNames = toolDotNames;
this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName;
this.retrieverSupplierClassDotName = retrieverSupplierClassDotName;
this.auditServiceClassSupplierDotName = auditServiceClassSupplierDotName;
this.cdiScope = cdiScope;
}

public ClassInfo getServiceClassInfo() {
Expand All @@ -56,4 +60,8 @@ public DotName getRetrieverSupplierClassDotName() {
public DotName getAuditServiceClassSupplierDotName() {
return auditServiceClassSupplierDotName;
}

public ScopeInfo getCdiScope() {
return cdiScope;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import java.lang.annotation.Target;
import java.util.function.Supplier;

import jakarta.enterprise.context.ApplicationScoped;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
Expand All @@ -24,7 +22,9 @@
* while also providing the builder with the proper {@link ChatLanguageModel} bean (mandatory), {@code tools} bean (optional),
* {@link ChatMemoryProvider} and {@link Retriever} beans (which by default are configured if such beans exist).
* <p>
* NOTE: The resulting CDI bean is {@link ApplicationScoped}.
* NOTE: The resulting CDI bean is {@link jakarta.enterprise.context.RequestScoped} be default. If you need to change the scope,
* simply annotate the class with a CDI scope.
* CAUTION: When using anything other than the request scope, you need to be very careful with the chat memory implementation.
* <p>
* NOTE: When the application also contains the {@code quarkus-micrometer} extension, metrics are automatically generated
* for the method invocations.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package io.quarkiverse.langchain4j;

import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;

/**
* Extends {@link ChatMemoryProvider} to allow for removing {@link ChatMemory}
* when it is no longer needed.
*/
public interface RemovableChatMemoryProvider extends ChatMemoryProvider {

void remove(Object id);
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ private static Object doImplement(AiServiceMethodCreateInfo createInfo, Object[]
}

Object memoryId = memoryId(createInfo, methodArgs).orElse("default");
context.usedMemoryIds.add(memoryId);

if (context.hasChatMemory()) {
ChatMemory chatMemory = context.chatMemory(memoryId);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package io.quarkiverse.langchain4j.runtime.aiservice;

import java.util.Map;

import jakarta.enterprise.context.spi.CreationalContext;

import org.jboss.logging.Logger;

import io.quarkus.arc.BeanDestroyer;

public class DeclarativeAiServiceBeanDestroyer implements BeanDestroyer<AutoCloseable> {

private static final Logger log = Logger.getLogger(DeclarativeAiServiceBeanDestroyer.class);

@Override
public void destroy(AutoCloseable instance, CreationalContext<AutoCloseable> creationalContext,
Map<String, Object> params) {
try {
instance.close();
} catch (Exception e) {
log.error("Unable to close " + instance);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,44 @@
package io.quarkiverse.langchain4j.runtime.aiservice;

import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import dev.langchain4j.service.AiServiceContext;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.RemovableChatMemoryProvider;
import io.quarkiverse.langchain4j.audit.AuditService;

public class QuarkusAiServiceContext extends AiServiceContext {

public AuditService auditService;

public Set<Object> usedMemoryIds = ConcurrentHashMap.newKeySet();

public QuarkusAiServiceContext(Class<?> aiServiceClass) {
super(aiServiceClass);
}

/**
* This is called by the {@code close} method of AiServices registered with {@link RegisterAiService}
* when the bean's scope is closed
*/
public void close() {
removeChatMemories();
}

private void removeChatMemories() {
if (usedMemoryIds.isEmpty()) {
return;
}
RemovableChatMemoryProvider removableChatMemoryProvider = null;
if (chatMemoryProvider instanceof RemovableChatMemoryProvider) {
removableChatMemoryProvider = (RemovableChatMemoryProvider) chatMemoryProvider;
}
for (Object memoryId : usedMemoryIds) {
if (removableChatMemoryProvider != null) {
removableChatMemoryProvider.remove(memoryId);
}
chatMemories.remove(memoryId);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
package io.quarkiverse.langchain4j.samples;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import jakarta.annotation.PreDestroy;
import jakarta.enterprise.context.RequestScoped;

import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import io.quarkiverse.langchain4j.RemovableChatMemoryProvider;
import jakarta.inject.Singleton;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@RequestScoped
public class ChatMemoryBean implements ChatMemoryProvider {
@Singleton
public class ChatMemoryBean implements RemovableChatMemoryProvider {

private final Map<Object, ChatMemory> memories = new ConcurrentHashMap<>();

Expand All @@ -23,8 +20,8 @@ public ChatMemory get(Object memoryId) {
.build());
}

@PreDestroy
public void close() {
memories.clear();
@Override
public void remove(Object id) {
memories.remove(id);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.quarkiverse.langchain4j.samples;

import io.quarkiverse.langchain4j.RemovableChatMemoryProvider;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
Expand All @@ -11,7 +12,7 @@
public class MySmallMemoryProvider implements Supplier<ChatMemoryProvider> {
@Override
public ChatMemoryProvider get() {
return new ChatMemoryProvider() {
return new RemovableChatMemoryProvider() {
private final Map<Object, ChatMemory> memories = new ConcurrentHashMap<>();

@Override
Expand All @@ -21,6 +22,11 @@ public ChatMemory get(Object memoryId) {
.id(memoryId)
.build());
}

@Override
public void remove(Object id) {
memories.remove(id);
}
};
}
}
Loading

0 comments on commit ae6d811

Please sign in to comment.