From 55d9825ab5588d003f2d70e17e71cc25e8f08bb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20M=C3=BCller?= Date: Tue, 10 Sep 2024 11:41:23 +0200 Subject: [PATCH] refactor: Use coroutines for LangChain4j interactions --- .../fmueller/jarvis/ai/OllamaService.kt | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/main/kotlin/com/github/fmueller/jarvis/ai/OllamaService.kt b/src/main/kotlin/com/github/fmueller/jarvis/ai/OllamaService.kt index aecc37c..daf9566 100644 --- a/src/main/kotlin/com/github/fmueller/jarvis/ai/OllamaService.kt +++ b/src/main/kotlin/com/github/fmueller/jarvis/ai/OllamaService.kt @@ -5,6 +5,8 @@ import dev.langchain4j.model.ollama.OllamaStreamingChatModel import dev.langchain4j.service.AiServices import dev.langchain4j.service.TokenStream import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.suspendCancellableCoroutine import kotlinx.coroutines.withContext import kotlinx.serialization.Serializable import kotlinx.serialization.decodeFromString @@ -15,7 +17,7 @@ import java.net.http.HttpClient import java.net.http.HttpRequest import java.net.http.HttpResponse import java.time.Duration -import java.util.concurrent.CompletableFuture +import kotlin.coroutines.resumeWithException @Serializable private data class ChatMessage(val role: String, val content: String) @@ -63,22 +65,27 @@ object OllamaService { .systemMessageProvider { chatMemoryId -> systemMessage } .build() + @OptIn(ExperimentalCoroutinesApi::class) suspend fun chatLangChain4J(conversation: Conversation): String = withContext(Dispatchers.IO) { // TODO check if model is available // TODO if not, download model - // TODO migration to LangChain4J: change to Kotlin Coroutines // TODO migration to LangChain4J: add timeout handling // TODO migration to LangChain4J: research how token limits work and context windows // TODO migration to LangChain4J: research how chat memory is configured - val future = CompletableFuture() - assistant - .chat(conversation.getLastUserMessage()?.content ?: "Tell me that there was not message provided.") - .onNext { update -> conversation.addToMessageBeingGenerated(update) } - .onComplete { response -> future.complete(response.content().text()) } - .onError { error -> future.complete("Error: ${error.message}") } - .start() - future.get() + + suspendCancellableCoroutine { continuation -> + assistant + .chat(conversation.getLastUserMessage()?.content ?: "Tell me that there was no message provided.") + .onNext { update -> conversation.addToMessageBeingGenerated(update) } + .onComplete { response -> continuation.resume(response.content().text()) { t -> /* noop */ } } + .onError { error -> continuation.resumeWithException(Exception("Error: ${error.message}")) } + .start() + + continuation.invokeOnCancellation { + // TODO when LangChain4j implemented AbortController, call it here + } + } } suspend fun chat(conversation: Conversation): String = withContext(Dispatchers.IO) {