Skip to content

Commit

Permalink
Add support for DeepSink
Browse files Browse the repository at this point in the history
  • Loading branch information
schmidt-sebastian committed Jan 30, 2025
1 parent 3701863 commit 816a911
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 61 deletions.
2 changes: 1 addition & 1 deletion examples/llm_inference/android/app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ dependencies {
implementation("androidx.compose.ui:ui-tooling-preview")
implementation("androidx.compose.material3:material3")

implementation ("com.google.mediapipe:tasks-genai:0.10.16")
implementation ("com.google.mediapipe:tasks-genai:0.10.21")

testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ data class ChatMessage(
val id: String = UUID.randomUUID().toString(),
val rawMessage: String = "",
val author: String,
val isLoading: Boolean = false
val isLoading: Boolean = false,
val isThinking: Boolean = false,
) {
val isEmpty: Boolean
get() = rawMessage.trim().isEmpty()
val isFromUser: Boolean
get() = author == USER_PREFIX
val message: String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ fun ChatItem(
) {
val backgroundColor = if (chatMessage.isFromUser) {
MaterialTheme.colorScheme.tertiaryContainer
} else if (chatMessage.isThinking) {
MaterialTheme.colorScheme.primaryContainer
} else {
MaterialTheme.colorScheme.secondaryContainer
}
Expand All @@ -158,7 +160,9 @@ fun ChatItem(
) {
val author = if (chatMessage.isFromUser) {
stringResource(R.string.user_label)
} else {
} else if (chatMessage.isThinking) {
stringResource(R.string.thinking_label)
} else {
stringResource(R.string.model_label)
}
Text(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ const val MODEL_PREFIX = "model"

interface UiState {
val messages: List<ChatMessage>
val fullPrompt: String

/**
* Creates a new loading message.
Expand All @@ -17,15 +16,21 @@ interface UiState {

/**
* Appends the specified text to the message with the specified ID.
* THe underlying implementations may split the re-use messages or create new ones. The method
* always returns the ID of the message used.
* @param done - indicates whether the model has finished generating the message.
* @return the id of the message that was used.
*/
fun appendMessage(id: String, text: String, done: Boolean = false)
fun appendMessage(id: String, text: String, done: Boolean = false): String

/**
* Creates a new message with the specified text and author.
* Return the id of that message.
*/
fun addMessage(text: String, author: String): String

/** Formats a messages from the user into the prompt format of the model. */
fun formatPrompt(text:String) : String
}

/**
Expand All @@ -37,26 +42,19 @@ class ChatUiState(
private val _messages: MutableList<ChatMessage> = messages.toMutableStateList()
override val messages: List<ChatMessage> = _messages.reversed()

// Prompt the model with the current chat history
override val fullPrompt: String
get() = _messages.joinToString(separator = "\n") { it.rawMessage }

override fun createLoadingMessage(): String {
val chatMessage = ChatMessage(author = MODEL_PREFIX, isLoading = true)
_messages.add(chatMessage)
return chatMessage.id
}

fun appendFirstMessage(id: String, text: String) {
appendMessage(id, text, false)
}

override fun appendMessage(id: String, text: String, done: Boolean) {
override fun appendMessage(id: String, text: String, done: Boolean) : String{
val index = _messages.indexOfFirst { it.id == id }
if (index != -1) {
val newText = _messages[index].rawMessage + text
_messages[index] = _messages[index].copy(rawMessage = newText, isLoading = false)
}
return id
}

override fun addMessage(text: String, author: String): String {
Expand All @@ -67,6 +65,10 @@ class ChatUiState(
_messages.add(chatMessage)
return chatMessage.id
}

override fun formatPrompt(text: String): String {
return text
}
}

/**
Expand All @@ -77,56 +79,111 @@ class GemmaUiState(
) : UiState {
private val START_TURN = "<start_of_turn>"
private val END_TURN = "<end_of_turn>"
private val lock = Any()

private val _messages: MutableList<ChatMessage> = messages.toMutableStateList()
override val messages: List<ChatMessage>
get() = synchronized(lock) {
_messages. apply{
for (i in indices) {
this[i] = this[i].copy(
rawMessage = this[i].rawMessage.replace(START_TURN + this[i].author + "\n", "")
.replace(END_TURN, "")
)
}
}.asReversed()
override val messages: List<ChatMessage> = _messages.asReversed()

override fun createLoadingMessage(): String {
val chatMessage = ChatMessage(author = MODEL_PREFIX, isLoading = true)
_messages.add(chatMessage)
return chatMessage.id
}

override fun appendMessage(id: String, text: String, done: Boolean): String {
val index = _messages.indexOfFirst { it.id == id }
if (index != -1) {
val newText = _messages[index].rawMessage + text
_messages[index] = _messages[index].copy(rawMessage = newText, isLoading = false)
}
return id
}

// Only using the last 4 messages to keep input + output short
override val fullPrompt: String
get() = _messages.takeLast(4).joinToString(separator = "\n") { it.rawMessage }
override fun addMessage(text: String, author: String): String {
val chatMessage = ChatMessage(
rawMessage = text,
author = author
)
_messages.add(chatMessage)
return chatMessage.id
}

override fun formatPrompt(text: String): String {
return "$START_TURN$USER_PREFIX\n$text$END_TURN$START_TURN$MODEL_PREFIX"
}
}


/** An implementation of [UiState] to be used with the DeepSeek model. */
class DeepSeeUiState(
messages: List<ChatMessage> = emptyList()
) : UiState {
private var START_TOKEN = "<|begin▁of▁sentence|>"
private var PROMPT_PREFIX = "<|User|>"
private var PROMPT_SUFFIX = "<|Assistant|>"
private var THINKING_MARKER_START = "<think>"
private var THINKING_MARKER_END = "</think>"

private val _messages: MutableList<ChatMessage> = messages.toMutableStateList()
override val messages: List<ChatMessage> = _messages.asReversed()

override fun createLoadingMessage(): String {
val chatMessage = ChatMessage(author = MODEL_PREFIX, isLoading = true)
val chatMessage = ChatMessage(author = MODEL_PREFIX, isLoading = true, isThinking = true)
_messages.add(chatMessage)
return chatMessage.id
}

fun appendFirstMessage(id: String, text: String) {
appendMessage(id, "$START_TURN$MODEL_PREFIX\n$text", false)
}
override fun appendMessage(id: String, text: String, done: Boolean): String {
if (text.contains(THINKING_MARKER_END)) { // The model is done thinking.
val markerEnd = text.indexOf(THINKING_MARKER_END) + THINKING_MARKER_END.length

override fun appendMessage(id: String, text: String, done: Boolean) {
val index = _messages.indexOfFirst { it.id == id }
if (index != -1) {
val newText = if (done) {
// Append the Suffix when model is done generating the response
_messages[index].rawMessage + text + END_TURN
// Add text to current bubble
val prefix = text.substring(0, markerEnd);
val index = appendToMessage(id, prefix)

var currentMessage = _messages[index];

if (currentMessage.isEmpty) {
// No thoughts. Turn the bubble into the model response.
_messages[index] = _messages[index].copy(
isThinking = false
)
currentMessage = _messages[index]
} else {
// Append the text
_messages[index].rawMessage + text
// There are some thoughts. Add a new bubble.
currentMessage = ChatMessage(author = MODEL_PREFIX, isLoading = true, isThinking = false)
_messages.add(currentMessage)
}
_messages[index] = _messages[index].copy(rawMessage = newText, isLoading = false)

val suffix = text.substring(markerEnd);
appendToMessage(currentMessage.id, suffix)

return currentMessage.id
} else {
appendToMessage(id, text)
return id
}
}

private fun appendToMessage(id: String, suffix: String) : Int {
val index = _messages.indexOfFirst { it.id == id }
val newText = suffix.replace(THINKING_MARKER_START, "").replace(THINKING_MARKER_END, "")
_messages[index] = _messages[index].copy(
rawMessage = _messages[index].rawMessage + newText,
isLoading = false
)
return index
}

override fun addMessage(text: String, author: String): String {
val chatMessage = ChatMessage(
rawMessage = "$START_TURN$author\n$text$END_TURN",
rawMessage = text,
author = author
)
_messages.add(chatMessage)
return chatMessage.id
}

override fun formatPrompt(text: String): String {
return "$START_TOKEN$PROMPT_PREFIX$text$PROMPT_SUFFIX"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ class ChatViewModel(
private val inferenceModel: InferenceModel
) : ViewModel() {

// `GemmaUiState()` is optimized for the Gemma model.
// Replace `GemmaUiState` with `ChatUiState()` if you're using a different model
private val _uiState: MutableStateFlow<GemmaUiState> = MutableStateFlow(GemmaUiState())
private val _uiState: MutableStateFlow<UiState> = MutableStateFlow(inferenceModel.uiState)
val uiState: StateFlow<UiState> =
_uiState.asStateFlow()

Expand All @@ -33,16 +31,11 @@ class ChatViewModel(
var currentMessageId: String? = _uiState.value.createLoadingMessage()
setInputEnabled(false)
try {
val fullPrompt = _uiState.value.fullPrompt
inferenceModel.generateResponseAsync(fullPrompt)
inferenceModel.generateResponseAsync(userMessage)
inferenceModel.partialResults
.collectIndexed { index, (partialResult, done) ->
.collectIndexed { _, (partialResult, done) ->
currentMessageId?.let {
if (index == 0) {
_uiState.value.appendFirstMessage(it, partialResult)
} else {
_uiState.value.appendMessage(it, partialResult, done)
}
currentMessageId = _uiState.value.appendMessage(it, partialResult, done)
if (done) {
currentMessageId = null
// Re-enable text input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package com.google.mediapipe.examples.llminference

import android.content.Context
import com.google.mediapipe.tasks.genai.llminference.LlmInference
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession
import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession.LlmInferenceSessionOptions
import java.io.File
import kotlinx.coroutines.channels.BufferOverflow
import kotlinx.coroutines.flow.MutableSharedFlow
Expand All @@ -10,6 +12,7 @@ import kotlinx.coroutines.flow.asSharedFlow

class InferenceModel private constructor(context: Context) {
private var llmInference: LlmInference
private var llmInferenceSession: LlmInferenceSession

private val modelExists: Boolean
get() = File(model.path).exists()
Expand All @@ -19,27 +22,36 @@ class InferenceModel private constructor(context: Context) {
onBufferOverflow = BufferOverflow.DROP_OLDEST
)
val partialResults: SharedFlow<Pair<String, Boolean>> = _partialResults.asSharedFlow()
val uiState: UiState

init {
if (!modelExists) {
throw IllegalArgumentException("Model not found at path: ${model.path}")
}

val options = LlmInference.LlmInferenceOptions.builder()
val inferenceOptions = LlmInference.LlmInferenceOptions.builder()
.setModelPath(model.path)
.setMaxTokens(1024)
.setResultListener { partialResult, done ->
_partialResults.tryEmit(partialResult to done)
}
.build()

llmInference = LlmInference.createFromOptions(context, options)
val sessionOptions = LlmInferenceSessionOptions.builder()
.setTemperature(model.temperature)
.setTopK(model.topK)
.setTopP(model.topP)
.build()

uiState = model.uiState
llmInference = LlmInference.createFromOptions(context, inferenceOptions)
llmInferenceSession = LlmInferenceSession.createFromOptions(llmInference, sessionOptions)
}

fun generateResponseAsync(prompt: String) {
// Add the gemma prompt prefix to trigger the response.
val gemmaPrompt = prompt + "<start_of_turn>model\n"
llmInference.generateResponseAsync(gemmaPrompt)
val formattedPrompt = model.uiState.formatPrompt(prompt)
llmInferenceSession.addQueryChunk(formattedPrompt)
llmInferenceSession.generateResponseAsync()
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package com.google.mediapipe.examples.llminference

// NB: Make sure the filename is *unique* per model you use!
// Weight caching is currently based on filename alone.
enum class Model(val path: String) {
GEMMA_CPU("/data/local/tmp/llm/gemma-2b-it-cpu-int4.bin"),
GEMMA_GPU("/data/local/tmp/llm/gemma-2b-it-gpu-int4.bin"),
enum class Model(val path: String, val uiState: UiState, val temperature: Float, val topK: Int, val topP: Float) {
GEMMA_CPU("/data/local/tmp/llm/gemma-2b-it-cpu-int4.bin", GemmaUiState(), temperature = 0.8f, topK = 40, topP = 1.0f),
GEMMA_GPU("/data/local/tmp/llm/gemma-2b-it-gpu-int4.bin", GemmaUiState(), temperature = 0.8f, topK = 40, topP = 1.0f),
DEEPSEEK_CPU("/data/local/tmp/llm/deepseek3k_q8_ekv1280.task", DeepSeeUiState(), temperature = 0.6f, topK = 40, topP = 0.7f),
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
<string name="action_send">Send</string>
<string name="user_label">User</string>
<string name="model_label">Model</string>
<string name="thinking_label">Thinking...</string>
<string name="disclaimer">Responses generated by user-provided model</string>
</resources>

0 comments on commit 816a911

Please sign in to comment.