From 816a91184e7c47a4ffb6eda8428c1318acf16242 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 29 Jan 2025 16:32:57 -0700 Subject: [PATCH] Add support for DeepSink --- .../android/app/build.gradle.kts | 2 +- .../examples/llminference/ChatMessage.kt | 5 +- .../examples/llminference/ChatScreen.kt | 6 +- .../examples/llminference/ChatUiState.kt | 135 +++++++++++++----- .../examples/llminference/ChatViewModel.kt | 15 +- .../examples/llminference/InferenceModel.kt | 22 ++- .../mediapipe/examples/llminference/Model.kt | 7 +- .../app/src/main/res/values/strings.xml | 1 + 8 files changed, 132 insertions(+), 61 deletions(-) diff --git a/examples/llm_inference/android/app/build.gradle.kts b/examples/llm_inference/android/app/build.gradle.kts index 061e0841..45dd3c8b 100644 --- a/examples/llm_inference/android/app/build.gradle.kts +++ b/examples/llm_inference/android/app/build.gradle.kts @@ -64,7 +64,7 @@ dependencies { implementation("androidx.compose.ui:ui-tooling-preview") implementation("androidx.compose.material3:material3") - implementation ("com.google.mediapipe:tasks-genai:0.10.16") + implementation ("com.google.mediapipe:tasks-genai:0.10.21") testImplementation("junit:junit:4.13.2") androidTestImplementation("androidx.test.ext:junit:1.1.5") diff --git a/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatMessage.kt b/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatMessage.kt index 3430e3b1..1e13f1fe 100644 --- a/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatMessage.kt +++ b/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatMessage.kt @@ -9,8 +9,11 @@ data class ChatMessage( val id: String = UUID.randomUUID().toString(), val rawMessage: String = "", val author: String, - val isLoading: Boolean = false + val isLoading: Boolean = false, + val isThinking: Boolean = false, ) { + val isEmpty: Boolean + get() = rawMessage.trim().isEmpty() val isFromUser: Boolean get() = author == USER_PREFIX val message: String diff --git a/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatScreen.kt b/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatScreen.kt index dc1f1408..606cab4c 100644 --- a/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatScreen.kt +++ b/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatScreen.kt @@ -134,6 +134,8 @@ fun ChatItem( ) { val backgroundColor = if (chatMessage.isFromUser) { MaterialTheme.colorScheme.tertiaryContainer + } else if (chatMessage.isThinking) { + MaterialTheme.colorScheme.primaryContainer } else { MaterialTheme.colorScheme.secondaryContainer } @@ -158,7 +160,9 @@ fun ChatItem( ) { val author = if (chatMessage.isFromUser) { stringResource(R.string.user_label) - } else { + } else if (chatMessage.isThinking) { + stringResource(R.string.thinking_label) + } else { stringResource(R.string.model_label) } Text( diff --git a/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatUiState.kt b/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatUiState.kt index 7d6e965f..8a832dc2 100644 --- a/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatUiState.kt +++ b/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatUiState.kt @@ -7,7 +7,6 @@ const val MODEL_PREFIX = "model" interface UiState { val messages: List - val fullPrompt: String /** * Creates a new loading message. @@ -17,15 +16,21 @@ interface UiState { /** * Appends the specified text to the message with the specified ID. + * THe underlying implementations may split the re-use messages or create new ones. The method + * always returns the ID of the message used. * @param done - indicates whether the model has finished generating the message. + * @return the id of the message that was used. */ - fun appendMessage(id: String, text: String, done: Boolean = false) + fun appendMessage(id: String, text: String, done: Boolean = false): String /** * Creates a new message with the specified text and author. * Return the id of that message. */ fun addMessage(text: String, author: String): String + + /** Formats a messages from the user into the prompt format of the model. */ + fun formatPrompt(text:String) : String } /** @@ -37,26 +42,19 @@ class ChatUiState( private val _messages: MutableList = messages.toMutableStateList() override val messages: List = _messages.reversed() - // Prompt the model with the current chat history - override val fullPrompt: String - get() = _messages.joinToString(separator = "\n") { it.rawMessage } - override fun createLoadingMessage(): String { val chatMessage = ChatMessage(author = MODEL_PREFIX, isLoading = true) _messages.add(chatMessage) return chatMessage.id } - fun appendFirstMessage(id: String, text: String) { - appendMessage(id, text, false) - } - - override fun appendMessage(id: String, text: String, done: Boolean) { + override fun appendMessage(id: String, text: String, done: Boolean) : String{ val index = _messages.indexOfFirst { it.id == id } if (index != -1) { val newText = _messages[index].rawMessage + text _messages[index] = _messages[index].copy(rawMessage = newText, isLoading = false) } + return id } override fun addMessage(text: String, author: String): String { @@ -67,6 +65,10 @@ class ChatUiState( _messages.add(chatMessage) return chatMessage.id } + + override fun formatPrompt(text: String): String { + return text + } } /** @@ -77,56 +79,111 @@ class GemmaUiState( ) : UiState { private val START_TURN = "" private val END_TURN = "" - private val lock = Any() private val _messages: MutableList = messages.toMutableStateList() - override val messages: List - get() = synchronized(lock) { - _messages. apply{ - for (i in indices) { - this[i] = this[i].copy( - rawMessage = this[i].rawMessage.replace(START_TURN + this[i].author + "\n", "") - .replace(END_TURN, "") - ) - } - }.asReversed() + override val messages: List = _messages.asReversed() + override fun createLoadingMessage(): String { + val chatMessage = ChatMessage(author = MODEL_PREFIX, isLoading = true) + _messages.add(chatMessage) + return chatMessage.id + } + + override fun appendMessage(id: String, text: String, done: Boolean): String { + val index = _messages.indexOfFirst { it.id == id } + if (index != -1) { + val newText = _messages[index].rawMessage + text + _messages[index] = _messages[index].copy(rawMessage = newText, isLoading = false) } + return id + } - // Only using the last 4 messages to keep input + output short - override val fullPrompt: String - get() = _messages.takeLast(4).joinToString(separator = "\n") { it.rawMessage } + override fun addMessage(text: String, author: String): String { + val chatMessage = ChatMessage( + rawMessage = text, + author = author + ) + _messages.add(chatMessage) + return chatMessage.id + } + + override fun formatPrompt(text: String): String { + return "$START_TURN$USER_PREFIX\n$text$END_TURN$START_TURN$MODEL_PREFIX" + } +} + + +/** An implementation of [UiState] to be used with the DeepSeek model. */ +class DeepSeeUiState( + messages: List = emptyList() +) : UiState { + private var START_TOKEN = "<|begin▁of▁sentence|>" + private var PROMPT_PREFIX = "<|User|>" + private var PROMPT_SUFFIX = "<|Assistant|>" + private var THINKING_MARKER_START = "" + private var THINKING_MARKER_END = "" + + private val _messages: MutableList = messages.toMutableStateList() + override val messages: List = _messages.asReversed() override fun createLoadingMessage(): String { - val chatMessage = ChatMessage(author = MODEL_PREFIX, isLoading = true) + val chatMessage = ChatMessage(author = MODEL_PREFIX, isLoading = true, isThinking = true) _messages.add(chatMessage) return chatMessage.id } - fun appendFirstMessage(id: String, text: String) { - appendMessage(id, "$START_TURN$MODEL_PREFIX\n$text", false) - } + override fun appendMessage(id: String, text: String, done: Boolean): String { + if (text.contains(THINKING_MARKER_END)) { // The model is done thinking. + val markerEnd = text.indexOf(THINKING_MARKER_END) + THINKING_MARKER_END.length - override fun appendMessage(id: String, text: String, done: Boolean) { - val index = _messages.indexOfFirst { it.id == id } - if (index != -1) { - val newText = if (done) { - // Append the Suffix when model is done generating the response - _messages[index].rawMessage + text + END_TURN + // Add text to current bubble + val prefix = text.substring(0, markerEnd); + val index = appendToMessage(id, prefix) + + var currentMessage = _messages[index]; + + if (currentMessage.isEmpty) { + // No thoughts. Turn the bubble into the model response. + _messages[index] = _messages[index].copy( + isThinking = false + ) + currentMessage = _messages[index] } else { - // Append the text - _messages[index].rawMessage + text + // There are some thoughts. Add a new bubble. + currentMessage = ChatMessage(author = MODEL_PREFIX, isLoading = true, isThinking = false) + _messages.add(currentMessage) } - _messages[index] = _messages[index].copy(rawMessage = newText, isLoading = false) + + val suffix = text.substring(markerEnd); + appendToMessage(currentMessage.id, suffix) + + return currentMessage.id + } else { + appendToMessage(id, text) + return id } } + private fun appendToMessage(id: String, suffix: String) : Int { + val index = _messages.indexOfFirst { it.id == id } + val newText = suffix.replace(THINKING_MARKER_START, "").replace(THINKING_MARKER_END, "") + _messages[index] = _messages[index].copy( + rawMessage = _messages[index].rawMessage + newText, + isLoading = false + ) + return index + } + override fun addMessage(text: String, author: String): String { val chatMessage = ChatMessage( - rawMessage = "$START_TURN$author\n$text$END_TURN", + rawMessage = text, author = author ) _messages.add(chatMessage) return chatMessage.id } + + override fun formatPrompt(text: String): String { + return "$START_TOKEN$PROMPT_PREFIX$text$PROMPT_SUFFIX" + } } diff --git a/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatViewModel.kt b/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatViewModel.kt index cfede4d3..2a9240fd 100644 --- a/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatViewModel.kt +++ b/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/ChatViewModel.kt @@ -16,9 +16,7 @@ class ChatViewModel( private val inferenceModel: InferenceModel ) : ViewModel() { - // `GemmaUiState()` is optimized for the Gemma model. - // Replace `GemmaUiState` with `ChatUiState()` if you're using a different model - private val _uiState: MutableStateFlow = MutableStateFlow(GemmaUiState()) + private val _uiState: MutableStateFlow = MutableStateFlow(inferenceModel.uiState) val uiState: StateFlow = _uiState.asStateFlow() @@ -33,16 +31,11 @@ class ChatViewModel( var currentMessageId: String? = _uiState.value.createLoadingMessage() setInputEnabled(false) try { - val fullPrompt = _uiState.value.fullPrompt - inferenceModel.generateResponseAsync(fullPrompt) + inferenceModel.generateResponseAsync(userMessage) inferenceModel.partialResults - .collectIndexed { index, (partialResult, done) -> + .collectIndexed { _, (partialResult, done) -> currentMessageId?.let { - if (index == 0) { - _uiState.value.appendFirstMessage(it, partialResult) - } else { - _uiState.value.appendMessage(it, partialResult, done) - } + currentMessageId = _uiState.value.appendMessage(it, partialResult, done) if (done) { currentMessageId = null // Re-enable text input diff --git a/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/InferenceModel.kt b/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/InferenceModel.kt index ae5dd71a..503681a2 100644 --- a/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/InferenceModel.kt +++ b/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/InferenceModel.kt @@ -2,6 +2,8 @@ package com.google.mediapipe.examples.llminference import android.content.Context import com.google.mediapipe.tasks.genai.llminference.LlmInference +import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession +import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession.LlmInferenceSessionOptions import java.io.File import kotlinx.coroutines.channels.BufferOverflow import kotlinx.coroutines.flow.MutableSharedFlow @@ -10,6 +12,7 @@ import kotlinx.coroutines.flow.asSharedFlow class InferenceModel private constructor(context: Context) { private var llmInference: LlmInference + private var llmInferenceSession: LlmInferenceSession private val modelExists: Boolean get() = File(model.path).exists() @@ -19,13 +22,14 @@ class InferenceModel private constructor(context: Context) { onBufferOverflow = BufferOverflow.DROP_OLDEST ) val partialResults: SharedFlow> = _partialResults.asSharedFlow() + val uiState: UiState init { if (!modelExists) { throw IllegalArgumentException("Model not found at path: ${model.path}") } - val options = LlmInference.LlmInferenceOptions.builder() + val inferenceOptions = LlmInference.LlmInferenceOptions.builder() .setModelPath(model.path) .setMaxTokens(1024) .setResultListener { partialResult, done -> @@ -33,13 +37,21 @@ class InferenceModel private constructor(context: Context) { } .build() - llmInference = LlmInference.createFromOptions(context, options) + val sessionOptions = LlmInferenceSessionOptions.builder() + .setTemperature(model.temperature) + .setTopK(model.topK) + .setTopP(model.topP) + .build() + + uiState = model.uiState + llmInference = LlmInference.createFromOptions(context, inferenceOptions) + llmInferenceSession = LlmInferenceSession.createFromOptions(llmInference, sessionOptions) } fun generateResponseAsync(prompt: String) { - // Add the gemma prompt prefix to trigger the response. - val gemmaPrompt = prompt + "model\n" - llmInference.generateResponseAsync(gemmaPrompt) + val formattedPrompt = model.uiState.formatPrompt(prompt) + llmInferenceSession.addQueryChunk(formattedPrompt) + llmInferenceSession.generateResponseAsync() } companion object { diff --git a/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/Model.kt b/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/Model.kt index f4685aeb..15aac91a 100644 --- a/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/Model.kt +++ b/examples/llm_inference/android/app/src/main/java/com/google/mediapipe/examples/llminference/Model.kt @@ -2,7 +2,8 @@ package com.google.mediapipe.examples.llminference // NB: Make sure the filename is *unique* per model you use! // Weight caching is currently based on filename alone. -enum class Model(val path: String) { - GEMMA_CPU("/data/local/tmp/llm/gemma-2b-it-cpu-int4.bin"), - GEMMA_GPU("/data/local/tmp/llm/gemma-2b-it-gpu-int4.bin"), +enum class Model(val path: String, val uiState: UiState, val temperature: Float, val topK: Int, val topP: Float) { + GEMMA_CPU("/data/local/tmp/llm/gemma-2b-it-cpu-int4.bin", GemmaUiState(), temperature = 0.8f, topK = 40, topP = 1.0f), + GEMMA_GPU("/data/local/tmp/llm/gemma-2b-it-gpu-int4.bin", GemmaUiState(), temperature = 0.8f, topK = 40, topP = 1.0f), + DEEPSEEK_CPU("/data/local/tmp/llm/deepseek3k_q8_ekv1280.task", DeepSeeUiState(), temperature = 0.6f, topK = 40, topP = 0.7f), } diff --git a/examples/llm_inference/android/app/src/main/res/values/strings.xml b/examples/llm_inference/android/app/src/main/res/values/strings.xml index c54676a1..a399b0c6 100644 --- a/examples/llm_inference/android/app/src/main/res/values/strings.xml +++ b/examples/llm_inference/android/app/src/main/res/values/strings.xml @@ -5,5 +5,6 @@ Send User Model + Thinking... Responses generated by user-provided model