From 32c1c98914b3fd25e12d72e4cbc3802988f4f1e7 Mon Sep 17 00:00:00 2001 From: damonsk <709078+damonsk@users.noreply.github.com> Date: Sun, 27 Oct 2024 01:08:34 +0000 Subject: [PATCH] Add ability to use LLMs and embedding models from Ollama running locally. (#45) * add new ollama provider, ability to use local ollama llama3.1 llm * revert header override for local ollama. * added in ollama nomic-embed-text embeddings * remove toggle to use ollama and fallback to using llm names. * add in ollama embeddings - mxbai-embed-large * configure ollama url via settings, and throw expection if not set. * adjust dimensions for mxbai_embed_large * retain bearer header incase needed for hosted ollama * updated with linting changes. * remove incept5/llama3.1-claude and be explicit on llama3.1:8b * removed usage of openAIKey for ollama * update filenames to be more meaningful --- ...02_create_vector_data_nomic_embed_text.sql | 10 + ...3_create_vector_data_mxbai_embed_large.sql | 10 + drizzle/meta/0002_snapshot.json | 192 +++++++++++++ drizzle/meta/0003_snapshot.json | 256 ++++++++++++++++++ drizzle/meta/_journal.json | 14 + src/components/chat-view/Chat.tsx | 4 +- src/constants.ts | 18 ++ src/contexts/llm-context.tsx | 20 +- src/db/migrations.json | 18 ++ src/db/schema.ts | 2 + src/settings/SettingTab.tsx | 16 ++ src/types/embedding.ts | 2 + src/types/settings.ts | 1 + src/utils/embedding.ts | 39 +++ src/utils/llm/exception.ts | 7 + src/utils/llm/manager.ts | 20 +- src/utils/llm/ollama.ts | 82 ++++++ src/utils/llm/openai.ts | 108 +------- src/utils/llm/openaiCompatibleProvider.ts | 137 ++++++++++ src/utils/ragEngine.ts | 20 +- 20 files changed, 857 insertions(+), 119 deletions(-) create mode 100644 drizzle/0002_create_vector_data_nomic_embed_text.sql create mode 100644 drizzle/0003_create_vector_data_mxbai_embed_large.sql create mode 100644 drizzle/meta/0002_snapshot.json create mode 100644 drizzle/meta/0003_snapshot.json create mode 100644 src/utils/llm/ollama.ts create mode 100644 src/utils/llm/openaiCompatibleProvider.ts diff --git a/drizzle/0002_create_vector_data_nomic_embed_text.sql b/drizzle/0002_create_vector_data_nomic_embed_text.sql new file mode 100644 index 0000000..2e8fcb1 --- /dev/null +++ b/drizzle/0002_create_vector_data_nomic_embed_text.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS "vector_data_nomic_embed_text" ( + "id" serial PRIMARY KEY NOT NULL, + "path" text NOT NULL, + "mtime" bigint NOT NULL, + "content" text NOT NULL, + "embedding" vector(768), + "metadata" jsonb NOT NULL +); +--> statement-breakpoint +CREATE INDEX IF NOT EXISTS "embeddingIndex_nomic_embed_text" ON "vector_data_nomic_embed_text" USING hnsw ("embedding" vector_cosine_ops); \ No newline at end of file diff --git a/drizzle/0003_create_vector_data_mxbai_embed_large.sql b/drizzle/0003_create_vector_data_mxbai_embed_large.sql new file mode 100644 index 0000000..e03bd96 --- /dev/null +++ b/drizzle/0003_create_vector_data_mxbai_embed_large.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS "vector_data_mxbai_embed_large" ( + "id" serial PRIMARY KEY NOT NULL, + "path" text NOT NULL, + "mtime" bigint NOT NULL, + "content" text NOT NULL, + "embedding" vector(1024), + "metadata" jsonb NOT NULL +); +--> statement-breakpoint +CREATE INDEX IF NOT EXISTS "embeddingIndex_mxbai_embed_large" ON "vector_data_mxbai_embed_large" USING hnsw ("embedding" vector_cosine_ops); \ No newline at end of file diff --git a/drizzle/meta/0002_snapshot.json b/drizzle/meta/0002_snapshot.json new file mode 100644 index 0000000..4456757 --- /dev/null +++ b/drizzle/meta/0002_snapshot.json @@ -0,0 +1,192 @@ +{ + "id": "66d66503-35b5-4c19-8433-529426a63956", + "prevId": "1084161a-8de5-4452-ba8b-6ffc3549d411", + "version": "7", + "dialect": "postgresql", + "tables": { + "public.vector_data_text_embedding_3_small": { + "name": "vector_data_text_embedding_3_small", + "schema": "", + "columns": { + "id": { + "name": "id", + "type": "serial", + "primaryKey": true, + "notNull": true + }, + "path": { + "name": "path", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "mtime": { + "name": "mtime", + "type": "bigint", + "primaryKey": false, + "notNull": true + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "embedding": { + "name": "embedding", + "type": "vector(1536)", + "primaryKey": false, + "notNull": false + }, + "metadata": { + "name": "metadata", + "type": "jsonb", + "primaryKey": false, + "notNull": true + } + }, + "indexes": { + "embeddingIndex_text_embedding_3_small": { + "name": "embeddingIndex_text_embedding_3_small", + "columns": [ + { + "expression": "embedding", + "isExpression": false, + "asc": true, + "nulls": "last", + "opclass": "vector_cosine_ops" + } + ], + "isUnique": false, + "concurrently": false, + "method": "hnsw", + "with": {} + } + }, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "public.vector_data_text_embedding_3_large": { + "name": "vector_data_text_embedding_3_large", + "schema": "", + "columns": { + "id": { + "name": "id", + "type": "serial", + "primaryKey": true, + "notNull": true + }, + "path": { + "name": "path", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "mtime": { + "name": "mtime", + "type": "bigint", + "primaryKey": false, + "notNull": true + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "embedding": { + "name": "embedding", + "type": "vector(3072)", + "primaryKey": false, + "notNull": false + }, + "metadata": { + "name": "metadata", + "type": "jsonb", + "primaryKey": false, + "notNull": true + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "public.vector_data_nomic_embed_text": { + "name": "vector_data_nomic_embed_text", + "schema": "", + "columns": { + "id": { + "name": "id", + "type": "serial", + "primaryKey": true, + "notNull": true + }, + "path": { + "name": "path", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "mtime": { + "name": "mtime", + "type": "bigint", + "primaryKey": false, + "notNull": true + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "embedding": { + "name": "embedding", + "type": "vector(768)", + "primaryKey": false, + "notNull": false + }, + "metadata": { + "name": "metadata", + "type": "jsonb", + "primaryKey": false, + "notNull": true + } + }, + "indexes": { + "embeddingIndex_nomic_embed_text": { + "name": "embeddingIndex_nomic_embed_text", + "columns": [ + { + "expression": "embedding", + "isExpression": false, + "asc": true, + "nulls": "last", + "opclass": "vector_cosine_ops" + } + ], + "isUnique": false, + "concurrently": false, + "method": "hnsw", + "with": {} + } + }, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + } + }, + "enums": {}, + "schemas": {}, + "sequences": {}, + "views": {}, + "_meta": { + "columns": {}, + "schemas": {}, + "tables": {} + } +} diff --git a/drizzle/meta/0003_snapshot.json b/drizzle/meta/0003_snapshot.json new file mode 100644 index 0000000..4d7126c --- /dev/null +++ b/drizzle/meta/0003_snapshot.json @@ -0,0 +1,256 @@ +{ + "id": "cd7adecf-9de7-40f7-bc6a-bcf78b36053a", + "prevId": "66d66503-35b5-4c19-8433-529426a63956", + "version": "7", + "dialect": "postgresql", + "tables": { + "public.vector_data_text_embedding_3_small": { + "name": "vector_data_text_embedding_3_small", + "schema": "", + "columns": { + "id": { + "name": "id", + "type": "serial", + "primaryKey": true, + "notNull": true + }, + "path": { + "name": "path", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "mtime": { + "name": "mtime", + "type": "bigint", + "primaryKey": false, + "notNull": true + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "embedding": { + "name": "embedding", + "type": "vector(1536)", + "primaryKey": false, + "notNull": false + }, + "metadata": { + "name": "metadata", + "type": "jsonb", + "primaryKey": false, + "notNull": true + } + }, + "indexes": { + "embeddingIndex_text_embedding_3_small": { + "name": "embeddingIndex_text_embedding_3_small", + "columns": [ + { + "expression": "embedding", + "isExpression": false, + "asc": true, + "nulls": "last", + "opclass": "vector_cosine_ops" + } + ], + "isUnique": false, + "concurrently": false, + "method": "hnsw", + "with": {} + } + }, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "public.vector_data_text_embedding_3_large": { + "name": "vector_data_text_embedding_3_large", + "schema": "", + "columns": { + "id": { + "name": "id", + "type": "serial", + "primaryKey": true, + "notNull": true + }, + "path": { + "name": "path", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "mtime": { + "name": "mtime", + "type": "bigint", + "primaryKey": false, + "notNull": true + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "embedding": { + "name": "embedding", + "type": "vector(3072)", + "primaryKey": false, + "notNull": false + }, + "metadata": { + "name": "metadata", + "type": "jsonb", + "primaryKey": false, + "notNull": true + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "public.vector_data_nomic_embed_text": { + "name": "vector_data_nomic_embed_text", + "schema": "", + "columns": { + "id": { + "name": "id", + "type": "serial", + "primaryKey": true, + "notNull": true + }, + "path": { + "name": "path", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "mtime": { + "name": "mtime", + "type": "bigint", + "primaryKey": false, + "notNull": true + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "embedding": { + "name": "embedding", + "type": "vector(768)", + "primaryKey": false, + "notNull": false + }, + "metadata": { + "name": "metadata", + "type": "jsonb", + "primaryKey": false, + "notNull": true + } + }, + "indexes": { + "embeddingIndex_nomic_embed_text": { + "name": "embeddingIndex_nomic_embed_text", + "columns": [ + { + "expression": "embedding", + "isExpression": false, + "asc": true, + "nulls": "last", + "opclass": "vector_cosine_ops" + } + ], + "isUnique": false, + "concurrently": false, + "method": "hnsw", + "with": {} + } + }, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "public.vector_data_mxbai_embed_large": { + "name": "vector_data_mxbai_embed_large", + "schema": "", + "columns": { + "id": { + "name": "id", + "type": "serial", + "primaryKey": true, + "notNull": true + }, + "path": { + "name": "path", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "mtime": { + "name": "mtime", + "type": "bigint", + "primaryKey": false, + "notNull": true + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "embedding": { + "name": "embedding", + "type": "vector(1024)", + "primaryKey": false, + "notNull": false + }, + "metadata": { + "name": "metadata", + "type": "jsonb", + "primaryKey": false, + "notNull": true + } + }, + "indexes": { + "embeddingIndex_mxbai_embed_large": { + "name": "embeddingIndex_mxbai_embed_large", + "columns": [ + { + "expression": "embedding", + "isExpression": false, + "asc": true, + "nulls": "last", + "opclass": "vector_cosine_ops" + } + ], + "isUnique": false, + "concurrently": false, + "method": "hnsw", + "with": {} + } + }, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + } + }, + "enums": {}, + "schemas": {}, + "sequences": {}, + "views": {}, + "_meta": { + "columns": {}, + "schemas": {}, + "tables": {} + } +} diff --git a/drizzle/meta/_journal.json b/drizzle/meta/_journal.json index e3bd512..49172af 100644 --- a/drizzle/meta/_journal.json +++ b/drizzle/meta/_journal.json @@ -15,6 +15,20 @@ "when": 1729509994653, "tag": "0001_create_vector_data_tables", "breakpoints": true + }, + { + "idx": 2, + "version": "7", + "when": 1729890971064, + "tag": "0002_create_vector_data_nomic_embed_text", + "breakpoints": true + }, + { + "idx": 3, + "version": "7", + "when": 1729928816942, + "tag": "0003_create_vector_data_mxbai_embed_large", + "breakpoints": true } ] } diff --git a/src/components/chat-view/Chat.tsx b/src/components/chat-view/Chat.tsx index c6cee8b..6c6a7a0 100644 --- a/src/components/chat-view/Chat.tsx +++ b/src/components/chat-view/Chat.tsx @@ -28,6 +28,7 @@ import { } from '../../types/mentionable' import { applyChangesToFile } from '../../utils/apply' import { + LLMABaseUrlNotSetException, LLMAPIKeyInvalidException, LLMAPIKeyNotSetException, } from '../../utils/llm/exception' @@ -259,7 +260,8 @@ const Chat = forwardRef((props, ref) => { }) if ( error instanceof LLMAPIKeyNotSetException || - error instanceof LLMAPIKeyInvalidException + error instanceof LLMAPIKeyInvalidException || + error instanceof LLMABaseUrlNotSetException ) { new OpenSettingsModal(app, error.message, () => { // eslint-disable-next-line @typescript-eslint/no-explicit-any diff --git a/src/constants.ts b/src/constants.ts index 031a2c7..88223a6 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -18,6 +18,10 @@ export const CHAT_MODEL_OPTIONS = [ name: 'llama-3.1-70b (Groq)', value: 'llama-3.1-70b-versatile', }, + { + name: 'llama3.1:8b (Ollama)', + value: 'llama3.1:8b', + }, ] export const APPLY_MODEL_OPTIONS = [ @@ -37,6 +41,10 @@ export const APPLY_MODEL_OPTIONS = [ name: 'llama-3.1-70b (Groq)', value: 'llama-3.1-70b-versatile', }, + { + name: 'llama3.1:8b (Ollama)', + value: 'llama3.1:8b', + }, ] // Update table exports in src/db/schema.ts when updating this @@ -51,6 +59,16 @@ export const EMBEDDING_MODEL_OPTIONS = [ value: 'text-embedding-3-large', dimension: 3072, }, + { + name: 'nomic-embed-text (Ollama)', + value: 'nomic-embed-text', + dimension: 768, + }, + { + name: 'mxbai-embed-large (Ollama)', + value: 'mxbai-embed-large', + dimension: 1024, + }, ] export const PGLITE_DB_PATH = '.smtcmp_vector_db.tar.gz' diff --git a/src/contexts/llm-context.tsx b/src/contexts/llm-context.tsx index 9863a51..08ba16b 100644 --- a/src/contexts/llm-context.tsx +++ b/src/contexts/llm-context.tsx @@ -38,13 +38,21 @@ export function LLMProvider({ children }: PropsWithChildren) { const { settings } = useSettings() useEffect(() => { - const manager = new LLMManager({ - openai: settings.openAIApiKey, - groq: settings.groqApiKey, - anthropic: settings.anthropicApiKey, - }) + const manager = new LLMManager( + { + openai: settings.openAIApiKey, + groq: settings.groqApiKey, + anthropic: settings.anthropicApiKey, + }, + settings.ollamaBaseUrl, + ) setLLMManager(manager) - }, [settings.openAIApiKey, settings.groqApiKey, settings.anthropicApiKey]) + }, [ + settings.openAIApiKey, + settings.groqApiKey, + settings.anthropicApiKey, + settings.ollamaBaseUrl, + ]) const generateResponse = useCallback( async (request: LLMRequestNonStreaming, options?: LLMOptions) => { diff --git a/src/db/migrations.json b/src/db/migrations.json index 1aded28..c38d5a3 100644 --- a/src/db/migrations.json +++ b/src/db/migrations.json @@ -16,5 +16,23 @@ "bps": true, "folderMillis": 1729509994653, "hash": "30520313039892c9c07b13185b6e4aa0b0f9a09b851db96e0f6e400303560aec" + }, + { + "sql": [ + "CREATE TABLE IF NOT EXISTS \"vector_data_nomic_embed_text\" (\n\t\"id\" serial PRIMARY KEY NOT NULL,\n\t\"path\" text NOT NULL,\n\t\"mtime\" bigint NOT NULL,\n\t\"content\" text NOT NULL,\n\t\"embedding\" vector(768),\n\t\"metadata\" jsonb NOT NULL\n);\n", + "\nCREATE INDEX IF NOT EXISTS \"embeddingIndex_nomic_embed_text\" ON \"vector_data_nomic_embed_text\" USING hnsw (\"embedding\" vector_cosine_ops);" + ], + "bps": true, + "folderMillis": 1729890971064, + "hash": "2041cd1e2808ad7ceea75ab34088b1dae1d563286b35d36586129ab6467fb627" + }, + { + "sql": [ + "CREATE TABLE IF NOT EXISTS \"vector_data_mxbai_embed_large\" (\n\t\"id\" serial PRIMARY KEY NOT NULL,\n\t\"path\" text NOT NULL,\n\t\"mtime\" bigint NOT NULL,\n\t\"content\" text NOT NULL,\n\t\"embedding\" vector(1024),\n\t\"metadata\" jsonb NOT NULL\n);\n", + "\nCREATE INDEX IF NOT EXISTS \"embeddingIndex_mxbai_embed_large\" ON \"vector_data_mxbai_embed_large\" USING hnsw (\"embedding\" vector_cosine_ops);" + ], + "bps": true, + "folderMillis": 1729928816942, + "hash": "4dd533fe5e978bc0ad79708c073fc3ce973702b7c6eed71f0e4818b1063212fb" } ] diff --git a/src/db/schema.ts b/src/db/schema.ts index 88a4b92..85ba0fa 100644 --- a/src/db/schema.ts +++ b/src/db/schema.ts @@ -55,3 +55,5 @@ export type VectorMetaData = { // 'npx drizzle-kit generate' requires individual table exports to generate correct migration files export const vectorTable0 = vectorTables[EMBEDDING_MODEL_OPTIONS[0].value] export const vectorTable1 = vectorTables[EMBEDDING_MODEL_OPTIONS[1].value] +export const vectorTable2 = vectorTables[EMBEDDING_MODEL_OPTIONS[2].value] +export const vectorTable3 = vectorTables[EMBEDDING_MODEL_OPTIONS[3].value] diff --git a/src/settings/SettingTab.tsx b/src/settings/SettingTab.tsx index ff9cfc2..a401156 100644 --- a/src/settings/SettingTab.tsx +++ b/src/settings/SettingTab.tsx @@ -71,6 +71,22 @@ export class SmartCopilotSettingTab extends PluginSettingTab { }), ) + new Setting(containerEl) + .setName('Ollama Address') + .setDesc( + 'Set the Ollama URL and port address - normally http://127.0.0.1:11434', + ) + .addText((text) => + text + .setValue(String(this.plugin.settings.ollamaBaseUrl)) + .onChange(async (value) => { + await this.plugin.setSettings({ + ...this.plugin.settings, + ollamaBaseUrl: value, + }) + }), + ) + new Setting(containerEl).setHeading().setName('Model Settings') new Setting(containerEl) diff --git a/src/types/embedding.ts b/src/types/embedding.ts index 07a6a05..15c4a09 100644 --- a/src/types/embedding.ts +++ b/src/types/embedding.ts @@ -1,6 +1,8 @@ export type EmbeddingModelName = | 'text-embedding-3-small' | 'text-embedding-3-large' + | 'nomic-embed-text' + | 'mxbai-embed-large' export type EmbeddingModel = { name: EmbeddingModelName diff --git a/src/types/settings.ts b/src/types/settings.ts index 35ea5ed..16f37ee 100644 --- a/src/types/settings.ts +++ b/src/types/settings.ts @@ -4,6 +4,7 @@ const smartCopilotSettingsSchema = z.object({ openAIApiKey: z.string().default(''), groqApiKey: z.string().default(''), anthropicApiKey: z.string().default(''), + ollamaBaseUrl: z.string().default(''), chatModel: z.string().default('claude-3-5-sonnet-20240620'), applyModel: z.string().default('gpt-4o-mini'), embeddingModel: z.string().default('text-embedding-3-small'), diff --git a/src/utils/embedding.ts b/src/utils/embedding.ts index 31b8318..a34a031 100644 --- a/src/utils/embedding.ts +++ b/src/utils/embedding.ts @@ -2,11 +2,14 @@ import { OpenAI } from 'openai' import { EmbeddingModel } from '../types/embedding' +import { NoStainlessOpenAI } from './llm/ollama' + export const getEmbeddingModel = ( name: string, apiKeys: { openAIApiKey: string }, + ollamaBaseUrl: string, ): EmbeddingModel => { switch (name) { case 'text-embedding-3-small': { @@ -43,6 +46,42 @@ export const getEmbeddingModel = ( }, } } + case 'nomic-embed-text': { + const openai = new NoStainlessOpenAI({ + apiKey: '', + dangerouslyAllowBrowser: true, + baseURL: `${ollamaBaseUrl}/v1`, + }) + return { + name: 'nomic-embed-text', + dimension: 768, + getEmbedding: async (text: string) => { + const embedding = await openai.embeddings.create({ + model: 'nomic-embed-text', + input: text, + }) + return embedding.data[0].embedding + }, + } + } + case 'mxbai-embed-large': { + const openai = new NoStainlessOpenAI({ + apiKey: '', + dangerouslyAllowBrowser: true, + baseURL: `${ollamaBaseUrl}/v1`, + }) + return { + name: 'mxbai-embed-large', + dimension: 1024, + getEmbedding: async (text: string) => { + const embedding = await openai.embeddings.create({ + model: 'mxbai-embed-large', + input: text, + }) + return embedding.data[0].embedding + }, + } + } default: throw new Error('Invalid embedding model') } diff --git a/src/utils/llm/exception.ts b/src/utils/llm/exception.ts index c32944d..9dfdc69 100644 --- a/src/utils/llm/exception.ts +++ b/src/utils/llm/exception.ts @@ -11,3 +11,10 @@ export class LLMAPIKeyInvalidException extends Error { this.name = 'LLMAPIKeyInvalidException' } } + +export class LLMABaseUrlNotSetException extends Error { + constructor(message: string) { + super(message) + this.name = 'LLMABaseUrlNotSetException' + } +} diff --git a/src/utils/llm/manager.ts b/src/utils/llm/manager.ts index d71569e..a5b3b57 100644 --- a/src/utils/llm/manager.ts +++ b/src/utils/llm/manager.ts @@ -10,7 +10,8 @@ import { import { AnthropicProvider } from './anthropic' import { GroqProvider } from './groq' -import { OpenAIProvider } from './openai' +import { OllamaOpenAIProvider } from './ollama' +import { OpenAIAuthenticatedProvider } from './openai' export type LLMManagerInterface = { generateResponse( @@ -24,20 +25,28 @@ export type LLMManagerInterface = { } class LLMManager implements LLMManagerInterface { - private openaiProvider: OpenAIProvider + private openaiProvider: OpenAIAuthenticatedProvider private groqProvider: GroqProvider private anthropicProvider: AnthropicProvider + private ollamaProvider: OllamaOpenAIProvider - constructor(apiKeys: { openai?: string; groq?: string; anthropic?: string }) { - this.openaiProvider = new OpenAIProvider(apiKeys.openai ?? '') + constructor( + apiKeys: { openai?: string; groq?: string; anthropic?: string }, + ollamaBaseUrl?: string, + ) { + this.openaiProvider = new OpenAIAuthenticatedProvider(apiKeys.openai ?? '') this.groqProvider = new GroqProvider(apiKeys.groq ?? '') this.anthropicProvider = new AnthropicProvider(apiKeys.anthropic ?? '') + this.ollamaProvider = new OllamaOpenAIProvider(ollamaBaseUrl ?? '') } async generateResponse( request: LLMRequestNonStreaming, options?: LLMOptions, ): Promise { + if (this.ollamaProvider.getSupportedModels().includes(request.model)) { + return await this.ollamaProvider.generateResponse(request, options) + } if (this.openaiProvider.getSupportedModels().includes(request.model)) { return await this.openaiProvider.generateResponse(request, options) } @@ -54,6 +63,9 @@ class LLMManager implements LLMManagerInterface { request: LLMRequestStreaming, options?: LLMOptions, ): Promise> { + if (this.ollamaProvider.getSupportedModels().includes(request.model)) { + return await this.ollamaProvider.streamResponse(request, options) + } if (this.openaiProvider.getSupportedModels().includes(request.model)) { return await this.openaiProvider.streamResponse(request, options) } diff --git a/src/utils/llm/ollama.ts b/src/utils/llm/ollama.ts new file mode 100644 index 0000000..498212f --- /dev/null +++ b/src/utils/llm/ollama.ts @@ -0,0 +1,82 @@ +import OpenAI from 'openai' +import { FinalRequestOptions } from 'openai/core' +import { + LLMOptions, + LLMRequestNonStreaming, + LLMRequestStreaming, +} from 'src/types/llm/request' +import { + LLMResponseNonStreaming, + LLMResponseStreaming, +} from 'src/types/llm/response' + +import { BaseLLMProvider } from './base' +import { LLMABaseUrlNotSetException } from './exception' +import { OpenAICompatibleProvider } from './openaiCompatibleProvider' + +export class NoStainlessOpenAI extends OpenAI { + defaultHeaders() { + return { + Accept: 'application/json', + 'Content-Type': 'application/json', + } + } + + buildRequest( + options: FinalRequestOptions, + { retryCount = 0 }: { retryCount?: number } = {}, + ): { req: RequestInit; url: string; timeout: number } { + const req = super.buildRequest(options, { retryCount }) + const headers = req.req.headers as Record + Object.keys(headers).forEach((k) => { + if (k.startsWith('x-stainless')) { + delete headers[k] + } + }) + return req + } +} + +export type OllamaModel = 'llama3.1:8b' +export const OLLAMA_MODELS: OllamaModel[] = ['llama3.1:8b'] + +export class OllamaOpenAIProvider implements BaseLLMProvider { + private provider: OpenAICompatibleProvider + private ollamaBaseUrl: string + + constructor(baseUrl: string) { + this.ollamaBaseUrl = baseUrl + this.provider = new OpenAICompatibleProvider( + new NoStainlessOpenAI({ + apiKey: '', + dangerouslyAllowBrowser: true, + baseURL: `${baseUrl}/v1`, + }), + ) + } + generateResponse( + request: LLMRequestNonStreaming, + options?: LLMOptions, + ): Promise { + if (!this.ollamaBaseUrl) { + throw new LLMABaseUrlNotSetException( + 'Ollama Address is missing. Please set it in settings menu.', + ) + } + return this.provider.generateResponse(request, options) + } + streamResponse( + request: LLMRequestStreaming, + options?: LLMOptions, + ): Promise> { + if (!this.ollamaBaseUrl) { + throw new LLMABaseUrlNotSetException( + 'Ollama Address is missing. Please set it in settings menu.', + ) + } + return this.provider.streamResponse(request, options) + } + getSupportedModels(): string[] { + return OLLAMA_MODELS + } +} diff --git a/src/utils/llm/openai.ts b/src/utils/llm/openai.ts index c9d1541..010977a 100644 --- a/src/utils/llm/openai.ts +++ b/src/utils/llm/openai.ts @@ -1,15 +1,9 @@ import OpenAI from 'openai' -import { - ChatCompletion, - ChatCompletionChunk, - ChatCompletionMessageParam, -} from 'openai/resources/chat/completions' import { LLMOptions, LLMRequestNonStreaming, LLMRequestStreaming, - RequestMessage, } from '../../types/llm/request' import { LLMResponseNonStreaming, @@ -21,15 +15,18 @@ import { LLMAPIKeyInvalidException, LLMAPIKeyNotSetException, } from './exception' +import { OpenAICompatibleProvider } from './openaiCompatibleProvider' export type OpenAIModel = 'gpt-4o' | 'gpt-4o-mini' export const OPENAI_MODELS: OpenAIModel[] = ['gpt-4o', 'gpt-4o-mini'] -export class OpenAIProvider implements BaseLLMProvider { +export class OpenAIAuthenticatedProvider implements BaseLLMProvider { + private provider: OpenAICompatibleProvider private client: OpenAI constructor(apiKey: string) { this.client = new OpenAI({ apiKey, dangerouslyAllowBrowser: true }) + this.provider = new OpenAICompatibleProvider(this.client) } async generateResponse( @@ -41,26 +38,8 @@ export class OpenAIProvider implements BaseLLMProvider { 'OpenAI API key is missing. Please set it in settings menu.', ) } - try { - const response = await this.client.chat.completions.create( - { - model: request.model, - messages: request.messages.map((m) => - OpenAIProvider.parseRequestMessage(m), - ), - max_tokens: request.max_tokens, - temperature: request.temperature, - top_p: request.top_p, - frequency_penalty: request.frequency_penalty, - presence_penalty: request.presence_penalty, - logit_bias: request.logit_bias, - }, - { - signal: options?.signal, - }, - ) - return OpenAIProvider.parseNonStreamingResponse(response) + return this.provider.generateResponse(request, options) } catch (error) { if (error instanceof OpenAI.AuthenticationError) { throw new LLMAPIKeyInvalidException( @@ -82,33 +61,7 @@ export class OpenAIProvider implements BaseLLMProvider { } try { - const stream = await this.client.chat.completions.create( - { - model: request.model, - messages: request.messages.map((m) => - OpenAIProvider.parseRequestMessage(m), - ), - max_completion_tokens: request.max_tokens, - temperature: request.temperature, - top_p: request.top_p, - frequency_penalty: request.frequency_penalty, - presence_penalty: request.presence_penalty, - logit_bias: request.logit_bias, - stream: true, - }, - { - signal: options?.signal, - }, - ) - - // eslint-disable-next-line no-inner-declarations - async function* streamResponse(): AsyncIterable { - for await (const chunk of stream) { - yield OpenAIProvider.parseStreamingResponseChunk(chunk) - } - } - - return streamResponse() + return this.provider.streamResponse(request, options) } catch (error) { if (error instanceof OpenAI.AuthenticationError) { throw new LLMAPIKeyInvalidException( @@ -119,55 +72,6 @@ export class OpenAIProvider implements BaseLLMProvider { } } - static parseRequestMessage( - message: RequestMessage, - ): ChatCompletionMessageParam { - return { - role: message.role, - content: message.content, - } - } - - static parseNonStreamingResponse( - response: ChatCompletion, - ): LLMResponseNonStreaming { - return { - id: response.id, - choices: response.choices.map((choice) => ({ - finish_reason: choice.finish_reason, - message: { - content: choice.message.content, - role: choice.message.role, - }, - })), - created: response.created, - model: response.model, - object: 'chat.completion', - system_fingerprint: response.system_fingerprint, - usage: response.usage, - } - } - - static parseStreamingResponseChunk( - chunk: ChatCompletionChunk, - ): LLMResponseStreaming { - return { - id: chunk.id, - choices: chunk.choices.map((choice) => ({ - finish_reason: choice.finish_reason ?? null, - delta: { - content: choice.delta.content ?? null, - role: choice.delta.role, - }, - })), - created: chunk.created, - model: chunk.model, - object: 'chat.completion.chunk', - system_fingerprint: chunk.system_fingerprint, - usage: chunk.usage, - } - } - getSupportedModels(): string[] { return OPENAI_MODELS } diff --git a/src/utils/llm/openaiCompatibleProvider.ts b/src/utils/llm/openaiCompatibleProvider.ts new file mode 100644 index 0000000..18afafa --- /dev/null +++ b/src/utils/llm/openaiCompatibleProvider.ts @@ -0,0 +1,137 @@ +import OpenAI from 'openai' +import { + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessageParam, +} from 'openai/resources/chat/completions' + +import { + LLMOptions, + LLMRequestNonStreaming, + LLMRequestStreaming, + RequestMessage, +} from '../../types/llm/request' +import { + LLMResponseNonStreaming, + LLMResponseStreaming, +} from '../../types/llm/response' + +import { BaseLLMProvider } from './base' + +export class OpenAICompatibleProvider implements BaseLLMProvider { + private client: OpenAI + + constructor(client: OpenAI) { + this.client = client + } + + async generateResponse( + request: LLMRequestNonStreaming, + options?: LLMOptions, + ): Promise { + const response = await this.client.chat.completions.create( + { + model: request.model, + messages: request.messages.map((m) => + OpenAICompatibleProvider.parseRequestMessage(m), + ), + max_tokens: request.max_tokens, + temperature: request.temperature, + top_p: request.top_p, + frequency_penalty: request.frequency_penalty, + presence_penalty: request.presence_penalty, + logit_bias: request.logit_bias, + }, + { + signal: options?.signal, + }, + ) + return OpenAICompatibleProvider.parseNonStreamingResponse(response) + } + + async streamResponse( + request: LLMRequestStreaming, + options?: LLMOptions, + ): Promise> { + const stream = await this.client.chat.completions.create( + { + model: request.model, + messages: request.messages.map((m) => + OpenAICompatibleProvider.parseRequestMessage(m), + ), + max_completion_tokens: request.max_tokens, + temperature: request.temperature, + top_p: request.top_p, + frequency_penalty: request.frequency_penalty, + presence_penalty: request.presence_penalty, + logit_bias: request.logit_bias, + stream: true, + }, + { + signal: options?.signal, + }, + ) + + // eslint-disable-next-line no-inner-declarations + async function* streamResponse(): AsyncIterable { + for await (const chunk of stream) { + yield OpenAICompatibleProvider.parseStreamingResponseChunk(chunk) + } + } + + return streamResponse() + } + + static parseRequestMessage( + message: RequestMessage, + ): ChatCompletionMessageParam { + return { + role: message.role, + content: message.content, + } + } + + static parseNonStreamingResponse( + response: ChatCompletion, + ): LLMResponseNonStreaming { + return { + id: response.id, + choices: response.choices.map((choice) => ({ + finish_reason: choice.finish_reason, + message: { + content: choice.message.content, + role: choice.message.role, + }, + })), + created: response.created, + model: response.model, + object: 'chat.completion', + system_fingerprint: response.system_fingerprint, + usage: response.usage, + } + } + + static parseStreamingResponseChunk( + chunk: ChatCompletionChunk, + ): LLMResponseStreaming { + return { + id: chunk.id, + choices: chunk.choices.map((choice) => ({ + finish_reason: choice.finish_reason ?? null, + delta: { + content: choice.delta.content ?? null, + role: choice.delta.role, + }, + })), + created: chunk.created, + model: chunk.model, + object: 'chat.completion.chunk', + system_fingerprint: chunk.system_fingerprint, + usage: chunk.usage, + } + } + + getSupportedModels(): string[] { + throw new Error('Not implemented') + } +} diff --git a/src/utils/ragEngine.ts b/src/utils/ragEngine.ts index 4992a48..ce5538d 100644 --- a/src/utils/ragEngine.ts +++ b/src/utils/ragEngine.ts @@ -25,9 +25,13 @@ export class RAGEngine { ): Promise { const ragEngine = new RAGEngine(app, settings) ragEngine.vectorDbManager = await VectorDbManager.create(app) - ragEngine.embeddingModel = getEmbeddingModel(settings.embeddingModel, { - openAIApiKey: settings.openAIApiKey, - }) + ragEngine.embeddingModel = getEmbeddingModel( + settings.embeddingModel, + { + openAIApiKey: settings.openAIApiKey, + }, + settings.ollamaBaseUrl, + ) return ragEngine } @@ -37,9 +41,13 @@ export class RAGEngine { setSettings(settings: SmartCopilotSettings) { this.settings = settings - this.embeddingModel = getEmbeddingModel(settings.embeddingModel, { - openAIApiKey: settings.openAIApiKey, - }) + this.embeddingModel = getEmbeddingModel( + settings.embeddingModel, + { + openAIApiKey: settings.openAIApiKey, + }, + settings.ollamaBaseUrl, + ) } // TODO: Implement automatic vault re-indexing when settings are changed.