diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f9b8db8..6bdabc8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -44,7 +44,7 @@ We use PGlite and Drizzle ORM for database management in this project. This sect To update the database schema: -1. Modify the existing schema as needed in the `src/db/schema.ts` file. +1. Modify the existing schema as needed in the `src/database/schema.ts` file. 2. After making changes, run the following command to generate migration files: ``` @@ -58,7 +58,7 @@ To update the database schema: npm run migrate:compile ``` - This will create or update the `src/db/migrations.json` file. Note that migration files in the 'drizzle' directory won't affect the project until they are compiled into this JSON file, which is used in the actual migration process. + This will create or update the `src/database/migrations.json` file. Note that migration files in the 'drizzle' directory won't affect the project until they are compiled into this JSON file, which is used in the actual migration process. ### Handling Migration Files diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index fc51395..ee87f62 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -4,7 +4,7 @@ PGlite typically uses the `node:fs` module to load bundle files. However, Obsidian plugins run in a browser-like environment where `node:fs` is not available. This presents a challenge in implementing PGlite in Obsidian's environment. -To address this, we developed a workaround in `src/utils/vector-db/repository.ts`: +To address this, we developed a workaround in `src/database/DatabaseManager.ts`: 1. Manually fetch required PGlite resources (Postgres data, WebAssembly module, and Vector extension). 2. Use PGlite's option to directly set bundle files or URLs when initializing the database. diff --git a/compile-migration.js b/compile-migration.js index 241124d..4cbd8bc 100644 --- a/compile-migration.js +++ b/compile-migration.js @@ -4,7 +4,10 @@ const fs = require('node:fs/promises') async function compileMigrations() { const migrations = readMigrationFiles({ migrationsFolder: './drizzle/' }) - await fs.writeFile('./src/db/migrations.json', JSON.stringify(migrations)) + await fs.writeFile( + './src/database/migrations.json', + JSON.stringify(migrations), + ) console.log('Migrations compiled!') } diff --git a/drizzle.config.ts b/drizzle.config.ts index 7cf0a6d..a20c3c0 100644 --- a/drizzle.config.ts +++ b/drizzle.config.ts @@ -2,5 +2,5 @@ import { defineConfig } from 'drizzle-kit' export default defineConfig({ dialect: 'postgresql', - schema: './src/db/schema.ts', + schema: './src/database/schema.ts', }) diff --git a/drizzle/0005_create_template_table.sql b/drizzle/0005_create_template_table.sql new file mode 100644 index 0000000..375777a --- /dev/null +++ b/drizzle/0005_create_template_table.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS "template" ( + "id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL, + "name" text NOT NULL, + "content" jsonb NOT NULL, + "created_at" timestamp DEFAULT now() NOT NULL, + "updated_at" timestamp DEFAULT now() NOT NULL, + CONSTRAINT "template_name_unique" UNIQUE("name") +); diff --git a/drizzle/meta/0005_snapshot.json b/drizzle/meta/0005_snapshot.json new file mode 100644 index 0000000..1ee5466 --- /dev/null +++ b/drizzle/meta/0005_snapshot.json @@ -0,0 +1,370 @@ +{ + "id": "410315fb-45e8-44ba-91f2-deadeac035d6", + "prevId": "c810ceee-4769-451b-a42b-31afbe8d434b", + "version": "7", + "dialect": "postgresql", + "tables": { + "public.template": { + "name": "template", + "schema": "", + "columns": { + "id": { + "name": "id", + "type": "uuid", + "primaryKey": true, + "notNull": true, + "default": "gen_random_uuid()" + }, + "name": { + "name": "name", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "content": { + "name": "content", + "type": "jsonb", + "primaryKey": false, + "notNull": true + }, + "created_at": { + "name": "created_at", + "type": "timestamp", + "primaryKey": false, + "notNull": true, + "default": "now()" + }, + "updated_at": { + "name": "updated_at", + "type": "timestamp", + "primaryKey": false, + "notNull": true, + "default": "now()" + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": { + "template_name_unique": { + "name": "template_name_unique", + "nullsNotDistinct": false, + "columns": ["name"] + } + }, + "checkConstraints": {} + }, + "public.vector_data_text_embedding_3_small": { + "name": "vector_data_text_embedding_3_small", + "schema": "", + "columns": { + "id": { + "name": "id", + "type": "serial", + "primaryKey": true, + "notNull": true + }, + "path": { + "name": "path", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "mtime": { + "name": "mtime", + "type": "bigint", + "primaryKey": false, + "notNull": true + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "embedding": { + "name": "embedding", + "type": "vector(1536)", + "primaryKey": false, + "notNull": false + }, + "metadata": { + "name": "metadata", + "type": "jsonb", + "primaryKey": false, + "notNull": true + } + }, + "indexes": { + "embeddingIndex_text_embedding_3_small": { + "name": "embeddingIndex_text_embedding_3_small", + "columns": [ + { + "expression": "embedding", + "isExpression": false, + "asc": true, + "nulls": "last", + "opclass": "vector_cosine_ops" + } + ], + "isUnique": false, + "concurrently": false, + "method": "hnsw", + "with": {} + } + }, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "public.vector_data_text_embedding_3_large": { + "name": "vector_data_text_embedding_3_large", + "schema": "", + "columns": { + "id": { + "name": "id", + "type": "serial", + "primaryKey": true, + "notNull": true + }, + "path": { + "name": "path", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "mtime": { + "name": "mtime", + "type": "bigint", + "primaryKey": false, + "notNull": true + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "embedding": { + "name": "embedding", + "type": "vector(3072)", + "primaryKey": false, + "notNull": false + }, + "metadata": { + "name": "metadata", + "type": "jsonb", + "primaryKey": false, + "notNull": true + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "public.vector_data_nomic_embed_text": { + "name": "vector_data_nomic_embed_text", + "schema": "", + "columns": { + "id": { + "name": "id", + "type": "serial", + "primaryKey": true, + "notNull": true + }, + "path": { + "name": "path", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "mtime": { + "name": "mtime", + "type": "bigint", + "primaryKey": false, + "notNull": true + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "embedding": { + "name": "embedding", + "type": "vector(768)", + "primaryKey": false, + "notNull": false + }, + "metadata": { + "name": "metadata", + "type": "jsonb", + "primaryKey": false, + "notNull": true + } + }, + "indexes": { + "embeddingIndex_nomic_embed_text": { + "name": "embeddingIndex_nomic_embed_text", + "columns": [ + { + "expression": "embedding", + "isExpression": false, + "asc": true, + "nulls": "last", + "opclass": "vector_cosine_ops" + } + ], + "isUnique": false, + "concurrently": false, + "method": "hnsw", + "with": {} + } + }, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "public.vector_data_mxbai_embed_large": { + "name": "vector_data_mxbai_embed_large", + "schema": "", + "columns": { + "id": { + "name": "id", + "type": "serial", + "primaryKey": true, + "notNull": true + }, + "path": { + "name": "path", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "mtime": { + "name": "mtime", + "type": "bigint", + "primaryKey": false, + "notNull": true + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "embedding": { + "name": "embedding", + "type": "vector(1024)", + "primaryKey": false, + "notNull": false + }, + "metadata": { + "name": "metadata", + "type": "jsonb", + "primaryKey": false, + "notNull": true + } + }, + "indexes": { + "embeddingIndex_mxbai_embed_large": { + "name": "embeddingIndex_mxbai_embed_large", + "columns": [ + { + "expression": "embedding", + "isExpression": false, + "asc": true, + "nulls": "last", + "opclass": "vector_cosine_ops" + } + ], + "isUnique": false, + "concurrently": false, + "method": "hnsw", + "with": {} + } + }, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "public.vector_data_bge_m3": { + "name": "vector_data_bge_m3", + "schema": "", + "columns": { + "id": { + "name": "id", + "type": "serial", + "primaryKey": true, + "notNull": true + }, + "path": { + "name": "path", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "mtime": { + "name": "mtime", + "type": "bigint", + "primaryKey": false, + "notNull": true + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true + }, + "embedding": { + "name": "embedding", + "type": "vector(1024)", + "primaryKey": false, + "notNull": false + }, + "metadata": { + "name": "metadata", + "type": "jsonb", + "primaryKey": false, + "notNull": true + } + }, + "indexes": { + "embeddingIndex_bge_m3": { + "name": "embeddingIndex_bge_m3", + "columns": [ + { + "expression": "embedding", + "isExpression": false, + "asc": true, + "nulls": "last", + "opclass": "vector_cosine_ops" + } + ], + "isUnique": false, + "concurrently": false, + "method": "hnsw", + "with": {} + } + }, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + } + }, + "enums": {}, + "schemas": {}, + "sequences": {}, + "views": {}, + "_meta": { + "columns": {}, + "schemas": {}, + "tables": {} + } +} diff --git a/drizzle/meta/_journal.json b/drizzle/meta/_journal.json index f02a664..3788376 100644 --- a/drizzle/meta/_journal.json +++ b/drizzle/meta/_journal.json @@ -36,6 +36,13 @@ "when": 1730085029605, "tag": "0004_create_vector_data_bge_m3", "breakpoints": true + }, + { + "idx": 5, + "version": "7", + "when": 1730443763982, + "tag": "0005_create_template_table", + "breakpoints": true } ] } diff --git a/esbuild.config.mjs b/esbuild.config.mjs index 0e8e42f..9edf9a7 100644 --- a/esbuild.config.mjs +++ b/esbuild.config.mjs @@ -31,6 +31,7 @@ const context = await esbuild.context({ '@lezer/common', '@lezer/highlight', '@lezer/lr', + '@lexical/clipboard/clipboard', ...builtins, ], format: 'cjs', diff --git a/package-lock.json b/package-lock.json index 0178ca4..4c04625 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11,10 +11,13 @@ "dependencies": { "@anthropic-ai/sdk": "^0.27.3", "@electric-sql/pglite": "^0.2.12", + "@lexical/clipboard": "^0.17.1", "@lexical/react": "^0.17.1", + "@radix-ui/react-dialog": "^1.1.2", "@radix-ui/react-dropdown-menu": "^2.1.2", "@radix-ui/react-tooltip": "^1.1.3", "@tanstack/react-query": "^5.56.2", + "clsx": "^2.1.1", "diff": "^7.0.0", "drizzle-orm": "^0.35.2", "exponential-backoff": "^3.1.1", @@ -1727,7 +1730,8 @@ }, "node_modules/@lexical/clipboard": { "version": "0.17.1", - "license": "MIT", + "resolved": "https://registry.npmjs.org/@lexical/clipboard/-/clipboard-0.17.1.tgz", + "integrity": "sha512-OVqnEfWX8XN5xxuMPo6BfgGKHREbz++D5V5ISOiml0Z8fV/TQkdgwqbBJcUdJHGRHWSUwdK7CWGs/VALvVvZyw==", "dependencies": { "@lexical/html": "0.17.1", "@lexical/list": "0.17.1", @@ -2057,6 +2061,41 @@ } } }, + "node_modules/@radix-ui/react-dialog": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.1.2.tgz", + "integrity": "sha512-Yj4dZtqa2o+kG61fzB0H2qUvmwBA2oyQroGLyNtBj1beo1khoQ3q1a2AO8rrQYjd8256CO9+N8L9tvsS+bnIyA==", + "dependencies": { + "@radix-ui/primitive": "1.1.0", + "@radix-ui/react-compose-refs": "1.1.0", + "@radix-ui/react-context": "1.1.1", + "@radix-ui/react-dismissable-layer": "1.1.1", + "@radix-ui/react-focus-guards": "1.1.1", + "@radix-ui/react-focus-scope": "1.1.0", + "@radix-ui/react-id": "1.1.0", + "@radix-ui/react-portal": "1.1.2", + "@radix-ui/react-presence": "1.1.1", + "@radix-ui/react-primitive": "2.0.0", + "@radix-ui/react-slot": "1.1.0", + "@radix-ui/react-use-controllable-state": "1.1.0", + "aria-hidden": "^1.1.1", + "react-remove-scroll": "2.6.0" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-direction": { "version": "1.1.0", "license": "MIT", @@ -3667,6 +3706,14 @@ "node": ">=12" } }, + "node_modules/clsx": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/clsx/-/clsx-2.1.1.tgz", + "integrity": "sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==", + "engines": { + "node": ">=6" + } + }, "node_modules/co": { "version": "4.6.0", "dev": true, diff --git a/package.json b/package.json index 83b3d9d..3b3a54b 100644 --- a/package.json +++ b/package.json @@ -46,10 +46,13 @@ "dependencies": { "@anthropic-ai/sdk": "^0.27.3", "@electric-sql/pglite": "^0.2.12", + "@lexical/clipboard": "^0.17.1", "@lexical/react": "^0.17.1", + "@radix-ui/react-dialog": "^1.1.2", "@radix-ui/react-dropdown-menu": "^2.1.2", "@radix-ui/react-tooltip": "^1.1.3", "@tanstack/react-query": "^5.56.2", + "clsx": "^2.1.1", "diff": "^7.0.0", "drizzle-orm": "^0.35.2", "exponential-backoff": "^3.1.1", diff --git a/src/ChatView.tsx b/src/ChatView.tsx index 3ee35db..9c5a047 100644 --- a/src/ChatView.tsx +++ b/src/ChatView.tsx @@ -7,6 +7,8 @@ import Chat, { ChatProps, ChatRef } from './components/chat-view/Chat' import { CHAT_VIEW_TYPE } from './constants' import { AppProvider } from './contexts/app-context' import { DarkModeProvider } from './contexts/dark-mode-context' +import { DatabaseProvider } from './contexts/database-context' +import { DialogContainerProvider } from './contexts/dialog-container-context' import { LLMProvider } from './contexts/llm-context' import { RAGProvider } from './contexts/rag-context' import { SettingsProvider } from './contexts/settings-context' @@ -67,6 +69,7 @@ export class ChatView extends ItemView { }, }, }) + const dbManager = await this.plugin.getDbManager() const ragEngine = await this.plugin.getRAGEngine() this.root.render( @@ -80,13 +83,19 @@ export class ChatView extends ItemView { > - - - - - - - + + + + + + + + + + + diff --git a/src/components/chat-view/Chat.tsx b/src/components/chat-view/Chat.tsx index dc753cf..7dcdb01 100644 --- a/src/components/chat-view/Chat.tsx +++ b/src/components/chat-view/Chat.tsx @@ -18,6 +18,11 @@ import { useApp } from '../../contexts/app-context' import { useLLM } from '../../contexts/llm-context' import { useRAG } from '../../contexts/rag-context' import { useSettings } from '../../contexts/settings-context' +import { + LLMAPIKeyInvalidException, + LLMAPIKeyNotSetException, + LLMBaseUrlNotSetException, +} from '../../core/llm/exception' import { useChatHistory } from '../../hooks/useChatHistory' import { ChatMessage, ChatUserMessage } from '../../types/chat' import { @@ -26,11 +31,6 @@ import { MentionableCurrentFile, } from '../../types/mentionable' import { applyChangesToFile } from '../../utils/apply' -import { - LLMAPIKeyInvalidException, - LLMAPIKeyNotSetException, - LLMBaseUrlNotSetException, -} from '../../utils/llm/exception' import { getMentionableKey, serializeMentionable, @@ -505,7 +505,7 @@ const Chat = forwardRef((props, ref) => { registerChatUserInputRef(message.id, ref)} - message={message.content} + initialSerializedEditorState={message.content} onChange={(content) => { setChatMessages((prevChatHistory) => prevChatHistory.map((msg) => @@ -561,7 +561,7 @@ const Chat = forwardRef((props, ref) => { registerChatUserInputRef(inputMessage.id, ref)} - message={inputMessage.content} + initialSerializedEditorState={inputMessage.content} onChange={(content) => { setInputMessage((prevInputMessage) => ({ ...prevInputMessage, diff --git a/src/components/chat-view/CreateTemplateDialog.tsx b/src/components/chat-view/CreateTemplateDialog.tsx new file mode 100644 index 0000000..97df795 --- /dev/null +++ b/src/components/chat-view/CreateTemplateDialog.tsx @@ -0,0 +1,125 @@ +import { $generateNodesFromSerializedNodes } from '@lexical/clipboard' +import { BaseSerializedNode } from '@lexical/clipboard/clipboard' +import { InitialEditorStateType } from '@lexical/react/LexicalComposer' +import * as Dialog from '@radix-ui/react-dialog' +import { $insertNodes, LexicalEditor } from 'lexical' +import { X } from 'lucide-react' +import { Notice } from 'obsidian' +import { useRef, useState } from 'react' + +import { useDatabase } from '../../contexts/database-context' +import { useDialogContainer } from '../../contexts/dialog-container-context' +import { DuplicateTemplateException } from '../../database/exception' + +import LexicalContentEditable from './chat-input/LexicalContentEditable' + +/* + * This component must be used inside + * The modal={false} prop is required because modal mode blocks pointer events across the entire page, + * which would conflict with lexical editor popovers + */ +export default function CreateTemplateDialogContent({ + selectedSerializedNodes, + onClose, +}: { + selectedSerializedNodes?: BaseSerializedNode[] | null + onClose: () => void +}) { + const container = useDialogContainer() + const { templateManager } = useDatabase() + + const [templateName, setTemplateName] = useState('') + const editorRef = useRef(null) + const contentEditableRef = useRef(null) + + const initialEditorState: InitialEditorStateType = ( + editor: LexicalEditor, + ) => { + if (!selectedSerializedNodes) return + editor.update(() => { + const parsedNodes = $generateNodesFromSerializedNodes( + selectedSerializedNodes, + ) + $insertNodes(parsedNodes) + }) + } + + const onSubmit = async () => { + try { + if (!editorRef.current) return + const serializedEditorState = editorRef.current.toJSON() + const nodes = serializedEditorState.editorState.root.children + if (nodes.length === 0) { + new Notice('Please enter a content for your template') + return + } + if (templateName.trim().length === 0) { + new Notice('Please enter a name for your template') + return + } + + await templateManager.createTemplate({ + name: templateName, + content: { nodes }, + }) + new Notice(`Template created: ${templateName}`) + setTemplateName('') + onClose() + } catch (error) { + if (error instanceof DuplicateTemplateException) { + new Notice('A template with this name already exists') + } else { + console.error(error) + new Notice('Failed to create template') + } + } + } + + return ( + + +
+ + Create template + + + Create a new template from the selected nodes + +
+ +
+ + setTemplateName(e.target.value)} + onKeyDown={(e) => { + if (e.key === 'Enter') { + e.stopPropagation() + e.preventDefault() + onSubmit() + } + }} + /> +
+ +
+ +
+ +
+ +
+ + + + +
+
+ ) +} diff --git a/src/components/chat-view/QueryProgress.tsx b/src/components/chat-view/QueryProgress.tsx index bb5bd34..4117416 100644 --- a/src/components/chat-view/QueryProgress.tsx +++ b/src/components/chat-view/QueryProgress.tsx @@ -1,4 +1,4 @@ -import { SelectVector } from '../../db/schema' +import { SelectVector } from '../../database/schema' export type QueryProgressState = | { diff --git a/src/components/chat-view/chat-input/ChatUserInput.tsx b/src/components/chat-view/chat-input/ChatUserInput.tsx index c34c8ff..20a0d0a 100644 --- a/src/components/chat-view/chat-input/ChatUserInput.tsx +++ b/src/components/chat-view/chat-input/ChatUserInput.tsx @@ -1,50 +1,29 @@ -import { - InitialConfigType, - LexicalComposer, -} from '@lexical/react/LexicalComposer' -import { ContentEditable } from '@lexical/react/LexicalContentEditable' -import { EditorRefPlugin } from '@lexical/react/LexicalEditorRefPlugin' -import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary' -import { HistoryPlugin } from '@lexical/react/LexicalHistoryPlugin' -import { OnChangePlugin } from '@lexical/react/LexicalOnChangePlugin' -import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin' import { useQuery } from '@tanstack/react-query' import { $nodesOfType, LexicalEditor, SerializedEditorState } from 'lexical' import { forwardRef, - useCallback, useEffect, useImperativeHandle, useRef, useState, } from 'react' -import { - deserializeMentionable, - serializeMentionable, -} from 'src/utils/mentionable' import { useApp } from '../../../contexts/app-context' import { useDarkModeContext } from '../../../contexts/dark-mode-context' import { Mentionable, SerializedMentionable } from '../../../types/mentionable' -import { fuzzySearch } from '../../../utils/fuzzy-search' -import { getMentionableKey } from '../../../utils/mentionable' +import { + deserializeMentionable, + getMentionableKey, + serializeMentionable, +} from '../../../utils/mentionable' import { openMarkdownFile, readTFileContent } from '../../../utils/obsidian' import { MemoizedSyntaxHighlighterWrapper } from '../SyntaxHighlighterWrapper' +import LexicalContentEditable from './LexicalContentEditable' import MentionableBadge from './MentionableBadge' import { ModelSelect } from './ModelSelect' -import AutoFocusPlugin from './plugins/auto-focus/AutoFocusPlugin' -import AutoLinkMentionPlugin from './plugins/mention/AutoLinkMentionPlugin' import { MentionNode } from './plugins/mention/MentionNode' -import MentionPlugin from './plugins/mention/MentionPlugin' -import NoFormatPlugin from './plugins/no-format/NoFormatPlugin' -import OnEnterPlugin from './plugins/on-enter/OnEnterPlugin' -import OnMutationPlugin, { - NodeMutations, -} from './plugins/on-mutation/OnMutationPlugin' -import UpdaterPlugin, { - UpdaterPluginRef, -} from './plugins/updater/UpdaterPlugin' +import { NodeMutations } from './plugins/on-mutation/OnMutationPlugin' import { SubmitButton } from './SubmitButton' import { VaultChatButton } from './VaultChatButton' @@ -53,7 +32,7 @@ export type ChatUserInputRef = { } export type ChatUserInputProps = { - message: SerializedEditorState | null // TODO: fix name to initialContent + initialSerializedEditorState: SerializedEditorState | null onChange: (content: SerializedEditorState) => void onSubmit: (content: SerializedEditorState, useVaultSearch?: boolean) => void onFocus: () => void @@ -66,7 +45,7 @@ export type ChatUserInputProps = { const ChatUserInput = forwardRef( ( { - message, + initialSerializedEditorState, onChange, onSubmit, onFocus, @@ -82,7 +61,7 @@ const ChatUserInput = forwardRef( const editorRef = useRef(null) const contentEditableRef = useRef(null) - const updaterRef = useRef(null) + const containerRef = useRef(null) const [displayedMentionableKey, setDisplayedMentionableKey] = useState< string | null @@ -100,31 +79,6 @@ const ChatUserInput = forwardRef( }, })) - const initialConfig: InitialConfigType = { - namespace: 'ChatUserInput', - theme: { - root: 'smtcmp-chat-input-root', - paragraph: 'smtcmp-chat-input-paragraph', - }, - nodes: [MentionNode], - onError: (error) => { - console.error(error) - }, - } - - // initialize editor state - useEffect(() => { - if (message) { - updaterRef.current?.update(message) - } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []) - - const searchResultByQuery = useCallback( - (query: string) => fuzzySearch(app, query), - [app], - ) - const handleMentionNodeMutation = ( mutations: NodeMutations, ) => { @@ -246,13 +200,13 @@ const ChatUserInput = forwardRef( }, }) - const handleSubmit = (useVaultSearch?: boolean) => { + const handleSubmit = (options: { useVaultSearch?: boolean } = {}) => { const content = editorRef.current?.getEditorState()?.toJSON() - content && onSubmit(content, useVaultSearch) + content && onSubmit(content, options.useVaultSearch) } return ( -
+
{mentionables.length > 0 && (
{mentionables.map((m) => ( @@ -295,59 +249,40 @@ const ChatUserInput = forwardRef(
)} - - {/* - There was two approach to make mentionable node copy and pasteable. - 1. use RichTextPlugin and reset text format when paste - - so I implemented NoFormatPlugin to reset text format when paste - 2. use PlainTextPlugin and override paste command - - PlainTextPlugin only pastes text, so we need to implement custom paste handler. - - https://github.com/facebook/lexical/discussions/5112 - */} - + { + if (initialSerializedEditorState) { + editor.setEditorState( + editor.parseEditorState(initialSerializedEditorState), + ) } - ErrorBoundary={LexicalErrorBoundary} - /> - - {autoFocus && } - - { - onChange(editorState.toJSON()) - }} - /> - - { - evt.preventDefault() - evt.stopPropagation() - handleSubmit(useVaultSearch) - }} - /> - - - - - + }} + editorRef={editorRef} + contentEditableRef={contentEditableRef} + onChange={onChange} + onEnter={() => handleSubmit({ useVaultSearch: false })} + onFocus={onFocus} + onMentionNodeMutation={handleMentionNodeMutation} + autoFocus={autoFocus} + plugins={{ + onEnter: { + onVaultChat: () => { + handleSubmit({ useVaultSearch: true }) + }, + }, + templatePopover: { + anchorElement: containerRef.current, + }, + }} + /> +
handleSubmit()} /> { - handleSubmit(true) + handleSubmit({ useVaultSearch: true }) }} />
diff --git a/src/components/chat-view/chat-input/LexicalContentEditable.tsx b/src/components/chat-view/chat-input/LexicalContentEditable.tsx new file mode 100644 index 0000000..b7947fe --- /dev/null +++ b/src/components/chat-view/chat-input/LexicalContentEditable.tsx @@ -0,0 +1,146 @@ +import { + InitialConfigType, + InitialEditorStateType, + LexicalComposer, +} from '@lexical/react/LexicalComposer' +import { ContentEditable } from '@lexical/react/LexicalContentEditable' +import { EditorRefPlugin } from '@lexical/react/LexicalEditorRefPlugin' +import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary' +import { HistoryPlugin } from '@lexical/react/LexicalHistoryPlugin' +import { OnChangePlugin } from '@lexical/react/LexicalOnChangePlugin' +import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin' +import { LexicalEditor, SerializedEditorState } from 'lexical' +import { RefObject, useCallback, useEffect } from 'react' + +import { useApp } from '../../../contexts/app-context' +import { fuzzySearch } from '../../../utils/fuzzy-search' + +import AutoLinkMentionPlugin from './plugins/mention/AutoLinkMentionPlugin' +import { MentionNode } from './plugins/mention/MentionNode' +import MentionPlugin from './plugins/mention/MentionPlugin' +import NoFormatPlugin from './plugins/no-format/NoFormatPlugin' +import OnEnterPlugin from './plugins/on-enter/OnEnterPlugin' +import OnMutationPlugin, { + NodeMutations, +} from './plugins/on-mutation/OnMutationPlugin' +import CreateTemplatePopoverPlugin from './plugins/template/CreateTemplatePopoverPlugin' +import TemplatePlugin from './plugins/template/TemplatePlugin' + +export type LexicalContentEditableProps = { + editorRef: RefObject + contentEditableRef: RefObject + onChange?: (content: SerializedEditorState) => void + onEnter?: (evt: KeyboardEvent) => void + onFocus?: () => void + onMentionNodeMutation?: (mutations: NodeMutations) => void + initialEditorState?: InitialEditorStateType + autoFocus?: boolean + plugins?: { + onEnter?: { + onVaultChat: () => void + } + templatePopover?: { + anchorElement: HTMLElement | null + } + } +} + +export default function LexicalContentEditable({ + editorRef, + contentEditableRef, + onChange, + onEnter, + onFocus, + onMentionNodeMutation, + initialEditorState, + autoFocus = false, + plugins, +}: LexicalContentEditableProps) { + const app = useApp() + + const initialConfig: InitialConfigType = { + namespace: 'LexicalContentEditable', + theme: { + root: 'smtcmp-lexical-content-editable-root', + paragraph: 'smtcmp-lexical-content-editable-paragraph', + }, + nodes: [MentionNode], + editorState: initialEditorState, + onError: (error) => { + console.error(error) + }, + } + + const searchResultByQuery = useCallback( + (query: string) => fuzzySearch(app, query), + [app], + ) + + /* + * Using requestAnimationFrame for autoFocus instead of using editor.focus() + * due to known issues with editor.focus() when initialConfig.editorState is set + * See: https://github.com/facebook/lexical/issues/4460 + */ + useEffect(() => { + if (autoFocus) { + requestAnimationFrame(() => { + contentEditableRef.current?.focus() + }) + } + }, [autoFocus, contentEditableRef]) + + return ( + + {/* + There was two approach to make mentionable node copy and pasteable. + 1. use RichTextPlugin and reset text format when paste + - so I implemented NoFormatPlugin to reset text format when paste + 2. use PlainTextPlugin and override paste command + - PlainTextPlugin only pastes text, so we need to implement custom paste handler. + - https://github.com/facebook/lexical/discussions/5112 + */} + + } + ErrorBoundary={LexicalErrorBoundary} + /> + + + { + onChange?.(editorState.toJSON()) + }} + /> + {onEnter && ( + + )} + { + onMentionNodeMutation?.(mutations) + }} + /> + + + + + {plugins?.templatePopover && ( + + )} + + ) +} diff --git a/src/components/chat-view/chat-input/ModelSelect.tsx b/src/components/chat-view/chat-input/ModelSelect.tsx index 883e334..0c4be54 100644 --- a/src/components/chat-view/chat-input/ModelSelect.tsx +++ b/src/components/chat-view/chat-input/ModelSelect.tsx @@ -1,8 +1,9 @@ import * as DropdownMenu from '@radix-ui/react-dropdown-menu' import { ChevronDown, ChevronUp } from 'lucide-react' import { useState } from 'react' -import { CHAT_MODEL_OPTIONS } from 'src/constants' -import { useSettings } from 'src/contexts/settings-context' + +import { CHAT_MODEL_OPTIONS } from '../../../constants' +import { useSettings } from '../../../contexts/settings-context' export function ModelSelect() { const { settings, setSettings } = useSettings() diff --git a/src/components/chat-view/chat-input/plugins/auto-focus/AutoFocusPlugin.tsx b/src/components/chat-view/chat-input/plugins/auto-focus/AutoFocusPlugin.tsx deleted file mode 100644 index 714a990..0000000 --- a/src/components/chat-view/chat-input/plugins/auto-focus/AutoFocusPlugin.tsx +++ /dev/null @@ -1,29 +0,0 @@ -import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' -import { useEffect } from 'react' - -export type AutoFocusPluginProps = { - defaultSelection?: 'rootStart' | 'rootEnd' -} - -export default function AutoFocusPlugin({ - defaultSelection, -}: AutoFocusPluginProps) { - const [editor] = useLexicalComposerContext() - - useEffect(() => { - editor.focus( - () => { - const rootElement = editor.getRootElement() - if (rootElement) { - // requestAnimationFrame is required here for unknown reasons, possibly related to the Obsidian plugin environment. - requestAnimationFrame(() => { - rootElement.focus() - }) - } - }, - { defaultSelection }, - ) - }, [defaultSelection, editor]) - - return null -} diff --git a/src/components/chat-view/chat-input/plugins/mention/MentionPlugin.tsx b/src/components/chat-view/chat-input/plugins/mention/MentionPlugin.tsx index 603d023..a41693a 100644 --- a/src/components/chat-view/chat-input/plugins/mention/MentionPlugin.tsx +++ b/src/components/chat-view/chat-input/plugins/mention/MentionPlugin.tsx @@ -11,11 +11,13 @@ import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext import { $createTextNode, COMMAND_PRIORITY_NORMAL, TextNode } from 'lexical' import { useCallback, useMemo, useState } from 'react' import { createPortal } from 'react-dom' -import { serializeMentionable } from 'src/utils/mentionable' import { Mentionable } from '../../../../../types/mentionable' import { SearchableMentionable } from '../../../../../utils/fuzzy-search' -import { getMentionableName } from '../../../../../utils/mentionable' +import { + getMentionableName, + serializeMentionable, +} from '../../../../../utils/mentionable' import { getMentionableIcon } from '../../utils/get-metionable-icon' import { MenuOption, MenuTextMatch } from '../shared/LexicalMenu' import { diff --git a/src/components/chat-view/chat-input/plugins/on-enter/OnEnterPlugin.tsx b/src/components/chat-view/chat-input/plugins/on-enter/OnEnterPlugin.tsx index 899fa08..8cfa78d 100644 --- a/src/components/chat-view/chat-input/plugins/on-enter/OnEnterPlugin.tsx +++ b/src/components/chat-view/chat-input/plugins/on-enter/OnEnterPlugin.tsx @@ -5,8 +5,10 @@ import { useEffect } from 'react' export default function OnEnterPlugin({ onEnter, + onVaultChat, }: { - onEnter: (evt: KeyboardEvent, useVaultSearch?: boolean) => void + onEnter: (evt: KeyboardEvent) => void + onVaultChat?: () => void }) { const [editor] = useLexicalComposerContext() @@ -14,14 +16,22 @@ export default function OnEnterPlugin({ const removeListener = editor.registerCommand( KEY_ENTER_COMMAND, (evt: KeyboardEvent) => { + if ( + onVaultChat && + evt.shiftKey && + (Platform.isMacOS ? evt.metaKey : evt.ctrlKey) + ) { + evt.preventDefault() + evt.stopPropagation() + onVaultChat() + return true + } if (evt.shiftKey) { - if (Platform.isMacOS ? evt.metaKey : evt.ctrlKey) { - onEnter(evt, true) - return true - } return false } - onEnter(evt, false) + evt.preventDefault() + evt.stopPropagation() + onEnter(evt) return true }, COMMAND_PRIORITY_LOW, @@ -30,7 +40,7 @@ export default function OnEnterPlugin({ return () => { removeListener() } - }, [editor, onEnter]) + }, [editor, onEnter, onVaultChat]) return null } diff --git a/src/components/chat-view/chat-input/plugins/template/CreateTemplatePopoverPlugin.tsx b/src/components/chat-view/chat-input/plugins/template/CreateTemplatePopoverPlugin.tsx new file mode 100644 index 0000000..a36157e --- /dev/null +++ b/src/components/chat-view/chat-input/plugins/template/CreateTemplatePopoverPlugin.tsx @@ -0,0 +1,132 @@ +import { $generateJSONFromSelectedNodes } from '@lexical/clipboard' +import { BaseSerializedNode } from '@lexical/clipboard/clipboard' +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' +import * as Dialog from '@radix-ui/react-dialog' +import { + $getSelection, + COMMAND_PRIORITY_LOW, + SELECTION_CHANGE_COMMAND, +} from 'lexical' +import { CSSProperties, useCallback, useEffect, useRef, useState } from 'react' + +import CreateTemplateDialogContent from '../../../CreateTemplateDialog' + +export default function CreateTemplatePopoverPlugin({ + anchorElement, + contentEditableElement, +}: { + anchorElement: HTMLElement | null + contentEditableElement: HTMLElement | null +}): JSX.Element | null { + const [editor] = useLexicalComposerContext() + + const [popoverStyle, setPopoverStyle] = useState(null) + const [isPopoverOpen, setIsPopoverOpen] = useState(false) + const [isDialogOpen, setIsDialogOpen] = useState(false) + const [selectedSerializedNodes, setSelectedSerializedNodes] = useState< + BaseSerializedNode[] | null + >(null) + + const popoverRef = useRef(null) + + const getSelectedSerializedNodes = useCallback((): + | BaseSerializedNode[] + | null => { + if (!editor) return null + let selectedNodes: BaseSerializedNode[] | null = null + editor.update(() => { + const selection = $getSelection() + if (!selection) return + selectedNodes = $generateJSONFromSelectedNodes(editor, selection).nodes + if (selectedNodes.length === 0) return null + }) + return selectedNodes + }, [editor]) + + const updatePopoverPosition = useCallback(() => { + if (!anchorElement || !contentEditableElement) return + const nativeSelection = document.getSelection() + const range = nativeSelection?.getRangeAt(0) + if (!range || range.collapsed) { + setIsPopoverOpen(false) + return + } + if (!contentEditableElement.contains(range.commonAncestorContainer)) { + setIsPopoverOpen(false) + return + } + const rects = Array.from(range.getClientRects()) + if (rects.length === 0) { + setIsPopoverOpen(false) + return + } + const anchorRect = anchorElement.getBoundingClientRect() + const idealLeft = rects[rects.length - 1].right - anchorRect.left + const paddingX = 8 + const paddingY = 4 + const minLeft = (popoverRef.current?.offsetWidth ?? 0) + paddingX + const finalLeft = Math.max(minLeft, idealLeft) + setPopoverStyle({ + top: rects[rects.length - 1].bottom - anchorRect.top + paddingY, + left: finalLeft, + transform: 'translate(-100%, 0)', + }) + setIsPopoverOpen(true) + }, [anchorElement, contentEditableElement]) + + useEffect(() => { + const removeSelectionChangeListener = editor.registerCommand( + SELECTION_CHANGE_COMMAND, + () => { + updatePopoverPosition() + return false + }, + COMMAND_PRIORITY_LOW, + ) + return () => { + removeSelectionChangeListener() + } + }, [editor, updatePopoverPosition]) + + useEffect(() => { + if (!contentEditableElement) return + const handleScroll = () => { + updatePopoverPosition() + } + contentEditableElement.addEventListener('scroll', handleScroll) + return () => { + contentEditableElement.removeEventListener('scroll', handleScroll) + } + }, [contentEditableElement, updatePopoverPosition]) + + return ( + { + if (open) { + setSelectedSerializedNodes(getSelectedSerializedNodes()) + } + setIsDialogOpen(open) + setIsPopoverOpen(false) + }} + > + + + + setIsDialogOpen(false)} + /> + + ) +} diff --git a/src/components/chat-view/chat-input/plugins/template/TemplatePlugin.tsx b/src/components/chat-view/chat-input/plugins/template/TemplatePlugin.tsx new file mode 100644 index 0000000..1f6eb38 --- /dev/null +++ b/src/components/chat-view/chat-input/plugins/template/TemplatePlugin.tsx @@ -0,0 +1,179 @@ +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' +import clsx from 'clsx' +import { + $parseSerializedNode, + COMMAND_PRIORITY_NORMAL, + TextNode, +} from 'lexical' +import { Trash2 } from 'lucide-react' +import { useCallback, useEffect, useMemo, useState } from 'react' +import { createPortal } from 'react-dom' + +import { useDatabase } from '../../../../../contexts/database-context' +import { SelectTemplate } from '../../../../../database/schema' +import { MenuOption } from '../shared/LexicalMenu' +import { + LexicalTypeaheadMenuPlugin, + useBasicTypeaheadTriggerMatch, +} from '../typeahead-menu/LexicalTypeaheadMenuPlugin' + +class TemplateTypeaheadOption extends MenuOption { + name: string + template: SelectTemplate + + constructor(name: string, template: SelectTemplate) { + super(name) + this.name = name + this.template = template + } +} + +function TemplateMenuItem({ + index, + isSelected, + onClick, + onDelete, + onMouseEnter, + option, +}: { + index: number + isSelected: boolean + onClick: () => void + onDelete: () => void + onMouseEnter: () => void + option: TemplateTypeaheadOption +}) { + return ( +
  • option.setRefElement(el)} + role="option" + aria-selected={isSelected} + id={`typeahead-item-${index}`} + onMouseEnter={onMouseEnter} + onClick={onClick} + > +
    +
    {option.name}
    +
    { + evt.stopPropagation() + evt.preventDefault() + onDelete() + }} + className="smtcmp-template-menu-item-delete" + > + +
    +
    +
  • + ) +} + +export default function TemplatePlugin() { + const [editor] = useLexicalComposerContext() + const { templateManager } = useDatabase() + + const [queryString, setQueryString] = useState(null) + const [searchResults, setSearchResults] = useState([]) + + useEffect(() => { + if (queryString == null) return + templateManager.searchTemplates(queryString).then(setSearchResults) + }, [queryString, templateManager]) + + const options = useMemo( + () => + searchResults.map( + (result) => new TemplateTypeaheadOption(result.name, result), + ), + [searchResults], + ) + + const checkForTriggerMatch = useBasicTypeaheadTriggerMatch('/', { + minLength: 0, + }) + + const onSelectOption = useCallback( + ( + selectedOption: TemplateTypeaheadOption, + nodeToRemove: TextNode | null, + closeMenu: () => void, + ) => { + editor.update(() => { + const parsedNodes = selectedOption.template.content.nodes.map((node) => + $parseSerializedNode(node), + ) + if (nodeToRemove) { + const parent = nodeToRemove.getParentOrThrow() + parent.splice(nodeToRemove.getIndexWithinParent(), 1, parsedNodes) + const lastNode = parsedNodes[parsedNodes.length - 1] + lastNode.selectEnd() + } + closeMenu() + }) + }, + [editor], + ) + + const handleDelete = useCallback( + async (option: TemplateTypeaheadOption) => { + await templateManager.deleteTemplate(option.template.id) + if (queryString !== null) { + const updatedResults = + await templateManager.searchTemplates(queryString) + setSearchResults(updatedResults) + } + }, + [templateManager, queryString], + ) + + return ( + + onQueryChange={setQueryString} + onSelectOption={onSelectOption} + triggerFn={checkForTriggerMatch} + options={options} + commandPriority={COMMAND_PRIORITY_NORMAL} + menuRenderFn={( + anchorElementRef, + { selectedIndex, selectOptionAndCleanUp, setHighlightedIndex }, + ) => + anchorElementRef.current && searchResults.length + ? createPortal( +
    +
      + {options.map((option, i: number) => ( + { + setHighlightedIndex(i) + selectOptionAndCleanUp(option) + }} + onDelete={() => { + handleDelete(option) + }} + onMouseEnter={() => { + setHighlightedIndex(i) + }} + key={option.key} + option={option} + /> + ))} +
    +
    , + anchorElementRef.current, + ) + : null + } + /> + ) +} diff --git a/src/components/chat-view/chat-input/plugins/updater/UpdaterPlugin.tsx b/src/components/chat-view/chat-input/plugins/updater/UpdaterPlugin.tsx deleted file mode 100644 index f767d0f..0000000 --- a/src/components/chat-view/chat-input/plugins/updater/UpdaterPlugin.tsx +++ /dev/null @@ -1,23 +0,0 @@ -import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' -import { SerializedEditorState } from 'lexical' -import { Ref, useImperativeHandle } from 'react' - -export type UpdaterPluginRef = { - update: (content: SerializedEditorState) => void -} - -export default function UpdaterPlugin({ - updaterRef, -}: { - updaterRef: Ref -}) { - const [editor] = useLexicalComposerContext() - - useImperativeHandle(updaterRef, () => ({ - update: (content: SerializedEditorState) => { - editor.setEditorState(editor.parseEditorState(content)) - }, - })) - - return null -} diff --git a/src/constants.ts b/src/constants.ts index 9c60545..987fb5e 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -47,7 +47,7 @@ export const APPLY_MODEL_OPTIONS = [ }, ] -// Update table exports in src/db/schema.ts when updating this +// Update table exports in src/database/schema.ts when updating this export const EMBEDDING_MODEL_OPTIONS = [ { name: 'text-embedding-3-small (Recommended)', diff --git a/src/contexts/database-context.tsx b/src/contexts/database-context.tsx new file mode 100644 index 0000000..54703f2 --- /dev/null +++ b/src/contexts/database-context.tsx @@ -0,0 +1,45 @@ +import { createContext, useContext, useMemo } from 'react' + +import { DatabaseManager } from '../database/DatabaseManager' +import { TemplateManager } from '../database/modules/template/TemplateManager' +import { VectorManager } from '../database/modules/vector/VectorManager' + +type DatabaseContextType = { + databaseManager: DatabaseManager + vectorManager: VectorManager + templateManager: TemplateManager +} + +const DatabaseContext = createContext(null) + +export function DatabaseProvider({ + children, + databaseManager, +}: { + children: React.ReactNode + databaseManager: DatabaseManager +}) { + const vectorManager = useMemo(() => { + return databaseManager.getVectorManager() + }, [databaseManager]) + + const templateManager = useMemo(() => { + return databaseManager.getTemplateManager() + }, [databaseManager]) + + return ( + + {children} + + ) +} + +export function useDatabase(): DatabaseContextType { + const context = useContext(DatabaseContext) + if (!context) { + throw new Error('useDatabase must be used within a DatabaseProvider') + } + return context +} diff --git a/src/contexts/dialog-container-context.tsx b/src/contexts/dialog-container-context.tsx new file mode 100644 index 0000000..9ff59f0 --- /dev/null +++ b/src/contexts/dialog-container-context.tsx @@ -0,0 +1,27 @@ +import React, { createContext, useContext } from 'react' + +const DialogContainerContext = createContext(null) + +export function DialogContainerProvider({ + children, + container, +}: { + children: React.ReactNode + container: HTMLElement | null +}) { + return ( + + {children} + + ) +} + +export function useDialogContainer() { + const context = useContext(DialogContainerContext) + if (!context) { + throw new Error( + 'useDialogContainer must be used within a DialogContainerProvider', + ) + } + return context +} diff --git a/src/contexts/llm-context.tsx b/src/contexts/llm-context.tsx index 08ba16b..9e6e966 100644 --- a/src/contexts/llm-context.tsx +++ b/src/contexts/llm-context.tsx @@ -7,6 +7,7 @@ import { useState, } from 'react' +import LLMManager from '../core/llm/manager' import { LLMOptions, LLMRequestNonStreaming, @@ -16,7 +17,6 @@ import { LLMResponseNonStreaming, LLMResponseStreaming, } from '../types/llm/response' -import LLMManager from '../utils/llm/manager' import { useSettings } from './settings-context' diff --git a/src/contexts/rag-context.tsx b/src/contexts/rag-context.tsx index 777d39c..f969225 100644 --- a/src/contexts/rag-context.tsx +++ b/src/contexts/rag-context.tsx @@ -1,6 +1,6 @@ import { PropsWithChildren, createContext, useContext } from 'react' -import { RAGEngine } from '../utils/ragEngine' +import { RAGEngine } from '../core/rag/ragEngine' export type RAGContextType = { ragEngine: RAGEngine | null diff --git a/src/utils/llm/anthropic.ts b/src/core/llm/anthropic.ts similarity index 100% rename from src/utils/llm/anthropic.ts rename to src/core/llm/anthropic.ts diff --git a/src/utils/llm/base.ts b/src/core/llm/base.ts similarity index 100% rename from src/utils/llm/base.ts rename to src/core/llm/base.ts diff --git a/src/utils/llm/exception.ts b/src/core/llm/exception.ts similarity index 100% rename from src/utils/llm/exception.ts rename to src/core/llm/exception.ts diff --git a/src/utils/llm/groq.ts b/src/core/llm/groq.ts similarity index 100% rename from src/utils/llm/groq.ts rename to src/core/llm/groq.ts diff --git a/src/utils/llm/manager.ts b/src/core/llm/manager.ts similarity index 100% rename from src/utils/llm/manager.ts rename to src/core/llm/manager.ts diff --git a/src/utils/llm/ollama.ts b/src/core/llm/ollama.ts similarity index 97% rename from src/utils/llm/ollama.ts rename to src/core/llm/ollama.ts index 20ab58a..f852e78 100644 --- a/src/utils/llm/ollama.ts +++ b/src/core/llm/ollama.ts @@ -1,14 +1,15 @@ import OpenAI from 'openai' import { FinalRequestOptions } from 'openai/core' + import { LLMOptions, LLMRequestNonStreaming, LLMRequestStreaming, -} from 'src/types/llm/request' +} from '../../types/llm/request' import { LLMResponseNonStreaming, LLMResponseStreaming, -} from 'src/types/llm/response' +} from '../../types/llm/response' import { BaseLLMProvider } from './base' import { LLMBaseUrlNotSetException } from './exception' diff --git a/src/utils/llm/openai.ts b/src/core/llm/openai.ts similarity index 100% rename from src/utils/llm/openai.ts rename to src/core/llm/openai.ts diff --git a/src/utils/llm/openaiCompatibleProvider.ts b/src/core/llm/openaiCompatibleProvider.ts similarity index 100% rename from src/utils/llm/openaiCompatibleProvider.ts rename to src/core/llm/openaiCompatibleProvider.ts diff --git a/src/utils/embedding.ts b/src/core/rag/embedding.ts similarity index 96% rename from src/utils/embedding.ts rename to src/core/rag/embedding.ts index 4a502fe..965504c 100644 --- a/src/utils/embedding.ts +++ b/src/core/rag/embedding.ts @@ -1,12 +1,11 @@ import { OpenAI } from 'openai' -import { EmbeddingModel } from '../types/embedding' - +import { EmbeddingModel } from '../../types/embedding' import { LLMAPIKeyNotSetException, LLMBaseUrlNotSetException, -} from './llm/exception' -import { NoStainlessOpenAI } from './llm/ollama' +} from '../llm/exception' +import { NoStainlessOpenAI } from '../llm/ollama' export const getEmbeddingModel = ( name: string, diff --git a/src/utils/ragEngine.ts b/src/core/rag/ragEngine.ts similarity index 76% rename from src/utils/ragEngine.ts rename to src/core/rag/ragEngine.ts index ce5538d..367012c 100644 --- a/src/utils/ragEngine.ts +++ b/src/core/rag/ragEngine.ts @@ -1,42 +1,35 @@ import { App } from 'obsidian' -import { QueryProgressState } from '../components/chat-view/QueryProgress' -import { SelectVector } from '../db/schema' -import { EmbeddingModel } from '../types/embedding' -import { SmartCopilotSettings } from '../types/settings' +import { QueryProgressState } from '../../components/chat-view/QueryProgress' +import { DatabaseManager } from '../../database/DatabaseManager' +import { VectorManager } from '../../database/modules/vector/VectorManager' +import { SelectVector } from '../../database/schema' +import { EmbeddingModel } from '../../types/embedding' +import { SmartCopilotSettings } from '../../types/settings' import { getEmbeddingModel } from './embedding' -import { VectorDbManager } from './vector-db/manager' export class RAGEngine { private app: App private settings: SmartCopilotSettings - private vectorDbManager: VectorDbManager + private vectorManager: VectorManager private embeddingModel: EmbeddingModel | null = null - constructor(app: App, settings: SmartCopilotSettings) { - this.app = app - this.settings = settings - } - - static async create( + constructor( app: App, settings: SmartCopilotSettings, - ): Promise { - const ragEngine = new RAGEngine(app, settings) - ragEngine.vectorDbManager = await VectorDbManager.create(app) - ragEngine.embeddingModel = getEmbeddingModel( + dbManager: DatabaseManager, + ) { + this.app = app + this.settings = settings + this.vectorManager = dbManager.getVectorManager() + this.embeddingModel = getEmbeddingModel( settings.embeddingModel, { openAIApiKey: settings.openAIApiKey, }, settings.ollamaBaseUrl, ) - return ragEngine - } - - async cleanup() { - await this.vectorDbManager.cleanup() } setSettings(settings: SmartCopilotSettings) { @@ -61,7 +54,7 @@ export class RAGEngine { if (!this.embeddingModel) { throw new Error('Embedding model is not set') } - await this.vectorDbManager.updateVaultIndex( + await this.vectorManager.updateVaultIndex( this.embeddingModel, { chunkSize: this.settings.ragOptions.chunkSize, @@ -102,7 +95,7 @@ export class RAGEngine { onQueryProgressChange?.({ type: 'querying', }) - const queryResult = await this.vectorDbManager.performSimilaritySearch( + const queryResult = await this.vectorManager.performSimilaritySearch( queryEmbedding, this.embeddingModel, { diff --git a/src/database/DatabaseManager.ts b/src/database/DatabaseManager.ts new file mode 100644 index 0000000..40286ec --- /dev/null +++ b/src/database/DatabaseManager.ts @@ -0,0 +1,156 @@ +import { PGlite } from '@electric-sql/pglite' +import { PgliteDatabase, drizzle } from 'drizzle-orm/pglite' +import { App, normalizePath, requestUrl } from 'obsidian' + +import { PGLITE_DB_PATH } from '../constants' + +import migrations from './migrations.json' +import { TemplateManager } from './modules/template/TemplateManager' +import { VectorManager } from './modules/vector/VectorManager' + +export class DatabaseManager { + private app: App + private dbPath: string + private pgClient: PGlite | null = null + private db: PgliteDatabase | null = null + private vectorManager: VectorManager + private templateManager: TemplateManager + + constructor(app: App, dbPath: string) { + this.app = app + this.dbPath = dbPath + } + + static async create(app: App): Promise { + const dbManager = new DatabaseManager(app, normalizePath(PGLITE_DB_PATH)) + dbManager.db = await dbManager.loadExistingDatabase() + if (!dbManager.db) { + dbManager.db = await dbManager.createNewDatabase() + } + await dbManager.migrateDatabase() + await dbManager.save() + + dbManager.vectorManager = new VectorManager(app, dbManager) + dbManager.templateManager = new TemplateManager(app, dbManager) + + console.log('Smart composer database initialized.') + return dbManager + } + + getDb() { + return this.db + } + + getVectorManager() { + return this.vectorManager + } + + getTemplateManager() { + return this.templateManager + } + + private async createNewDatabase() { + const { fsBundle, wasmModule, vectorExtensionBundlePath } = + await this.loadPGliteResources() + this.pgClient = await PGlite.create({ + fsBundle: fsBundle, + wasmModule: wasmModule, + extensions: { + vector: vectorExtensionBundlePath, + }, + }) + const db = drizzle(this.pgClient) + return db + } + + private async loadExistingDatabase(): Promise { + try { + const databaseFileExists = await this.app.vault.adapter.exists( + this.dbPath, + ) + if (!databaseFileExists) { + return null + } + const fileBuffer = await this.app.vault.adapter.readBinary(this.dbPath) + const fileBlob = new Blob([fileBuffer], { type: 'application/x-gzip' }) + const { fsBundle, wasmModule, vectorExtensionBundlePath } = + await this.loadPGliteResources() + this.pgClient = await PGlite.create({ + loadDataDir: fileBlob, + fsBundle: fsBundle, + wasmModule: wasmModule, + extensions: { + vector: vectorExtensionBundlePath, + }, + }) + return drizzle(this.pgClient) + } catch (error) { + console.error('Error loading database:', error) + return null + } + } + + private async migrateDatabase(): Promise { + try { + // Workaround for running Drizzle migrations in a browser environment + // This method uses an undocumented API to perform migrations + // See: https://github.com/drizzle-team/drizzle-orm/discussions/2532#discussioncomment-10780523 + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-expect-error + await this.db.dialect.migrate(migrations, this.db.session, { + migrationsTable: 'drizzle_migrations', + }) + } catch (error) { + console.error('Error migrating database:', error) + throw error + } + } + + async save(): Promise { + if (!this.pgClient) { + return + } + try { + const blob: Blob = await this.pgClient.dumpDataDir('gzip') + await this.app.vault.adapter.writeBinary( + this.dbPath, + Buffer.from(await blob.arrayBuffer()), + ) + } catch (error) { + console.error('Error saving database:', error) + } + } + + async cleanup() { + this.pgClient?.close() + this.db = null + this.pgClient = null + } + + // TODO: This function is a temporary workaround chosen due to the difficulty of bundling postgres.wasm and postgres.data from node_modules into a single JS file. The ultimate goal is to bundle everything into one JS file in the future. + private async loadPGliteResources(): Promise<{ + fsBundle: Blob + wasmModule: WebAssembly.Module + vectorExtensionBundlePath: URL + }> { + try { + const [fsBundleResponse, wasmResponse] = await Promise.all([ + requestUrl('https://unpkg.com/@electric-sql/pglite/dist/postgres.data'), + requestUrl('https://unpkg.com/@electric-sql/pglite/dist/postgres.wasm'), + ]) + + const fsBundle = new Blob([fsBundleResponse.arrayBuffer], { + type: 'application/octet-stream', + }) + const wasmModule = await WebAssembly.compile(wasmResponse.arrayBuffer) + const vectorExtensionBundlePath = new URL( + 'https://unpkg.com/@electric-sql/pglite/dist/vector.tar.gz', + ) + + return { fsBundle, wasmModule, vectorExtensionBundlePath } + } catch (error) { + console.error('Error loading PGlite resources:', error) + throw error + } + } +} diff --git a/src/database/exception.ts b/src/database/exception.ts new file mode 100644 index 0000000..3838a94 --- /dev/null +++ b/src/database/exception.ts @@ -0,0 +1,20 @@ +export class DatabaseException extends Error { + constructor(message: string) { + super(message) + this.name = 'DatabaseException' + } +} + +export class DatabaseNotInitializedException extends DatabaseException { + constructor(message = 'Database not initialized') { + super(message) + this.name = 'DatabaseNotInitializedException' + } +} + +export class DuplicateTemplateException extends DatabaseException { + constructor(templateName: string) { + super(`Template with name "${templateName}" already exists`) + this.name = 'DuplicateTemplateException' + } +} diff --git a/src/db/migrations.json b/src/database/migrations.json similarity index 85% rename from src/db/migrations.json rename to src/database/migrations.json index ddd7d53..0c14f03 100644 --- a/src/db/migrations.json +++ b/src/database/migrations.json @@ -43,5 +43,13 @@ "bps": true, "folderMillis": 1730085029605, "hash": "e41dde00d7d98da74596c60e796aa78d5cbfec57208347c72f16a6fa11c24ef1" + }, + { + "sql": [ + "CREATE TABLE IF NOT EXISTS \"template\" (\n\t\"id\" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,\n\t\"name\" text NOT NULL,\n\t\"content\" jsonb NOT NULL,\n\t\"created_at\" timestamp DEFAULT now() NOT NULL,\n\t\"updated_at\" timestamp DEFAULT now() NOT NULL,\n\tCONSTRAINT \"template_name_unique\" UNIQUE(\"name\")\n);\n" + ], + "bps": true, + "folderMillis": 1730443763982, + "hash": "eddf72b8d40619c170b3c12f3d3ce280385b1fc05717f211ab6077c6bea691bf" } ] diff --git a/src/database/modules/template/TemplateManager.ts b/src/database/modules/template/TemplateManager.ts new file mode 100644 index 0000000..df54ebe --- /dev/null +++ b/src/database/modules/template/TemplateManager.ts @@ -0,0 +1,51 @@ +import fuzzysort from 'fuzzysort' +import { App } from 'obsidian' + +import { DatabaseManager } from '../../DatabaseManager' +import { DuplicateTemplateException } from '../../exception' +import { InsertTemplate, SelectTemplate } from '../../schema' + +import { TemplateRepository } from './TemplateRepository' + +export class TemplateManager { + private app: App + private repository: TemplateRepository + private dbManager: DatabaseManager + + constructor(app: App, dbManager: DatabaseManager) { + this.app = app + this.dbManager = dbManager + this.repository = new TemplateRepository(app, dbManager.getDb()) + } + + async createTemplate(template: InsertTemplate): Promise { + const existingTemplate = await this.repository.findByName(template.name) + if (existingTemplate) { + throw new DuplicateTemplateException(template.name) + } + const created = await this.repository.create(template) + await this.dbManager.save() + return created + } + + async findAllTemplates(): Promise { + return await this.repository.findAll() + } + + async searchTemplates(query: string): Promise { + const templates = await this.findAllTemplates() + const results = fuzzysort.go(query, templates, { + keys: ['name'], + threshold: 0.2, + limit: 20, + all: true, + }) + return results.map((result) => result.obj) + } + + async deleteTemplate(id: string): Promise { + const deleted = await this.repository.delete(id) + await this.dbManager.save() + return deleted + } +} diff --git a/src/database/modules/template/TemplateRepository.ts b/src/database/modules/template/TemplateRepository.ts new file mode 100644 index 0000000..3c32927 --- /dev/null +++ b/src/database/modules/template/TemplateRepository.ts @@ -0,0 +1,76 @@ +import { eq } from 'drizzle-orm' +import { PgliteDatabase } from 'drizzle-orm/pglite' +import { App } from 'obsidian' + +import { DatabaseNotInitializedException } from '../../exception' +import { + type InsertTemplate, + type SelectTemplate, + templateTable, +} from '../../schema' + +export class TemplateRepository { + private app: App + private db: PgliteDatabase | null + + constructor(app: App, db: PgliteDatabase | null) { + this.app = app + this.db = db + } + + async create(template: InsertTemplate): Promise { + if (!this.db) { + throw new DatabaseNotInitializedException() + } + + const [created] = await this.db + .insert(templateTable) + .values(template) + .returning() + return created + } + + async findAll(): Promise { + if (!this.db) { + throw new DatabaseNotInitializedException() + } + return await this.db.select().from(templateTable) + } + + async findByName(name: string): Promise { + if (!this.db) { + throw new DatabaseNotInitializedException() + } + const [template] = await this.db + .select() + .from(templateTable) + .where(eq(templateTable.name, name)) + return template ?? null + } + + async update( + id: string, + template: Partial, + ): Promise { + if (!this.db) { + throw new DatabaseNotInitializedException() + } + const [updated] = await this.db + .update(templateTable) + .set({ ...template, updatedAt: new Date() }) + .where(eq(templateTable.id, id)) + .returning() + return updated + } + + async delete(id: string): Promise { + if (!this.db) { + throw new DatabaseNotInitializedException() + } + const [deleted] = await this.db + .delete(templateTable) + .where(eq(templateTable.id, id)) + .returning() + return !!deleted + } +} diff --git a/src/utils/vector-db/manager.ts b/src/database/modules/vector/VectorManager.ts similarity index 88% rename from src/utils/vector-db/manager.ts rename to src/database/modules/vector/VectorManager.ts index 43dc055..6b594e8 100644 --- a/src/utils/vector-db/manager.ts +++ b/src/database/modules/vector/VectorManager.ts @@ -1,38 +1,30 @@ import { backOff } from 'exponential-backoff' import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter' -import { App, TFile, normalizePath } from 'obsidian' +import { App, TFile } from 'obsidian' import pLimit from 'p-limit' -import { IndexProgress } from '../../components/chat-view/QueryProgress' -import { PGLITE_DB_PATH } from '../../constants' -import { InsertVector, SelectVector } from '../../db/schema' -import { EmbeddingModel } from '../../types/embedding' +import { IndexProgress } from '../../../components/chat-view/QueryProgress' import { LLMAPIKeyInvalidException, LLMAPIKeyNotSetException, LLMBaseUrlNotSetException, -} from '../llm/exception' -import { openSettingsModalWithError } from '../openSettingsModal' +} from '../../../core/llm/exception' +import { InsertVector, SelectVector } from '../../../database/schema' +import { EmbeddingModel } from '../../../types/embedding' +import { openSettingsModalWithError } from '../../../utils/openSettingsModal' +import { DatabaseManager } from '../../DatabaseManager' -import { VectorDbRepository } from './repository' +import { VectorRepository } from './VectorRepository' -export class VectorDbManager { +export class VectorManager { private app: App - private repository: VectorDbRepository + private repository: VectorRepository + private dbManager: DatabaseManager - constructor(app: App) { + constructor(app: App, dbManager: DatabaseManager) { this.app = app - } - - static async create(app: App): Promise { - const manager = new VectorDbManager(app) - const dbPath = normalizePath(PGLITE_DB_PATH) - manager.repository = await VectorDbRepository.create(app, dbPath) - return manager - } - - async cleanup() { - await this.repository.cleanup() + this.dbManager = dbManager + this.repository = new VectorRepository(app, dbManager.getDb()) } async performSimilaritySearch( @@ -203,7 +195,7 @@ export class VectorDbManager { throw error } } finally { - await this.repository.save() + await this.dbManager.save() } } diff --git a/src/database/modules/vector/VectorRepository.ts b/src/database/modules/vector/VectorRepository.ts new file mode 100644 index 0000000..a1831ae --- /dev/null +++ b/src/database/modules/vector/VectorRepository.ts @@ -0,0 +1,164 @@ +import { + SQL, + and, + cosineDistance, + desc, + eq, + getTableColumns, + gt, + inArray, + like, + or, + sql, +} from 'drizzle-orm' +import { PgliteDatabase } from 'drizzle-orm/pglite' +import { App } from 'obsidian' + +import { DatabaseNotInitializedException } from '../../exception' +import { InsertVector, SelectVector, vectorTables } from '../../schema' + +export class VectorRepository { + private app: App + private db: PgliteDatabase | null + + constructor(app: App, db: PgliteDatabase | null) { + this.app = app + this.db = db + } + + async getIndexedFilePaths(embeddingModel: { + name: string + }): Promise { + if (!this.db) { + throw new DatabaseNotInitializedException() + } + const vectors = vectorTables[embeddingModel.name] + const indexedFiles = await this.db + .select({ + path: vectors.path, + }) + .from(vectors) + return indexedFiles.map((row) => row.path) + } + + async getVectorsByFilePath( + filePath: string, + embeddingModel: { name: string }, + ): Promise { + if (!this.db) { + throw new DatabaseNotInitializedException() + } + const vectors = vectorTables[embeddingModel.name] + const fileVectors = await this.db + .select() + .from(vectors) + .where(eq(vectors.path, filePath)) + return fileVectors + } + + async deleteVectorsForSingleFile( + filePath: string, + embeddingModel: { name: string }, + ): Promise { + if (!this.db) { + throw new DatabaseNotInitializedException() + } + const vectors = vectorTables[embeddingModel.name] + await this.db.delete(vectors).where(eq(vectors.path, filePath)) + } + + async deleteVectorsForMultipleFiles( + filePaths: string[], + embeddingModel: { name: string }, + ): Promise { + if (!this.db) { + throw new DatabaseNotInitializedException() + } + const vectors = vectorTables[embeddingModel.name] + await this.db.delete(vectors).where(inArray(vectors.path, filePaths)) + } + + async clearAllVectors(embeddingModel: { name: string }): Promise { + if (!this.db) { + throw new DatabaseNotInitializedException() + } + const vectors = vectorTables[embeddingModel.name] + await this.db.delete(vectors) + } + + async insertVectors( + data: InsertVector[], + embeddingModel: { name: string }, + ): Promise { + if (!this.db) { + throw new DatabaseNotInitializedException() + } + const vectors = vectorTables[embeddingModel.name] + await this.db.insert(vectors).values(data) + } + + async performSimilaritySearch( + queryVector: number[], + embeddingModel: { name: string }, + options: { + minSimilarity: number + limit: number + scope?: { + files: string[] + folders: string[] + } + }, + ): Promise< + (Omit & { + similarity: number + })[] + > { + if (!this.db) { + throw new DatabaseNotInitializedException() + } + const vectors = vectorTables[embeddingModel.name] + + const similarity = sql`1 - (${cosineDistance(vectors.embedding, queryVector)})` + const similarityCondition = gt(similarity, options.minSimilarity) + + const getScopeCondition = (): SQL | undefined => { + if (!options.scope) { + return undefined + } + const conditions: (SQL | undefined)[] = [] + if (options.scope.files.length > 0) { + conditions.push(inArray(vectors.path, options.scope.files)) + } + if (options.scope.folders.length > 0) { + conditions.push( + or( + ...options.scope.folders.map((folder) => + like(vectors.path, `${folder}/%`), + ), + ), + ) + } + if (conditions.length === 0) { + return undefined + } + return or(...conditions) + } + const scopeCondition = getScopeCondition() + + const similaritySearchResult = await this.db + .select({ + ...(() => { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { embedding, ...rest } = getTableColumns(vectors) + return rest + })(), + similarity, + }) + .from(vectors) + .where(and(similarityCondition, scopeCondition)) + .orderBy((t) => desc(t.similarity)) + .limit(options.limit) + + return similaritySearchResult + } +} diff --git a/src/db/schema.ts b/src/database/schema.ts similarity index 76% rename from src/db/schema.ts rename to src/database/schema.ts index 0375914..526752c 100644 --- a/src/db/schema.ts +++ b/src/database/schema.ts @@ -5,12 +5,16 @@ import { pgTable, serial, text, + timestamp, + uuid, vector, } from 'drizzle-orm/pg-core' +import { SerializedLexicalNode } from 'lexical' import { EMBEDDING_MODEL_OPTIONS } from '../constants' import { EmbeddingModelName } from '../types/embedding' +/* Vector Table */ const createVectorTable = (name: string, dimension: number) => { const sanitizedName = name.replace(/[^a-zA-Z0-9]/g, '_') return pgTable( @@ -58,3 +62,19 @@ export const vectorTable1 = vectorTables[EMBEDDING_MODEL_OPTIONS[1].value] export const vectorTable2 = vectorTables[EMBEDDING_MODEL_OPTIONS[2].value] export const vectorTable3 = vectorTables[EMBEDDING_MODEL_OPTIONS[3].value] export const vectorTable4 = vectorTables[EMBEDDING_MODEL_OPTIONS[4].value] + +/* Template Table */ +export type TemplateContent = { + nodes: SerializedLexicalNode[] +} + +export const templateTable = pgTable('template', { + id: uuid('id').defaultRandom().primaryKey(), + name: text('name').notNull().unique(), + content: jsonb('content').notNull().$type(), + createdAt: timestamp('created_at').defaultNow().notNull(), + updatedAt: timestamp('updated_at').defaultNow().notNull(), +}) + +export type SelectTemplate = typeof templateTable.$inferSelect +export type InsertTemplate = typeof templateTable.$inferInsert diff --git a/src/main.ts b/src/main.ts index 1157093..282666c 100644 --- a/src/main.ts +++ b/src/main.ts @@ -4,19 +4,21 @@ import { ApplyView } from './ApplyView' import { ChatView } from './ChatView' import { ChatProps } from './components/chat-view/Chat' import { APPLY_VIEW_TYPE, CHAT_VIEW_TYPE } from './constants' +import { RAGEngine } from './core/rag/ragEngine' +import { DatabaseManager } from './database/DatabaseManager' import { SmartCopilotSettingTab } from './settings/SettingTab' import { SmartCopilotSettings, parseSmartCopilotSettings, } from './types/settings' import { getMentionableBlockData } from './utils/obsidian' -import { RAGEngine } from './utils/ragEngine' // Remember to rename these classes and interfaces! export default class SmartCopilotPlugin extends Plugin { settings: SmartCopilotSettings initialChatProps?: ChatProps // TODO: change this to use view state like ApplyView settingsChangeListeners: ((newSettings: SmartCopilotSettings) => void)[] = [] + dbManager: DatabaseManager | null = null ragEngine: RAGEngine | null = null async onload() { @@ -112,8 +114,8 @@ export default class SmartCopilotPlugin extends Plugin { } onunload() { - this.ragEngine?.cleanup() - this.ragEngine = null + this.dbManager?.cleanup() + this.dbManager = null } async loadSettings() { @@ -189,9 +191,17 @@ export default class SmartCopilotPlugin extends Plugin { chatView.focusMessage() } + async getDbManager(): Promise { + if (!this.dbManager) { + this.dbManager = await DatabaseManager.create(this.app) + } + return this.dbManager + } + async getRAGEngine(): Promise { + const dbManager = await this.getDbManager() if (!this.ragEngine) { - this.ragEngine = await RAGEngine.create(this.app, this.settings) + this.ragEngine = new RAGEngine(this.app, this.settings, dbManager) } return this.ragEngine } diff --git a/src/utils/promptGenerator.ts b/src/utils/promptGenerator.ts index ce75072..a1bb4dd 100644 --- a/src/utils/promptGenerator.ts +++ b/src/utils/promptGenerator.ts @@ -2,6 +2,7 @@ import { App, TFile, htmlToMarkdown, requestUrl } from 'obsidian' import { editorStateToPlainText } from '../components/chat-view/chat-input/utils/editor-state-to-plain-text' import { QueryProgressState } from '../components/chat-view/QueryProgress' +import { RAGEngine } from '../core/rag/ragEngine' import { ChatMessage, ChatUserMessage } from '../types/chat' import { RequestMessage } from '../types/llm/request' import { @@ -18,7 +19,6 @@ import { readMultipleTFiles, readTFileContent, } from './obsidian' -import { RAGEngine } from './ragEngine' import { tokenCount } from './token' import { YoutubeTranscript, isYoutubeUrl } from './youtube-transcript' diff --git a/src/utils/vector-db/repository.ts b/src/utils/vector-db/repository.ts deleted file mode 100644 index cb3bd2d..0000000 --- a/src/utils/vector-db/repository.ts +++ /dev/null @@ -1,283 +0,0 @@ -import { PGlite } from '@electric-sql/pglite' -import { - SQL, - and, - cosineDistance, - desc, - eq, - getTableColumns, - gt, - inArray, - like, - or, - sql, -} from 'drizzle-orm' -import { PgliteDatabase, drizzle } from 'drizzle-orm/pglite' -import { App, requestUrl } from 'obsidian' - -import migrations from '../../db/migrations.json' -import { InsertVector, SelectVector, vectorTables } from '../../db/schema' - -export class VectorDbRepository { - private app: App - private pgClient: PGlite | null = null - private db: PgliteDatabase | null = null - private dbPath: string - - constructor(app: App, dbPath: string) { - this.app = app - this.dbPath = dbPath - } - - static async create(app: App, dbPath: string): Promise { - const repo = new VectorDbRepository(app, dbPath) - repo.db = await repo.loadExistingDatabase() - if (!repo.db) { - repo.db = await repo.createNewDatabase() - } - await repo.migrateDatabase() - await repo.save() - console.log('Smart composer database initialized.') - return repo - } - - async cleanup() { - this.pgClient?.close() - this.db = null - this.pgClient = null - } - - async getIndexedFilePaths(embeddingModel: { - name: string - }): Promise { - if (!this.db) { - throw new Error('Database not initialized') - } - const vectors = vectorTables[embeddingModel.name] - const indexedFiles = await this.db - .select({ - path: vectors.path, - }) - .from(vectors) - return indexedFiles.map((row) => row.path) - } - - async getVectorsByFilePath( - filePath: string, - embeddingModel: { name: string }, - ): Promise { - if (!this.db) { - throw new Error('Database not initialized') - } - const vectors = vectorTables[embeddingModel.name] - const fileVectors = await this.db - .select() - .from(vectors) - .where(eq(vectors.path, filePath)) - return fileVectors - } - - async deleteVectorsForSingleFile( - filePath: string, - embeddingModel: { name: string }, - ): Promise { - if (!this.db) { - throw new Error('Database not initialized') - } - const vectors = vectorTables[embeddingModel.name] - await this.db.delete(vectors).where(eq(vectors.path, filePath)) - } - - async deleteVectorsForMultipleFiles( - filePaths: string[], - embeddingModel: { name: string }, - ): Promise { - if (!this.db) { - throw new Error('Database not initialized') - } - const vectors = vectorTables[embeddingModel.name] - await this.db.delete(vectors).where(inArray(vectors.path, filePaths)) - } - - async clearAllVectors(embeddingModel: { name: string }): Promise { - if (!this.db) { - throw new Error('Database not initialized') - } - const vectors = vectorTables[embeddingModel.name] - await this.db.delete(vectors) - } - - async insertVectors( - data: InsertVector[], - embeddingModel: { name: string }, - ): Promise { - if (!this.db) { - throw new Error('Database not initialized') - } - const vectors = vectorTables[embeddingModel.name] - await this.db.insert(vectors).values(data) - } - - async performSimilaritySearch( - queryVector: number[], - embeddingModel: { name: string }, - options: { - minSimilarity: number - limit: number - scope?: { - files: string[] - folders: string[] - } - }, - ): Promise< - (Omit & { - similarity: number - })[] - > { - if (!this.db) { - throw new Error('Database not initialized') - } - const vectors = vectorTables[embeddingModel.name] - - const similarity = sql`1 - (${cosineDistance(vectors.embedding, queryVector)})` - const similarityCondition = gt(similarity, options.minSimilarity) - - const getScopeCondition = (): SQL | undefined => { - if (!options.scope) { - return undefined - } - const conditions: (SQL | undefined)[] = [] - if (options.scope.files.length > 0) { - conditions.push(inArray(vectors.path, options.scope.files)) - } - if (options.scope.folders.length > 0) { - conditions.push( - or( - ...options.scope.folders.map((folder) => - like(vectors.path, `${folder}/%`), - ), - ), - ) - } - if (conditions.length === 0) { - return undefined - } - return or(...conditions) - } - const scopeCondition = getScopeCondition() - - const similaritySearchResult = await this.db - .select({ - ...(() => { - const { embedding, ...rest } = getTableColumns(vectors) - return rest - })(), - similarity, - }) - .from(vectors) - .where(and(similarityCondition, scopeCondition)) - .orderBy((t) => desc(t.similarity)) - .limit(options.limit) - - return similaritySearchResult - } - - async save(): Promise { - if (!this.pgClient) { - return - } - try { - const blob: Blob = await this.pgClient.dumpDataDir('gzip') - await this.app.vault.adapter.writeBinary( - this.dbPath, - Buffer.from(await blob.arrayBuffer()), - ) - } catch (error) { - console.error('Error saving database:', error) - } - } - - private async createNewDatabase() { - const { fsBundle, wasmModule, vectorExtensionBundlePath } = - await this.loadPGliteResources() - this.pgClient = await PGlite.create({ - fsBundle: fsBundle, - wasmModule: wasmModule, - extensions: { - vector: vectorExtensionBundlePath, - }, - }) - const db = drizzle(this.pgClient) - return db - } - - private async loadExistingDatabase(): Promise { - try { - const databaseFileExists = await this.app.vault.adapter.exists( - this.dbPath, - ) - if (!databaseFileExists) { - return null - } - const fileBuffer = await this.app.vault.adapter.readBinary(this.dbPath) - const fileBlob = new Blob([fileBuffer], { type: 'application/x-gzip' }) - const { fsBundle, wasmModule, vectorExtensionBundlePath } = - await this.loadPGliteResources() - this.pgClient = await PGlite.create({ - loadDataDir: fileBlob, - fsBundle: fsBundle, - wasmModule: wasmModule, - extensions: { - vector: vectorExtensionBundlePath, - }, - }) - return drizzle(this.pgClient) - } catch (error) { - console.error('Error loading database:', error) - return null - } - } - - private async migrateDatabase(): Promise { - try { - // Workaround for running Drizzle migrations in a browser environment - // This method uses an undocumented API to perform migrations - // See: https://github.com/drizzle-team/drizzle-orm/discussions/2532#discussioncomment-10780523 - // eslint-disable-next-line @typescript-eslint/ban-ts-comment - // @ts-expect-error - await this.db.dialect.migrate(migrations, this.db.session, { - migrationsTable: 'drizzle_migrations', - }) - } catch (error) { - console.error('Error migrating database:', error) - throw error - } - } - - // TODO: This function is a temporary workaround chosen due to the difficulty of bundling postgres.wasm and postgres.data from node_modules into a single JS file. The ultimate goal is to bundle everything into one JS file in the future. - private async loadPGliteResources(): Promise<{ - fsBundle: Blob - wasmModule: WebAssembly.Module - vectorExtensionBundlePath: URL - }> { - try { - const [fsBundleResponse, wasmResponse] = await Promise.all([ - requestUrl('https://unpkg.com/@electric-sql/pglite/dist/postgres.data'), - requestUrl('https://unpkg.com/@electric-sql/pglite/dist/postgres.wasm'), - ]) - - const fsBundle = new Blob([fsBundleResponse.arrayBuffer], { - type: 'application/octet-stream', - }) - const wasmModule = await WebAssembly.compile(wasmResponse.arrayBuffer) - const vectorExtensionBundlePath = new URL( - 'https://unpkg.com/@electric-sql/pglite/dist/vector.tar.gz', - ) - - return { fsBundle, wasmModule, vectorExtensionBundlePath } - } catch (error) { - console.error('Error loading PGlite resources:', error) - throw error - } - } -} diff --git a/styles.css b/styles.css index 44f856b..5a6724f 100644 --- a/styles.css +++ b/styles.css @@ -179,6 +179,7 @@ button:not(.clickable-icon).smtcmp-chat-list-dropdown { } .smtcmp-chat-user-input-container { + position: relative; display: flex; flex-direction: column; -webkit-app-region: no-drag; @@ -190,7 +191,6 @@ button:not(.clickable-icon).smtcmp-chat-list-dropdown { font-size: var(--font-ui-small); border-radius: var(--radius-s); outline: none; - margin-top: var(--size-4-1); &:focus-within, &:focus, @@ -315,7 +315,7 @@ button:not(.clickable-icon).smtcmp-chat-list-dropdown { * ChatUserInput */ -.smtcmp-chat-input-root .mention { +.smtcmp-lexical-content-editable-root .mention { background-color: var(--tag-background); color: var(--tag-color); padding: var(--size-2-1) calc(var(--size-2-1)); @@ -327,7 +327,7 @@ button:not(.clickable-icon).smtcmp-chat-list-dropdown { word-break: break-all; } -.smtcmp-chat-input-paragraph { +.smtcmp-lexical-content-editable-paragraph { margin: 0; line-height: 1.6; } @@ -757,3 +757,90 @@ button.smtcmp-chat-input-model-select { resize: none; } } + +.smtcmp-dialog-content { + position: fixed; + left: calc(50% - var(--size-4-4)); + top: 50%; + z-index: 50; + display: grid; + width: calc(100% - var(--size-4-8)); + max-width: 32rem; + transform: translate(-50%, -50%); + gap: var(--size-4-2); + border: var(--border-width) solid var(--background-modifier-border); + background-color: var(--background-secondary); + padding: var(--size-4-5); + transition-duration: 200ms; + border-radius: var(--radius-m); + box-shadow: 0 25px 50px -12px rgb(0 0 0 / 0.25); + margin: var(--size-4-4); + + .smtcmp-dialog-header { + margin-bottom: var(--size-4-2); + display: grid; + gap: var(--size-2-3); + } + + .smtcmp-dialog-title { + font-size: var(--font-ui-medium); + font-weight: var(--font-semibold); + line-height: var(--line-height-tight); + margin: 0; + } + + .smtcmp-dialog-input { + display: grid; + gap: var(--size-4-1); + + & label { + font-size: var(--font-ui-smaller); + } + } + + .smtcmp-dialog-description { + font-size: var(--font-ui-small); + color: var(--text-muted); + margin: 0; + } + + .smtcmp-dialog-footer { + margin-top: var(--size-4-2); + display: flex; + justify-content: flex-end; + } + + .smtcmp-dialog-close { + position: absolute; + right: var(--size-4-4); + top: var(--size-4-4); + cursor: var(--cursor); + opacity: 0.7; + transition: opacity 0.2s; + + &:hover { + opacity: 1; + } + } +} + +.smtcmp-template-menu-item { + display: flex; + align-items: center; + justify-content: space-between; + gap: var(--size-4-1); + width: 100%; + + .smtcmp-template-menu-item-delete { + display: flex; + align-items: center; + padding: var(--size-4-1); + margin: calc(var(--size-4-1) * -1); + opacity: 0.7; + transition: opacity 0.2s; + + &:hover { + opacity: 1; + } + } +}