Skip to content

Commit

Permalink
Merge pull request #47 from Taewan-P/feat/ollama-support
Browse files Browse the repository at this point in the history
Add Ollama Support
  • Loading branch information
Taewan-P authored Oct 6, 2024
2 parents ff53645 + 642b13e commit ad02c0a
Show file tree
Hide file tree
Showing 30 changed files with 547 additions and 84 deletions.
1 change: 1 addition & 0 deletions .idea/gradle.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions .idea/kotlinc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions .idea/runConfigurations.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ plugins {
alias(libs.plugins.android.application)
alias(libs.plugins.jetbrains.kotlin.android)
alias(libs.plugins.android.hilt)
alias(libs.plugins.compose.compiler)
alias(libs.plugins.kotlin.ksp)
alias(libs.plugins.kotlin.parcelize)
alias(libs.plugins.auto.license)
Expand Down Expand Up @@ -51,9 +52,9 @@ android {
buildFeatures {
compose = true
}
composeOptions {
kotlinCompilerExtensionVersion = "1.5.13" // Make sure to update this when Kotlin version is updated
}
// composeOptions {
// kotlinCompilerExtensionVersion = "1.5.13" // Make sure to update this when Kotlin version is updated
// }
packaging {
resources {
excludes += "/META-INF/{AL2.0,LGPL2.1}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@ object ModelConstants {
val openaiModels = linkedSetOf("gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo")
val anthropicModels = linkedSetOf("claude-3-5-sonnet-20240620", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307")
val googleModels = linkedSetOf("gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-1.0-pro")
const val OPENAI_API_URL = "https://api.openai.com"
const val ANTHROPIC_API_URL = "https://api.anthropic.com"
val ollamaModels = linkedSetOf<String>()

const val OPENAI_API_URL = "https://api.openai.com/v1/"
const val ANTHROPIC_API_URL = "https://api.anthropic.com/"
const val GOOGLE_API_URL = "https://generativelanguage.googleapis.com"

fun getDefaultAPIUrl(apiType: ApiType) = when (apiType) {
ApiType.OPENAI -> OPENAI_API_URL
ApiType.ANTHROPIC -> ANTHROPIC_API_URL
ApiType.GOOGLE -> GOOGLE_API_URL
ApiType.OLLAMA -> ""
}

const val ANTHROPIC_MAXIMUM_TOKEN = 4096
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,44 @@ class SettingDataSourceImpl @Inject constructor(
private val apiStatusMap = mapOf(
ApiType.OPENAI to booleanPreferencesKey("openai_status"),
ApiType.ANTHROPIC to booleanPreferencesKey("anthropic_status"),
ApiType.GOOGLE to booleanPreferencesKey("google_status")
ApiType.GOOGLE to booleanPreferencesKey("google_status"),
ApiType.OLLAMA to booleanPreferencesKey("ollama_status")
)
private val apiUrlMap = mapOf(
ApiType.OPENAI to stringPreferencesKey("openai_url"),
ApiType.ANTHROPIC to stringPreferencesKey("anthropic_url"),
ApiType.GOOGLE to stringPreferencesKey("google_url")
ApiType.GOOGLE to stringPreferencesKey("google_url"),
ApiType.OLLAMA to stringPreferencesKey("ollama_url")
)
private val apiTokenMap = mapOf(
ApiType.OPENAI to stringPreferencesKey("openai_token"),
ApiType.ANTHROPIC to stringPreferencesKey("anthropic_token"),
ApiType.GOOGLE to stringPreferencesKey("google_token")
ApiType.GOOGLE to stringPreferencesKey("google_token"),
ApiType.OLLAMA to stringPreferencesKey("ollama_token")
)
private val apiModelMap = mapOf(
ApiType.OPENAI to stringPreferencesKey("openai_model"),
ApiType.ANTHROPIC to stringPreferencesKey("anthropic_model"),
ApiType.GOOGLE to stringPreferencesKey("google_model")
ApiType.GOOGLE to stringPreferencesKey("google_model"),
ApiType.OLLAMA to stringPreferencesKey("ollama_model")
)
private val apiTemperatureMap = mapOf(
ApiType.OPENAI to floatPreferencesKey("openai_temperature"),
ApiType.ANTHROPIC to floatPreferencesKey("anthropic_temperature"),
ApiType.GOOGLE to floatPreferencesKey("google_temperature")
ApiType.GOOGLE to floatPreferencesKey("google_temperature"),
ApiType.OLLAMA to floatPreferencesKey("ollama_temperature")
)
private val apiTopPMap = mapOf(
ApiType.OPENAI to floatPreferencesKey("openai_top_p"),
ApiType.ANTHROPIC to floatPreferencesKey("anthropic_top_p"),
ApiType.GOOGLE to floatPreferencesKey("google_top_p")
ApiType.GOOGLE to floatPreferencesKey("google_top_p"),
ApiType.OLLAMA to floatPreferencesKey("ollama_top_p")
)
private val apiSystemPromptMap = mapOf(
ApiType.OPENAI to stringPreferencesKey("openai_system_prompt"),
ApiType.ANTHROPIC to stringPreferencesKey("anthropic_system_prompt"),
ApiType.GOOGLE to stringPreferencesKey("google_system_prompt")
ApiType.GOOGLE to stringPreferencesKey("google_system_prompt"),
ApiType.OLLAMA to stringPreferencesKey("ollama_system_prompt")
)
private val dynamicThemeKey = intPreferencesKey("dynamic_mode")
private val themeModeKey = intPreferencesKey("theme_mode")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ package dev.chungjungsoo.gptmobile.data.model
enum class ApiType {
OPENAI,
ANTHROPIC,
GOOGLE
GOOGLE,
OLLAMA
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class AnthropicAPIImpl @Inject constructor(

val builder = HttpRequestBuilder().apply {
method = HttpMethod.Post
url("$apiUrl/v1/messages")
if (apiUrl.endsWith("/")) url("${apiUrl}v1/messages") else url("$apiUrl/v1/messages")
contentType(ContentType.Application.Json)
setBody(body)
accept(ContentType.Text.EventStream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ interface ChatRepository {
suspend fun completeOpenAIChat(question: Message, history: List<Message>): Flow<ApiState>
suspend fun completeAnthropicChat(question: Message, history: List<Message>): Flow<ApiState>
suspend fun completeGoogleChat(question: Message, history: List<Message>): Flow<ApiState>
suspend fun completeOllamaChat(question: Message, history: List<Message>): Flow<ApiState>
suspend fun fetchChatList(): List<ChatRoom>
suspend fun fetchMessages(chatId: Int): List<Message>
suspend fun updateChatTitle(chatRoom: ChatRoom, title: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ChatRepositoryImpl @Inject constructor(

private lateinit var openAI: OpenAI
private lateinit var google: GenerativeModel
private lateinit var ollama: OpenAI

override suspend fun completeOpenAIChat(question: Message, history: List<Message>): Flow<ApiState> {
val platform = checkNotNull(settingRepository.fetchPlatforms().firstOrNull { it.name == ApiType.OPENAI })
Expand All @@ -63,7 +64,7 @@ class ChatRepositoryImpl @Inject constructor(
)

return openAI.chatCompletions(chatCompletionRequest)
.map<ChatCompletionChunk, ApiState> { chunk -> ApiState.Success(chunk.choices[0].delta.content ?: "") }
.map<ChatCompletionChunk, ApiState> { chunk -> ApiState.Success(chunk.choices[0].delta?.content ?: "") }
.catch { throwable -> emit(ApiState.Error(throwable.message ?: "Unknown error")) }
.onStart { emit(ApiState.Loading) }
.onCompletion { emit(ApiState.Done) }
Expand Down Expand Up @@ -125,6 +126,28 @@ class ChatRepositoryImpl @Inject constructor(
.onCompletion { emit(ApiState.Done) }
}

override suspend fun completeOllamaChat(question: Message, history: List<Message>): Flow<ApiState> {
val platform = checkNotNull(settingRepository.fetchPlatforms().firstOrNull { it.name == ApiType.OLLAMA })
ollama = OpenAI(platform.token ?: "", host = OpenAIHost(baseUrl = "${platform.apiUrl}v1/"))

val generatedMessages = messageToOpenAIMessage(history + listOf(question))
val generatedMessageWithPrompt = listOf(
ChatMessage(role = ChatRole.System, content = platform.systemPrompt ?: ModelConstants.DEFAULT_PROMPT)
) + generatedMessages
val chatCompletionRequest = ChatCompletionRequest(
model = ModelId(platform.model ?: ""),
messages = generatedMessageWithPrompt,
temperature = platform.temperature?.toDouble(),
topP = platform.topP?.toDouble()
)

return ollama.chatCompletions(chatCompletionRequest)
.map<ChatCompletionChunk, ApiState> { chunk -> ApiState.Success(chunk.choices[0].delta?.content ?: "") }
.catch { throwable -> emit(ApiState.Error(throwable.message ?: "Unknown error")) }
.onStart { emit(ApiState.Loading) }
.onCompletion { emit(ApiState.Done) }
}

override suspend fun fetchChatList(): List<ChatRoom> = chatRoomDao.getChatRooms()

override suspend fun fetchMessages(chatId: Int): List<Message> = messageDao.loadMessages(chatId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SettingRepositoryImpl @Inject constructor(
ApiType.OPENAI -> settingDataSource.getAPIUrl(apiType) ?: ModelConstants.OPENAI_API_URL
ApiType.ANTHROPIC -> settingDataSource.getAPIUrl(apiType) ?: ModelConstants.ANTHROPIC_API_URL
ApiType.GOOGLE -> settingDataSource.getAPIUrl(apiType) ?: ModelConstants.GOOGLE_API_URL
ApiType.OLLAMA -> settingDataSource.getAPIUrl(apiType) ?: ""
}
val token = settingDataSource.getToken(apiType)
val model = settingDataSource.getModel(apiType)
Expand All @@ -28,11 +29,12 @@ class SettingRepositoryImpl @Inject constructor(
ApiType.OPENAI -> settingDataSource.getSystemPrompt(ApiType.OPENAI) ?: ModelConstants.OPENAI_PROMPT
ApiType.ANTHROPIC -> settingDataSource.getSystemPrompt(ApiType.ANTHROPIC) ?: ModelConstants.DEFAULT_PROMPT
ApiType.GOOGLE -> settingDataSource.getSystemPrompt(ApiType.GOOGLE) ?: ModelConstants.DEFAULT_PROMPT
ApiType.OLLAMA -> settingDataSource.getSystemPrompt(ApiType.OLLAMA) ?: ModelConstants.DEFAULT_PROMPT
}

Platform(
name = apiType,
enabled = status ?: false,
enabled = status == true,
apiUrl = apiUrl,
token = token,
model = model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import dev.chungjungsoo.gptmobile.presentation.ui.setting.SettingScreen
import dev.chungjungsoo.gptmobile.presentation.ui.setting.SettingViewModel
import dev.chungjungsoo.gptmobile.presentation.ui.setup.SelectModelScreen
import dev.chungjungsoo.gptmobile.presentation.ui.setup.SelectPlatformScreen
import dev.chungjungsoo.gptmobile.presentation.ui.setup.SetupAPIUrlScreen
import dev.chungjungsoo.gptmobile.presentation.ui.setup.SetupCompleteScreen
import dev.chungjungsoo.gptmobile.presentation.ui.setup.SetupViewModel
import dev.chungjungsoo.gptmobile.presentation.ui.setup.TokenInputScreen
Expand Down Expand Up @@ -117,6 +118,31 @@ fun NavGraphBuilder.setupNavigation(
onBackAction = { navController.navigateUp() }
)
}
composable(route = Route.OLLAMA_MODEL_SELECT) {
val parentEntry = remember(it) {
navController.getBackStackEntry(Route.SETUP_ROUTE)
}
val setupViewModel: SetupViewModel = hiltViewModel(parentEntry)
SelectModelScreen(
setupViewModel = setupViewModel,
currentRoute = Route.OLLAMA_MODEL_SELECT,
platformType = ApiType.OLLAMA,
onNavigate = { route -> navController.navigate(route) },
onBackAction = { navController.navigateUp() }
)
}
composable(route = Route.OLLAMA_API_ADDRESS) {
val parentEntry = remember(it) {
navController.getBackStackEntry(Route.SETUP_ROUTE)
}
val setupViewModel: SetupViewModel = hiltViewModel(parentEntry)
SetupAPIUrlScreen(
setupViewModel = setupViewModel,
currentRoute = Route.OLLAMA_API_ADDRESS,
onNavigate = { route -> navController.navigate(route) },
onBackAction = { navController.navigateUp() }
)
}
composable(route = Route.SETUP_COMPLETE) {
val parentEntry = remember(it) {
navController.getBackStackEntry(Route.SETUP_ROUTE)
Expand Down Expand Up @@ -188,6 +214,7 @@ fun NavGraphBuilder.settingNavigation(navController: NavHostController) {
ApiType.OPENAI -> navController.navigate(Route.OPENAI_SETTINGS)
ApiType.ANTHROPIC -> navController.navigate(Route.ANTHROPIC_SETTINGS)
ApiType.GOOGLE -> navController.navigate(Route.GOOGLE_SETTINGS)
ApiType.OLLAMA -> navController.navigate(Route.OLLAMA_SETTINGS)
}
},
onNavigateToAboutPage = { navController.navigate(Route.ABOUT_PAGE) }
Expand Down Expand Up @@ -223,6 +250,16 @@ fun NavGraphBuilder.settingNavigation(navController: NavHostController) {
apiType = ApiType.GOOGLE
) { navController.navigateUp() }
}
composable(Route.OLLAMA_SETTINGS) {
val parentEntry = remember(it) {
navController.getBackStackEntry(Route.SETTING_ROUTE)
}
val settingViewModel: SettingViewModel = hiltViewModel(parentEntry)
PlatformSettingScreen(
settingViewModel = settingViewModel,
apiType = ApiType.OLLAMA
) { navController.navigateUp() }
}
composable(Route.ABOUT_PAGE) {
AboutScreen(
onNavigationClick = { navController.navigateUp() },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ object Route {
const val OPENAI_MODEL_SELECT = "openai_model_select"
const val ANTHROPIC_MODEL_SELECT = "anthropic_model_select"
const val GOOGLE_MODEL_SELECT = "google_model_select"
const val OLLAMA_MODEL_SELECT = "ollama_model_select"
const val OLLAMA_API_ADDRESS = "ollama_api_address"
const val SETUP_COMPLETE = "setup_complete"

const val CHAT_LIST = "chat_list"
Expand All @@ -20,6 +22,7 @@ object Route {
const val OPENAI_SETTINGS = "openai_settings"
const val ANTHROPIC_SETTINGS = "anthropic_settings"
const val GOOGLE_SETTINGS = "google_settings"
const val OLLAMA_SETTINGS = "ollama_settings"
const val ABOUT_PAGE = "about"
const val LICENSE = "license"
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,18 @@ fun ChatScreen(
val messages by chatViewModel.messages.collectManagedState()
val question by chatViewModel.question.collectManagedState()
val appEnabledPlatforms by chatViewModel.enabledPlatformsInApp.collectManagedState()

val openaiLoadingState by chatViewModel.openaiLoadingState.collectManagedState()
val anthropicLoadingState by chatViewModel.anthropicLoadingState.collectManagedState()
val googleLoadingState by chatViewModel.googleLoadingState.collectManagedState()
val ollamaLoadingState by chatViewModel.ollamaLoadingState.collectManagedState()

val userMessage by chatViewModel.userMessage.collectManagedState()

val openAIMessage by chatViewModel.openAIMessage.collectManagedState()
val anthropicMessage by chatViewModel.anthropicMessage.collectManagedState()
val googleMessage by chatViewModel.googleMessage.collectManagedState()
val ollamaMessage by chatViewModel.ollamaMessage.collectManagedState()

val canUseChat = (chatViewModel.enabledPlatformsInChat.toSet() - appEnabledPlatforms.toSet()).isEmpty()
val groupedMessages = remember(messages) { groupMessages(messages) }
Expand Down Expand Up @@ -205,12 +210,14 @@ fun ChatScreen(
ApiType.OPENAI -> openAIMessage
ApiType.ANTHROPIC -> anthropicMessage
ApiType.GOOGLE -> googleMessage
ApiType.OLLAMA -> ollamaMessage
}

val loadingState = when (apiType) {
ApiType.OPENAI -> openaiLoadingState
ApiType.ANTHROPIC -> anthropicLoadingState
ApiType.GOOGLE -> googleLoadingState
ApiType.OLLAMA -> ollamaLoadingState
}

OpponentChatBubble(
Expand Down
Loading

0 comments on commit ad02c0a

Please sign in to comment.