{mentionables.length > 0 && (
{mentionables.map((m) => (
@@ -295,59 +249,40 @@ const ChatUserInput = forwardRef(
)}
-
- {/*
- There was two approach to make mentionable node copy and pasteable.
- 1. use RichTextPlugin and reset text format when paste
- - so I implemented NoFormatPlugin to reset text format when paste
- 2. use PlainTextPlugin and override paste command
- - PlainTextPlugin only pastes text, so we need to implement custom paste handler.
- - https://github.com/facebook/lexical/discussions/5112
- */}
-
+ {
+ if (initialSerializedEditorState) {
+ editor.setEditorState(
+ editor.parseEditorState(initialSerializedEditorState),
+ )
}
- ErrorBoundary={LexicalErrorBoundary}
- />
-
- {autoFocus && }
-
- {
- onChange(editorState.toJSON())
- }}
- />
-
- {
- evt.preventDefault()
- evt.stopPropagation()
- handleSubmit(useVaultSearch)
- }}
- />
-
-
-
-
-
+ }}
+ editorRef={editorRef}
+ contentEditableRef={contentEditableRef}
+ onChange={onChange}
+ onEnter={() => handleSubmit({ useVaultSearch: false })}
+ onFocus={onFocus}
+ onMentionNodeMutation={handleMentionNodeMutation}
+ autoFocus={autoFocus}
+ plugins={{
+ onEnter: {
+ onVaultChat: () => {
+ handleSubmit({ useVaultSearch: true })
+ },
+ },
+ templatePopover: {
+ anchorElement: containerRef.current,
+ },
+ }}
+ />
+
handleSubmit()} />
{
- handleSubmit(true)
+ handleSubmit({ useVaultSearch: true })
}}
/>
diff --git a/src/components/chat-view/chat-input/LexicalContentEditable.tsx b/src/components/chat-view/chat-input/LexicalContentEditable.tsx
new file mode 100644
index 0000000..b7947fe
--- /dev/null
+++ b/src/components/chat-view/chat-input/LexicalContentEditable.tsx
@@ -0,0 +1,146 @@
+import {
+ InitialConfigType,
+ InitialEditorStateType,
+ LexicalComposer,
+} from '@lexical/react/LexicalComposer'
+import { ContentEditable } from '@lexical/react/LexicalContentEditable'
+import { EditorRefPlugin } from '@lexical/react/LexicalEditorRefPlugin'
+import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary'
+import { HistoryPlugin } from '@lexical/react/LexicalHistoryPlugin'
+import { OnChangePlugin } from '@lexical/react/LexicalOnChangePlugin'
+import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin'
+import { LexicalEditor, SerializedEditorState } from 'lexical'
+import { RefObject, useCallback, useEffect } from 'react'
+
+import { useApp } from '../../../contexts/app-context'
+import { fuzzySearch } from '../../../utils/fuzzy-search'
+
+import AutoLinkMentionPlugin from './plugins/mention/AutoLinkMentionPlugin'
+import { MentionNode } from './plugins/mention/MentionNode'
+import MentionPlugin from './plugins/mention/MentionPlugin'
+import NoFormatPlugin from './plugins/no-format/NoFormatPlugin'
+import OnEnterPlugin from './plugins/on-enter/OnEnterPlugin'
+import OnMutationPlugin, {
+ NodeMutations,
+} from './plugins/on-mutation/OnMutationPlugin'
+import CreateTemplatePopoverPlugin from './plugins/template/CreateTemplatePopoverPlugin'
+import TemplatePlugin from './plugins/template/TemplatePlugin'
+
+export type LexicalContentEditableProps = {
+ editorRef: RefObject
+ contentEditableRef: RefObject
+ onChange?: (content: SerializedEditorState) => void
+ onEnter?: (evt: KeyboardEvent) => void
+ onFocus?: () => void
+ onMentionNodeMutation?: (mutations: NodeMutations) => void
+ initialEditorState?: InitialEditorStateType
+ autoFocus?: boolean
+ plugins?: {
+ onEnter?: {
+ onVaultChat: () => void
+ }
+ templatePopover?: {
+ anchorElement: HTMLElement | null
+ }
+ }
+}
+
+export default function LexicalContentEditable({
+ editorRef,
+ contentEditableRef,
+ onChange,
+ onEnter,
+ onFocus,
+ onMentionNodeMutation,
+ initialEditorState,
+ autoFocus = false,
+ plugins,
+}: LexicalContentEditableProps) {
+ const app = useApp()
+
+ const initialConfig: InitialConfigType = {
+ namespace: 'LexicalContentEditable',
+ theme: {
+ root: 'smtcmp-lexical-content-editable-root',
+ paragraph: 'smtcmp-lexical-content-editable-paragraph',
+ },
+ nodes: [MentionNode],
+ editorState: initialEditorState,
+ onError: (error) => {
+ console.error(error)
+ },
+ }
+
+ const searchResultByQuery = useCallback(
+ (query: string) => fuzzySearch(app, query),
+ [app],
+ )
+
+ /*
+ * Using requestAnimationFrame for autoFocus instead of using editor.focus()
+ * due to known issues with editor.focus() when initialConfig.editorState is set
+ * See: https://github.com/facebook/lexical/issues/4460
+ */
+ useEffect(() => {
+ if (autoFocus) {
+ requestAnimationFrame(() => {
+ contentEditableRef.current?.focus()
+ })
+ }
+ }, [autoFocus, contentEditableRef])
+
+ return (
+
+ {/*
+ There was two approach to make mentionable node copy and pasteable.
+ 1. use RichTextPlugin and reset text format when paste
+ - so I implemented NoFormatPlugin to reset text format when paste
+ 2. use PlainTextPlugin and override paste command
+ - PlainTextPlugin only pastes text, so we need to implement custom paste handler.
+ - https://github.com/facebook/lexical/discussions/5112
+ */}
+
+ }
+ ErrorBoundary={LexicalErrorBoundary}
+ />
+
+
+ {
+ onChange?.(editorState.toJSON())
+ }}
+ />
+ {onEnter && (
+
+ )}
+ {
+ onMentionNodeMutation?.(mutations)
+ }}
+ />
+
+
+
+
+ {plugins?.templatePopover && (
+
+ )}
+
+ )
+}
diff --git a/src/components/chat-view/chat-input/ModelSelect.tsx b/src/components/chat-view/chat-input/ModelSelect.tsx
index 883e334..0c4be54 100644
--- a/src/components/chat-view/chat-input/ModelSelect.tsx
+++ b/src/components/chat-view/chat-input/ModelSelect.tsx
@@ -1,8 +1,9 @@
import * as DropdownMenu from '@radix-ui/react-dropdown-menu'
import { ChevronDown, ChevronUp } from 'lucide-react'
import { useState } from 'react'
-import { CHAT_MODEL_OPTIONS } from 'src/constants'
-import { useSettings } from 'src/contexts/settings-context'
+
+import { CHAT_MODEL_OPTIONS } from '../../../constants'
+import { useSettings } from '../../../contexts/settings-context'
export function ModelSelect() {
const { settings, setSettings } = useSettings()
diff --git a/src/components/chat-view/chat-input/plugins/auto-focus/AutoFocusPlugin.tsx b/src/components/chat-view/chat-input/plugins/auto-focus/AutoFocusPlugin.tsx
deleted file mode 100644
index 714a990..0000000
--- a/src/components/chat-view/chat-input/plugins/auto-focus/AutoFocusPlugin.tsx
+++ /dev/null
@@ -1,29 +0,0 @@
-import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
-import { useEffect } from 'react'
-
-export type AutoFocusPluginProps = {
- defaultSelection?: 'rootStart' | 'rootEnd'
-}
-
-export default function AutoFocusPlugin({
- defaultSelection,
-}: AutoFocusPluginProps) {
- const [editor] = useLexicalComposerContext()
-
- useEffect(() => {
- editor.focus(
- () => {
- const rootElement = editor.getRootElement()
- if (rootElement) {
- // requestAnimationFrame is required here for unknown reasons, possibly related to the Obsidian plugin environment.
- requestAnimationFrame(() => {
- rootElement.focus()
- })
- }
- },
- { defaultSelection },
- )
- }, [defaultSelection, editor])
-
- return null
-}
diff --git a/src/components/chat-view/chat-input/plugins/mention/MentionPlugin.tsx b/src/components/chat-view/chat-input/plugins/mention/MentionPlugin.tsx
index 603d023..a41693a 100644
--- a/src/components/chat-view/chat-input/plugins/mention/MentionPlugin.tsx
+++ b/src/components/chat-view/chat-input/plugins/mention/MentionPlugin.tsx
@@ -11,11 +11,13 @@ import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext
import { $createTextNode, COMMAND_PRIORITY_NORMAL, TextNode } from 'lexical'
import { useCallback, useMemo, useState } from 'react'
import { createPortal } from 'react-dom'
-import { serializeMentionable } from 'src/utils/mentionable'
import { Mentionable } from '../../../../../types/mentionable'
import { SearchableMentionable } from '../../../../../utils/fuzzy-search'
-import { getMentionableName } from '../../../../../utils/mentionable'
+import {
+ getMentionableName,
+ serializeMentionable,
+} from '../../../../../utils/mentionable'
import { getMentionableIcon } from '../../utils/get-metionable-icon'
import { MenuOption, MenuTextMatch } from '../shared/LexicalMenu'
import {
diff --git a/src/components/chat-view/chat-input/plugins/on-enter/OnEnterPlugin.tsx b/src/components/chat-view/chat-input/plugins/on-enter/OnEnterPlugin.tsx
index 899fa08..8cfa78d 100644
--- a/src/components/chat-view/chat-input/plugins/on-enter/OnEnterPlugin.tsx
+++ b/src/components/chat-view/chat-input/plugins/on-enter/OnEnterPlugin.tsx
@@ -5,8 +5,10 @@ import { useEffect } from 'react'
export default function OnEnterPlugin({
onEnter,
+ onVaultChat,
}: {
- onEnter: (evt: KeyboardEvent, useVaultSearch?: boolean) => void
+ onEnter: (evt: KeyboardEvent) => void
+ onVaultChat?: () => void
}) {
const [editor] = useLexicalComposerContext()
@@ -14,14 +16,22 @@ export default function OnEnterPlugin({
const removeListener = editor.registerCommand(
KEY_ENTER_COMMAND,
(evt: KeyboardEvent) => {
+ if (
+ onVaultChat &&
+ evt.shiftKey &&
+ (Platform.isMacOS ? evt.metaKey : evt.ctrlKey)
+ ) {
+ evt.preventDefault()
+ evt.stopPropagation()
+ onVaultChat()
+ return true
+ }
if (evt.shiftKey) {
- if (Platform.isMacOS ? evt.metaKey : evt.ctrlKey) {
- onEnter(evt, true)
- return true
- }
return false
}
- onEnter(evt, false)
+ evt.preventDefault()
+ evt.stopPropagation()
+ onEnter(evt)
return true
},
COMMAND_PRIORITY_LOW,
@@ -30,7 +40,7 @@ export default function OnEnterPlugin({
return () => {
removeListener()
}
- }, [editor, onEnter])
+ }, [editor, onEnter, onVaultChat])
return null
}
diff --git a/src/components/chat-view/chat-input/plugins/template/CreateTemplatePopoverPlugin.tsx b/src/components/chat-view/chat-input/plugins/template/CreateTemplatePopoverPlugin.tsx
new file mode 100644
index 0000000..a36157e
--- /dev/null
+++ b/src/components/chat-view/chat-input/plugins/template/CreateTemplatePopoverPlugin.tsx
@@ -0,0 +1,132 @@
+import { $generateJSONFromSelectedNodes } from '@lexical/clipboard'
+import { BaseSerializedNode } from '@lexical/clipboard/clipboard'
+import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
+import * as Dialog from '@radix-ui/react-dialog'
+import {
+ $getSelection,
+ COMMAND_PRIORITY_LOW,
+ SELECTION_CHANGE_COMMAND,
+} from 'lexical'
+import { CSSProperties, useCallback, useEffect, useRef, useState } from 'react'
+
+import CreateTemplateDialogContent from '../../../CreateTemplateDialog'
+
+export default function CreateTemplatePopoverPlugin({
+ anchorElement,
+ contentEditableElement,
+}: {
+ anchorElement: HTMLElement | null
+ contentEditableElement: HTMLElement | null
+}): JSX.Element | null {
+ const [editor] = useLexicalComposerContext()
+
+ const [popoverStyle, setPopoverStyle] = useState(null)
+ const [isPopoverOpen, setIsPopoverOpen] = useState(false)
+ const [isDialogOpen, setIsDialogOpen] = useState(false)
+ const [selectedSerializedNodes, setSelectedSerializedNodes] = useState<
+ BaseSerializedNode[] | null
+ >(null)
+
+ const popoverRef = useRef(null)
+
+ const getSelectedSerializedNodes = useCallback(():
+ | BaseSerializedNode[]
+ | null => {
+ if (!editor) return null
+ let selectedNodes: BaseSerializedNode[] | null = null
+ editor.update(() => {
+ const selection = $getSelection()
+ if (!selection) return
+ selectedNodes = $generateJSONFromSelectedNodes(editor, selection).nodes
+ if (selectedNodes.length === 0) return null
+ })
+ return selectedNodes
+ }, [editor])
+
+ const updatePopoverPosition = useCallback(() => {
+ if (!anchorElement || !contentEditableElement) return
+ const nativeSelection = document.getSelection()
+ const range = nativeSelection?.getRangeAt(0)
+ if (!range || range.collapsed) {
+ setIsPopoverOpen(false)
+ return
+ }
+ if (!contentEditableElement.contains(range.commonAncestorContainer)) {
+ setIsPopoverOpen(false)
+ return
+ }
+ const rects = Array.from(range.getClientRects())
+ if (rects.length === 0) {
+ setIsPopoverOpen(false)
+ return
+ }
+ const anchorRect = anchorElement.getBoundingClientRect()
+ const idealLeft = rects[rects.length - 1].right - anchorRect.left
+ const paddingX = 8
+ const paddingY = 4
+ const minLeft = (popoverRef.current?.offsetWidth ?? 0) + paddingX
+ const finalLeft = Math.max(minLeft, idealLeft)
+ setPopoverStyle({
+ top: rects[rects.length - 1].bottom - anchorRect.top + paddingY,
+ left: finalLeft,
+ transform: 'translate(-100%, 0)',
+ })
+ setIsPopoverOpen(true)
+ }, [anchorElement, contentEditableElement])
+
+ useEffect(() => {
+ const removeSelectionChangeListener = editor.registerCommand(
+ SELECTION_CHANGE_COMMAND,
+ () => {
+ updatePopoverPosition()
+ return false
+ },
+ COMMAND_PRIORITY_LOW,
+ )
+ return () => {
+ removeSelectionChangeListener()
+ }
+ }, [editor, updatePopoverPosition])
+
+ useEffect(() => {
+ if (!contentEditableElement) return
+ const handleScroll = () => {
+ updatePopoverPosition()
+ }
+ contentEditableElement.addEventListener('scroll', handleScroll)
+ return () => {
+ contentEditableElement.removeEventListener('scroll', handleScroll)
+ }
+ }, [contentEditableElement, updatePopoverPosition])
+
+ return (
+ {
+ if (open) {
+ setSelectedSerializedNodes(getSelectedSerializedNodes())
+ }
+ setIsDialogOpen(open)
+ setIsPopoverOpen(false)
+ }}
+ >
+
+
+
+ setIsDialogOpen(false)}
+ />
+
+ )
+}
diff --git a/src/components/chat-view/chat-input/plugins/template/TemplatePlugin.tsx b/src/components/chat-view/chat-input/plugins/template/TemplatePlugin.tsx
new file mode 100644
index 0000000..1f6eb38
--- /dev/null
+++ b/src/components/chat-view/chat-input/plugins/template/TemplatePlugin.tsx
@@ -0,0 +1,179 @@
+import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
+import clsx from 'clsx'
+import {
+ $parseSerializedNode,
+ COMMAND_PRIORITY_NORMAL,
+ TextNode,
+} from 'lexical'
+import { Trash2 } from 'lucide-react'
+import { useCallback, useEffect, useMemo, useState } from 'react'
+import { createPortal } from 'react-dom'
+
+import { useDatabase } from '../../../../../contexts/database-context'
+import { SelectTemplate } from '../../../../../database/schema'
+import { MenuOption } from '../shared/LexicalMenu'
+import {
+ LexicalTypeaheadMenuPlugin,
+ useBasicTypeaheadTriggerMatch,
+} from '../typeahead-menu/LexicalTypeaheadMenuPlugin'
+
+class TemplateTypeaheadOption extends MenuOption {
+ name: string
+ template: SelectTemplate
+
+ constructor(name: string, template: SelectTemplate) {
+ super(name)
+ this.name = name
+ this.template = template
+ }
+}
+
+function TemplateMenuItem({
+ index,
+ isSelected,
+ onClick,
+ onDelete,
+ onMouseEnter,
+ option,
+}: {
+ index: number
+ isSelected: boolean
+ onClick: () => void
+ onDelete: () => void
+ onMouseEnter: () => void
+ option: TemplateTypeaheadOption
+}) {
+ return (
+ option.setRefElement(el)}
+ role="option"
+ aria-selected={isSelected}
+ id={`typeahead-item-${index}`}
+ onMouseEnter={onMouseEnter}
+ onClick={onClick}
+ >
+
+
{option.name}
+
{
+ evt.stopPropagation()
+ evt.preventDefault()
+ onDelete()
+ }}
+ className="smtcmp-template-menu-item-delete"
+ >
+
+
+
+
+ )
+}
+
+export default function TemplatePlugin() {
+ const [editor] = useLexicalComposerContext()
+ const { templateManager } = useDatabase()
+
+ const [queryString, setQueryString] = useState(null)
+ const [searchResults, setSearchResults] = useState([])
+
+ useEffect(() => {
+ if (queryString == null) return
+ templateManager.searchTemplates(queryString).then(setSearchResults)
+ }, [queryString, templateManager])
+
+ const options = useMemo(
+ () =>
+ searchResults.map(
+ (result) => new TemplateTypeaheadOption(result.name, result),
+ ),
+ [searchResults],
+ )
+
+ const checkForTriggerMatch = useBasicTypeaheadTriggerMatch('/', {
+ minLength: 0,
+ })
+
+ const onSelectOption = useCallback(
+ (
+ selectedOption: TemplateTypeaheadOption,
+ nodeToRemove: TextNode | null,
+ closeMenu: () => void,
+ ) => {
+ editor.update(() => {
+ const parsedNodes = selectedOption.template.content.nodes.map((node) =>
+ $parseSerializedNode(node),
+ )
+ if (nodeToRemove) {
+ const parent = nodeToRemove.getParentOrThrow()
+ parent.splice(nodeToRemove.getIndexWithinParent(), 1, parsedNodes)
+ const lastNode = parsedNodes[parsedNodes.length - 1]
+ lastNode.selectEnd()
+ }
+ closeMenu()
+ })
+ },
+ [editor],
+ )
+
+ const handleDelete = useCallback(
+ async (option: TemplateTypeaheadOption) => {
+ await templateManager.deleteTemplate(option.template.id)
+ if (queryString !== null) {
+ const updatedResults =
+ await templateManager.searchTemplates(queryString)
+ setSearchResults(updatedResults)
+ }
+ },
+ [templateManager, queryString],
+ )
+
+ return (
+
+ onQueryChange={setQueryString}
+ onSelectOption={onSelectOption}
+ triggerFn={checkForTriggerMatch}
+ options={options}
+ commandPriority={COMMAND_PRIORITY_NORMAL}
+ menuRenderFn={(
+ anchorElementRef,
+ { selectedIndex, selectOptionAndCleanUp, setHighlightedIndex },
+ ) =>
+ anchorElementRef.current && searchResults.length
+ ? createPortal(
+
+
+ {options.map((option, i: number) => (
+ {
+ setHighlightedIndex(i)
+ selectOptionAndCleanUp(option)
+ }}
+ onDelete={() => {
+ handleDelete(option)
+ }}
+ onMouseEnter={() => {
+ setHighlightedIndex(i)
+ }}
+ key={option.key}
+ option={option}
+ />
+ ))}
+
+
,
+ anchorElementRef.current,
+ )
+ : null
+ }
+ />
+ )
+}
diff --git a/src/components/chat-view/chat-input/plugins/updater/UpdaterPlugin.tsx b/src/components/chat-view/chat-input/plugins/updater/UpdaterPlugin.tsx
deleted file mode 100644
index f767d0f..0000000
--- a/src/components/chat-view/chat-input/plugins/updater/UpdaterPlugin.tsx
+++ /dev/null
@@ -1,23 +0,0 @@
-import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
-import { SerializedEditorState } from 'lexical'
-import { Ref, useImperativeHandle } from 'react'
-
-export type UpdaterPluginRef = {
- update: (content: SerializedEditorState) => void
-}
-
-export default function UpdaterPlugin({
- updaterRef,
-}: {
- updaterRef: Ref
-}) {
- const [editor] = useLexicalComposerContext()
-
- useImperativeHandle(updaterRef, () => ({
- update: (content: SerializedEditorState) => {
- editor.setEditorState(editor.parseEditorState(content))
- },
- }))
-
- return null
-}
diff --git a/src/constants.ts b/src/constants.ts
index 9c60545..987fb5e 100644
--- a/src/constants.ts
+++ b/src/constants.ts
@@ -47,7 +47,7 @@ export const APPLY_MODEL_OPTIONS = [
},
]
-// Update table exports in src/db/schema.ts when updating this
+// Update table exports in src/database/schema.ts when updating this
export const EMBEDDING_MODEL_OPTIONS = [
{
name: 'text-embedding-3-small (Recommended)',
diff --git a/src/contexts/database-context.tsx b/src/contexts/database-context.tsx
new file mode 100644
index 0000000..54703f2
--- /dev/null
+++ b/src/contexts/database-context.tsx
@@ -0,0 +1,45 @@
+import { createContext, useContext, useMemo } from 'react'
+
+import { DatabaseManager } from '../database/DatabaseManager'
+import { TemplateManager } from '../database/modules/template/TemplateManager'
+import { VectorManager } from '../database/modules/vector/VectorManager'
+
+type DatabaseContextType = {
+ databaseManager: DatabaseManager
+ vectorManager: VectorManager
+ templateManager: TemplateManager
+}
+
+const DatabaseContext = createContext(null)
+
+export function DatabaseProvider({
+ children,
+ databaseManager,
+}: {
+ children: React.ReactNode
+ databaseManager: DatabaseManager
+}) {
+ const vectorManager = useMemo(() => {
+ return databaseManager.getVectorManager()
+ }, [databaseManager])
+
+ const templateManager = useMemo(() => {
+ return databaseManager.getTemplateManager()
+ }, [databaseManager])
+
+ return (
+
+ {children}
+
+ )
+}
+
+export function useDatabase(): DatabaseContextType {
+ const context = useContext(DatabaseContext)
+ if (!context) {
+ throw new Error('useDatabase must be used within a DatabaseProvider')
+ }
+ return context
+}
diff --git a/src/contexts/dialog-container-context.tsx b/src/contexts/dialog-container-context.tsx
new file mode 100644
index 0000000..9ff59f0
--- /dev/null
+++ b/src/contexts/dialog-container-context.tsx
@@ -0,0 +1,27 @@
+import React, { createContext, useContext } from 'react'
+
+const DialogContainerContext = createContext(null)
+
+export function DialogContainerProvider({
+ children,
+ container,
+}: {
+ children: React.ReactNode
+ container: HTMLElement | null
+}) {
+ return (
+
+ {children}
+
+ )
+}
+
+export function useDialogContainer() {
+ const context = useContext(DialogContainerContext)
+ if (!context) {
+ throw new Error(
+ 'useDialogContainer must be used within a DialogContainerProvider',
+ )
+ }
+ return context
+}
diff --git a/src/contexts/llm-context.tsx b/src/contexts/llm-context.tsx
index 08ba16b..9e6e966 100644
--- a/src/contexts/llm-context.tsx
+++ b/src/contexts/llm-context.tsx
@@ -7,6 +7,7 @@ import {
useState,
} from 'react'
+import LLMManager from '../core/llm/manager'
import {
LLMOptions,
LLMRequestNonStreaming,
@@ -16,7 +17,6 @@ import {
LLMResponseNonStreaming,
LLMResponseStreaming,
} from '../types/llm/response'
-import LLMManager from '../utils/llm/manager'
import { useSettings } from './settings-context'
diff --git a/src/contexts/rag-context.tsx b/src/contexts/rag-context.tsx
index 777d39c..f969225 100644
--- a/src/contexts/rag-context.tsx
+++ b/src/contexts/rag-context.tsx
@@ -1,6 +1,6 @@
import { PropsWithChildren, createContext, useContext } from 'react'
-import { RAGEngine } from '../utils/ragEngine'
+import { RAGEngine } from '../core/rag/ragEngine'
export type RAGContextType = {
ragEngine: RAGEngine | null
diff --git a/src/utils/llm/anthropic.ts b/src/core/llm/anthropic.ts
similarity index 100%
rename from src/utils/llm/anthropic.ts
rename to src/core/llm/anthropic.ts
diff --git a/src/utils/llm/base.ts b/src/core/llm/base.ts
similarity index 100%
rename from src/utils/llm/base.ts
rename to src/core/llm/base.ts
diff --git a/src/utils/llm/exception.ts b/src/core/llm/exception.ts
similarity index 100%
rename from src/utils/llm/exception.ts
rename to src/core/llm/exception.ts
diff --git a/src/utils/llm/groq.ts b/src/core/llm/groq.ts
similarity index 100%
rename from src/utils/llm/groq.ts
rename to src/core/llm/groq.ts
diff --git a/src/utils/llm/manager.ts b/src/core/llm/manager.ts
similarity index 100%
rename from src/utils/llm/manager.ts
rename to src/core/llm/manager.ts
diff --git a/src/utils/llm/ollama.ts b/src/core/llm/ollama.ts
similarity index 97%
rename from src/utils/llm/ollama.ts
rename to src/core/llm/ollama.ts
index 20ab58a..f852e78 100644
--- a/src/utils/llm/ollama.ts
+++ b/src/core/llm/ollama.ts
@@ -1,14 +1,15 @@
import OpenAI from 'openai'
import { FinalRequestOptions } from 'openai/core'
+
import {
LLMOptions,
LLMRequestNonStreaming,
LLMRequestStreaming,
-} from 'src/types/llm/request'
+} from '../../types/llm/request'
import {
LLMResponseNonStreaming,
LLMResponseStreaming,
-} from 'src/types/llm/response'
+} from '../../types/llm/response'
import { BaseLLMProvider } from './base'
import { LLMBaseUrlNotSetException } from './exception'
diff --git a/src/utils/llm/openai.ts b/src/core/llm/openai.ts
similarity index 100%
rename from src/utils/llm/openai.ts
rename to src/core/llm/openai.ts
diff --git a/src/utils/llm/openaiCompatibleProvider.ts b/src/core/llm/openaiCompatibleProvider.ts
similarity index 100%
rename from src/utils/llm/openaiCompatibleProvider.ts
rename to src/core/llm/openaiCompatibleProvider.ts
diff --git a/src/utils/embedding.ts b/src/core/rag/embedding.ts
similarity index 96%
rename from src/utils/embedding.ts
rename to src/core/rag/embedding.ts
index 4a502fe..965504c 100644
--- a/src/utils/embedding.ts
+++ b/src/core/rag/embedding.ts
@@ -1,12 +1,11 @@
import { OpenAI } from 'openai'
-import { EmbeddingModel } from '../types/embedding'
-
+import { EmbeddingModel } from '../../types/embedding'
import {
LLMAPIKeyNotSetException,
LLMBaseUrlNotSetException,
-} from './llm/exception'
-import { NoStainlessOpenAI } from './llm/ollama'
+} from '../llm/exception'
+import { NoStainlessOpenAI } from '../llm/ollama'
export const getEmbeddingModel = (
name: string,
diff --git a/src/utils/ragEngine.ts b/src/core/rag/ragEngine.ts
similarity index 76%
rename from src/utils/ragEngine.ts
rename to src/core/rag/ragEngine.ts
index ce5538d..367012c 100644
--- a/src/utils/ragEngine.ts
+++ b/src/core/rag/ragEngine.ts
@@ -1,42 +1,35 @@
import { App } from 'obsidian'
-import { QueryProgressState } from '../components/chat-view/QueryProgress'
-import { SelectVector } from '../db/schema'
-import { EmbeddingModel } from '../types/embedding'
-import { SmartCopilotSettings } from '../types/settings'
+import { QueryProgressState } from '../../components/chat-view/QueryProgress'
+import { DatabaseManager } from '../../database/DatabaseManager'
+import { VectorManager } from '../../database/modules/vector/VectorManager'
+import { SelectVector } from '../../database/schema'
+import { EmbeddingModel } from '../../types/embedding'
+import { SmartCopilotSettings } from '../../types/settings'
import { getEmbeddingModel } from './embedding'
-import { VectorDbManager } from './vector-db/manager'
export class RAGEngine {
private app: App
private settings: SmartCopilotSettings
- private vectorDbManager: VectorDbManager
+ private vectorManager: VectorManager
private embeddingModel: EmbeddingModel | null = null
- constructor(app: App, settings: SmartCopilotSettings) {
- this.app = app
- this.settings = settings
- }
-
- static async create(
+ constructor(
app: App,
settings: SmartCopilotSettings,
- ): Promise {
- const ragEngine = new RAGEngine(app, settings)
- ragEngine.vectorDbManager = await VectorDbManager.create(app)
- ragEngine.embeddingModel = getEmbeddingModel(
+ dbManager: DatabaseManager,
+ ) {
+ this.app = app
+ this.settings = settings
+ this.vectorManager = dbManager.getVectorManager()
+ this.embeddingModel = getEmbeddingModel(
settings.embeddingModel,
{
openAIApiKey: settings.openAIApiKey,
},
settings.ollamaBaseUrl,
)
- return ragEngine
- }
-
- async cleanup() {
- await this.vectorDbManager.cleanup()
}
setSettings(settings: SmartCopilotSettings) {
@@ -61,7 +54,7 @@ export class RAGEngine {
if (!this.embeddingModel) {
throw new Error('Embedding model is not set')
}
- await this.vectorDbManager.updateVaultIndex(
+ await this.vectorManager.updateVaultIndex(
this.embeddingModel,
{
chunkSize: this.settings.ragOptions.chunkSize,
@@ -102,7 +95,7 @@ export class RAGEngine {
onQueryProgressChange?.({
type: 'querying',
})
- const queryResult = await this.vectorDbManager.performSimilaritySearch(
+ const queryResult = await this.vectorManager.performSimilaritySearch(
queryEmbedding,
this.embeddingModel,
{
diff --git a/src/database/DatabaseManager.ts b/src/database/DatabaseManager.ts
new file mode 100644
index 0000000..40286ec
--- /dev/null
+++ b/src/database/DatabaseManager.ts
@@ -0,0 +1,156 @@
+import { PGlite } from '@electric-sql/pglite'
+import { PgliteDatabase, drizzle } from 'drizzle-orm/pglite'
+import { App, normalizePath, requestUrl } from 'obsidian'
+
+import { PGLITE_DB_PATH } from '../constants'
+
+import migrations from './migrations.json'
+import { TemplateManager } from './modules/template/TemplateManager'
+import { VectorManager } from './modules/vector/VectorManager'
+
+export class DatabaseManager {
+ private app: App
+ private dbPath: string
+ private pgClient: PGlite | null = null
+ private db: PgliteDatabase | null = null
+ private vectorManager: VectorManager
+ private templateManager: TemplateManager
+
+ constructor(app: App, dbPath: string) {
+ this.app = app
+ this.dbPath = dbPath
+ }
+
+ static async create(app: App): Promise {
+ const dbManager = new DatabaseManager(app, normalizePath(PGLITE_DB_PATH))
+ dbManager.db = await dbManager.loadExistingDatabase()
+ if (!dbManager.db) {
+ dbManager.db = await dbManager.createNewDatabase()
+ }
+ await dbManager.migrateDatabase()
+ await dbManager.save()
+
+ dbManager.vectorManager = new VectorManager(app, dbManager)
+ dbManager.templateManager = new TemplateManager(app, dbManager)
+
+ console.log('Smart composer database initialized.')
+ return dbManager
+ }
+
+ getDb() {
+ return this.db
+ }
+
+ getVectorManager() {
+ return this.vectorManager
+ }
+
+ getTemplateManager() {
+ return this.templateManager
+ }
+
+ private async createNewDatabase() {
+ const { fsBundle, wasmModule, vectorExtensionBundlePath } =
+ await this.loadPGliteResources()
+ this.pgClient = await PGlite.create({
+ fsBundle: fsBundle,
+ wasmModule: wasmModule,
+ extensions: {
+ vector: vectorExtensionBundlePath,
+ },
+ })
+ const db = drizzle(this.pgClient)
+ return db
+ }
+
+ private async loadExistingDatabase(): Promise {
+ try {
+ const databaseFileExists = await this.app.vault.adapter.exists(
+ this.dbPath,
+ )
+ if (!databaseFileExists) {
+ return null
+ }
+ const fileBuffer = await this.app.vault.adapter.readBinary(this.dbPath)
+ const fileBlob = new Blob([fileBuffer], { type: 'application/x-gzip' })
+ const { fsBundle, wasmModule, vectorExtensionBundlePath } =
+ await this.loadPGliteResources()
+ this.pgClient = await PGlite.create({
+ loadDataDir: fileBlob,
+ fsBundle: fsBundle,
+ wasmModule: wasmModule,
+ extensions: {
+ vector: vectorExtensionBundlePath,
+ },
+ })
+ return drizzle(this.pgClient)
+ } catch (error) {
+ console.error('Error loading database:', error)
+ return null
+ }
+ }
+
+ private async migrateDatabase(): Promise {
+ try {
+ // Workaround for running Drizzle migrations in a browser environment
+ // This method uses an undocumented API to perform migrations
+ // See: https://github.com/drizzle-team/drizzle-orm/discussions/2532#discussioncomment-10780523
+ // eslint-disable-next-line @typescript-eslint/ban-ts-comment
+ // @ts-expect-error
+ await this.db.dialect.migrate(migrations, this.db.session, {
+ migrationsTable: 'drizzle_migrations',
+ })
+ } catch (error) {
+ console.error('Error migrating database:', error)
+ throw error
+ }
+ }
+
+ async save(): Promise {
+ if (!this.pgClient) {
+ return
+ }
+ try {
+ const blob: Blob = await this.pgClient.dumpDataDir('gzip')
+ await this.app.vault.adapter.writeBinary(
+ this.dbPath,
+ Buffer.from(await blob.arrayBuffer()),
+ )
+ } catch (error) {
+ console.error('Error saving database:', error)
+ }
+ }
+
+ async cleanup() {
+ this.pgClient?.close()
+ this.db = null
+ this.pgClient = null
+ }
+
+ // TODO: This function is a temporary workaround chosen due to the difficulty of bundling postgres.wasm and postgres.data from node_modules into a single JS file. The ultimate goal is to bundle everything into one JS file in the future.
+ private async loadPGliteResources(): Promise<{
+ fsBundle: Blob
+ wasmModule: WebAssembly.Module
+ vectorExtensionBundlePath: URL
+ }> {
+ try {
+ const [fsBundleResponse, wasmResponse] = await Promise.all([
+ requestUrl('https://unpkg.com/@electric-sql/pglite/dist/postgres.data'),
+ requestUrl('https://unpkg.com/@electric-sql/pglite/dist/postgres.wasm'),
+ ])
+
+ const fsBundle = new Blob([fsBundleResponse.arrayBuffer], {
+ type: 'application/octet-stream',
+ })
+ const wasmModule = await WebAssembly.compile(wasmResponse.arrayBuffer)
+ const vectorExtensionBundlePath = new URL(
+ 'https://unpkg.com/@electric-sql/pglite/dist/vector.tar.gz',
+ )
+
+ return { fsBundle, wasmModule, vectorExtensionBundlePath }
+ } catch (error) {
+ console.error('Error loading PGlite resources:', error)
+ throw error
+ }
+ }
+}
diff --git a/src/database/exception.ts b/src/database/exception.ts
new file mode 100644
index 0000000..3838a94
--- /dev/null
+++ b/src/database/exception.ts
@@ -0,0 +1,20 @@
+export class DatabaseException extends Error {
+ constructor(message: string) {
+ super(message)
+ this.name = 'DatabaseException'
+ }
+}
+
+export class DatabaseNotInitializedException extends DatabaseException {
+ constructor(message = 'Database not initialized') {
+ super(message)
+ this.name = 'DatabaseNotInitializedException'
+ }
+}
+
+export class DuplicateTemplateException extends DatabaseException {
+ constructor(templateName: string) {
+ super(`Template with name "${templateName}" already exists`)
+ this.name = 'DuplicateTemplateException'
+ }
+}
diff --git a/src/db/migrations.json b/src/database/migrations.json
similarity index 85%
rename from src/db/migrations.json
rename to src/database/migrations.json
index ddd7d53..0c14f03 100644
--- a/src/db/migrations.json
+++ b/src/database/migrations.json
@@ -43,5 +43,13 @@
"bps": true,
"folderMillis": 1730085029605,
"hash": "e41dde00d7d98da74596c60e796aa78d5cbfec57208347c72f16a6fa11c24ef1"
+ },
+ {
+ "sql": [
+ "CREATE TABLE IF NOT EXISTS \"template\" (\n\t\"id\" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,\n\t\"name\" text NOT NULL,\n\t\"content\" jsonb NOT NULL,\n\t\"created_at\" timestamp DEFAULT now() NOT NULL,\n\t\"updated_at\" timestamp DEFAULT now() NOT NULL,\n\tCONSTRAINT \"template_name_unique\" UNIQUE(\"name\")\n);\n"
+ ],
+ "bps": true,
+ "folderMillis": 1730443763982,
+ "hash": "eddf72b8d40619c170b3c12f3d3ce280385b1fc05717f211ab6077c6bea691bf"
}
]
diff --git a/src/database/modules/template/TemplateManager.ts b/src/database/modules/template/TemplateManager.ts
new file mode 100644
index 0000000..df54ebe
--- /dev/null
+++ b/src/database/modules/template/TemplateManager.ts
@@ -0,0 +1,51 @@
+import fuzzysort from 'fuzzysort'
+import { App } from 'obsidian'
+
+import { DatabaseManager } from '../../DatabaseManager'
+import { DuplicateTemplateException } from '../../exception'
+import { InsertTemplate, SelectTemplate } from '../../schema'
+
+import { TemplateRepository } from './TemplateRepository'
+
+export class TemplateManager {
+ private app: App
+ private repository: TemplateRepository
+ private dbManager: DatabaseManager
+
+ constructor(app: App, dbManager: DatabaseManager) {
+ this.app = app
+ this.dbManager = dbManager
+ this.repository = new TemplateRepository(app, dbManager.getDb())
+ }
+
+ async createTemplate(template: InsertTemplate): Promise {
+ const existingTemplate = await this.repository.findByName(template.name)
+ if (existingTemplate) {
+ throw new DuplicateTemplateException(template.name)
+ }
+ const created = await this.repository.create(template)
+ await this.dbManager.save()
+ return created
+ }
+
+ async findAllTemplates(): Promise {
+ return await this.repository.findAll()
+ }
+
+ async searchTemplates(query: string): Promise {
+ const templates = await this.findAllTemplates()
+ const results = fuzzysort.go(query, templates, {
+ keys: ['name'],
+ threshold: 0.2,
+ limit: 20,
+ all: true,
+ })
+ return results.map((result) => result.obj)
+ }
+
+ async deleteTemplate(id: string): Promise {
+ const deleted = await this.repository.delete(id)
+ await this.dbManager.save()
+ return deleted
+ }
+}
diff --git a/src/database/modules/template/TemplateRepository.ts b/src/database/modules/template/TemplateRepository.ts
new file mode 100644
index 0000000..3c32927
--- /dev/null
+++ b/src/database/modules/template/TemplateRepository.ts
@@ -0,0 +1,76 @@
+import { eq } from 'drizzle-orm'
+import { PgliteDatabase } from 'drizzle-orm/pglite'
+import { App } from 'obsidian'
+
+import { DatabaseNotInitializedException } from '../../exception'
+import {
+ type InsertTemplate,
+ type SelectTemplate,
+ templateTable,
+} from '../../schema'
+
+export class TemplateRepository {
+ private app: App
+ private db: PgliteDatabase | null
+
+ constructor(app: App, db: PgliteDatabase | null) {
+ this.app = app
+ this.db = db
+ }
+
+ async create(template: InsertTemplate): Promise {
+ if (!this.db) {
+ throw new DatabaseNotInitializedException()
+ }
+
+ const [created] = await this.db
+ .insert(templateTable)
+ .values(template)
+ .returning()
+ return created
+ }
+
+ async findAll(): Promise {
+ if (!this.db) {
+ throw new DatabaseNotInitializedException()
+ }
+ return await this.db.select().from(templateTable)
+ }
+
+ async findByName(name: string): Promise {
+ if (!this.db) {
+ throw new DatabaseNotInitializedException()
+ }
+ const [template] = await this.db
+ .select()
+ .from(templateTable)
+ .where(eq(templateTable.name, name))
+ return template ?? null
+ }
+
+ async update(
+ id: string,
+ template: Partial,
+ ): Promise {
+ if (!this.db) {
+ throw new DatabaseNotInitializedException()
+ }
+ const [updated] = await this.db
+ .update(templateTable)
+ .set({ ...template, updatedAt: new Date() })
+ .where(eq(templateTable.id, id))
+ .returning()
+ return updated
+ }
+
+ async delete(id: string): Promise {
+ if (!this.db) {
+ throw new DatabaseNotInitializedException()
+ }
+ const [deleted] = await this.db
+ .delete(templateTable)
+ .where(eq(templateTable.id, id))
+ .returning()
+ return !!deleted
+ }
+}
diff --git a/src/utils/vector-db/manager.ts b/src/database/modules/vector/VectorManager.ts
similarity index 88%
rename from src/utils/vector-db/manager.ts
rename to src/database/modules/vector/VectorManager.ts
index 43dc055..6b594e8 100644
--- a/src/utils/vector-db/manager.ts
+++ b/src/database/modules/vector/VectorManager.ts
@@ -1,38 +1,30 @@
import { backOff } from 'exponential-backoff'
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter'
-import { App, TFile, normalizePath } from 'obsidian'
+import { App, TFile } from 'obsidian'
import pLimit from 'p-limit'
-import { IndexProgress } from '../../components/chat-view/QueryProgress'
-import { PGLITE_DB_PATH } from '../../constants'
-import { InsertVector, SelectVector } from '../../db/schema'
-import { EmbeddingModel } from '../../types/embedding'
+import { IndexProgress } from '../../../components/chat-view/QueryProgress'
import {
LLMAPIKeyInvalidException,
LLMAPIKeyNotSetException,
LLMBaseUrlNotSetException,
-} from '../llm/exception'
-import { openSettingsModalWithError } from '../openSettingsModal'
+} from '../../../core/llm/exception'
+import { InsertVector, SelectVector } from '../../../database/schema'
+import { EmbeddingModel } from '../../../types/embedding'
+import { openSettingsModalWithError } from '../../../utils/openSettingsModal'
+import { DatabaseManager } from '../../DatabaseManager'
-import { VectorDbRepository } from './repository'
+import { VectorRepository } from './VectorRepository'
-export class VectorDbManager {
+export class VectorManager {
private app: App
- private repository: VectorDbRepository
+ private repository: VectorRepository
+ private dbManager: DatabaseManager
- constructor(app: App) {
+ constructor(app: App, dbManager: DatabaseManager) {
this.app = app
- }
-
- static async create(app: App): Promise {
- const manager = new VectorDbManager(app)
- const dbPath = normalizePath(PGLITE_DB_PATH)
- manager.repository = await VectorDbRepository.create(app, dbPath)
- return manager
- }
-
- async cleanup() {
- await this.repository.cleanup()
+ this.dbManager = dbManager
+ this.repository = new VectorRepository(app, dbManager.getDb())
}
async performSimilaritySearch(
@@ -203,7 +195,7 @@ export class VectorDbManager {
throw error
}
} finally {
- await this.repository.save()
+ await this.dbManager.save()
}
}
diff --git a/src/database/modules/vector/VectorRepository.ts b/src/database/modules/vector/VectorRepository.ts
new file mode 100644
index 0000000..a1831ae
--- /dev/null
+++ b/src/database/modules/vector/VectorRepository.ts
@@ -0,0 +1,164 @@
+import {
+ SQL,
+ and,
+ cosineDistance,
+ desc,
+ eq,
+ getTableColumns,
+ gt,
+ inArray,
+ like,
+ or,
+ sql,
+} from 'drizzle-orm'
+import { PgliteDatabase } from 'drizzle-orm/pglite'
+import { App } from 'obsidian'
+
+import { DatabaseNotInitializedException } from '../../exception'
+import { InsertVector, SelectVector, vectorTables } from '../../schema'
+
+export class VectorRepository {
+ private app: App
+ private db: PgliteDatabase | null
+
+ constructor(app: App, db: PgliteDatabase | null) {
+ this.app = app
+ this.db = db
+ }
+
+ async getIndexedFilePaths(embeddingModel: {
+ name: string
+ }): Promise {
+ if (!this.db) {
+ throw new DatabaseNotInitializedException()
+ }
+ const vectors = vectorTables[embeddingModel.name]
+ const indexedFiles = await this.db
+ .select({
+ path: vectors.path,
+ })
+ .from(vectors)
+ return indexedFiles.map((row) => row.path)
+ }
+
+ async getVectorsByFilePath(
+ filePath: string,
+ embeddingModel: { name: string },
+ ): Promise {
+ if (!this.db) {
+ throw new DatabaseNotInitializedException()
+ }
+ const vectors = vectorTables[embeddingModel.name]
+ const fileVectors = await this.db
+ .select()
+ .from(vectors)
+ .where(eq(vectors.path, filePath))
+ return fileVectors
+ }
+
+ async deleteVectorsForSingleFile(
+ filePath: string,
+ embeddingModel: { name: string },
+ ): Promise {
+ if (!this.db) {
+ throw new DatabaseNotInitializedException()
+ }
+ const vectors = vectorTables[embeddingModel.name]
+ await this.db.delete(vectors).where(eq(vectors.path, filePath))
+ }
+
+ async deleteVectorsForMultipleFiles(
+ filePaths: string[],
+ embeddingModel: { name: string },
+ ): Promise {
+ if (!this.db) {
+ throw new DatabaseNotInitializedException()
+ }
+ const vectors = vectorTables[embeddingModel.name]
+ await this.db.delete(vectors).where(inArray(vectors.path, filePaths))
+ }
+
+ async clearAllVectors(embeddingModel: { name: string }): Promise {
+ if (!this.db) {
+ throw new DatabaseNotInitializedException()
+ }
+ const vectors = vectorTables[embeddingModel.name]
+ await this.db.delete(vectors)
+ }
+
+ async insertVectors(
+ data: InsertVector[],
+ embeddingModel: { name: string },
+ ): Promise {
+ if (!this.db) {
+ throw new DatabaseNotInitializedException()
+ }
+ const vectors = vectorTables[embeddingModel.name]
+ await this.db.insert(vectors).values(data)
+ }
+
+ async performSimilaritySearch(
+ queryVector: number[],
+ embeddingModel: { name: string },
+ options: {
+ minSimilarity: number
+ limit: number
+ scope?: {
+ files: string[]
+ folders: string[]
+ }
+ },
+ ): Promise<
+ (Omit & {
+ similarity: number
+ })[]
+ > {
+ if (!this.db) {
+ throw new DatabaseNotInitializedException()
+ }
+ const vectors = vectorTables[embeddingModel.name]
+
+ const similarity = sql`1 - (${cosineDistance(vectors.embedding, queryVector)})`
+ const similarityCondition = gt(similarity, options.minSimilarity)
+
+ const getScopeCondition = (): SQL | undefined => {
+ if (!options.scope) {
+ return undefined
+ }
+ const conditions: (SQL | undefined)[] = []
+ if (options.scope.files.length > 0) {
+ conditions.push(inArray(vectors.path, options.scope.files))
+ }
+ if (options.scope.folders.length > 0) {
+ conditions.push(
+ or(
+ ...options.scope.folders.map((folder) =>
+ like(vectors.path, `${folder}/%`),
+ ),
+ ),
+ )
+ }
+ if (conditions.length === 0) {
+ return undefined
+ }
+ return or(...conditions)
+ }
+ const scopeCondition = getScopeCondition()
+
+ const similaritySearchResult = await this.db
+ .select({
+ ...(() => {
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
+ const { embedding, ...rest } = getTableColumns(vectors)
+ return rest
+ })(),
+ similarity,
+ })
+ .from(vectors)
+ .where(and(similarityCondition, scopeCondition))
+ .orderBy((t) => desc(t.similarity))
+ .limit(options.limit)
+
+ return similaritySearchResult
+ }
+}
diff --git a/src/db/schema.ts b/src/database/schema.ts
similarity index 76%
rename from src/db/schema.ts
rename to src/database/schema.ts
index 0375914..526752c 100644
--- a/src/db/schema.ts
+++ b/src/database/schema.ts
@@ -5,12 +5,16 @@ import {
pgTable,
serial,
text,
+ timestamp,
+ uuid,
vector,
} from 'drizzle-orm/pg-core'
+import { SerializedLexicalNode } from 'lexical'
import { EMBEDDING_MODEL_OPTIONS } from '../constants'
import { EmbeddingModelName } from '../types/embedding'
+/* Vector Table */
const createVectorTable = (name: string, dimension: number) => {
const sanitizedName = name.replace(/[^a-zA-Z0-9]/g, '_')
return pgTable(
@@ -58,3 +62,19 @@ export const vectorTable1 = vectorTables[EMBEDDING_MODEL_OPTIONS[1].value]
export const vectorTable2 = vectorTables[EMBEDDING_MODEL_OPTIONS[2].value]
export const vectorTable3 = vectorTables[EMBEDDING_MODEL_OPTIONS[3].value]
export const vectorTable4 = vectorTables[EMBEDDING_MODEL_OPTIONS[4].value]
+
+/* Template Table */
+export type TemplateContent = {
+ nodes: SerializedLexicalNode[]
+}
+
+export const templateTable = pgTable('template', {
+ id: uuid('id').defaultRandom().primaryKey(),
+ name: text('name').notNull().unique(),
+ content: jsonb('content').notNull().$type(),
+ createdAt: timestamp('created_at').defaultNow().notNull(),
+ updatedAt: timestamp('updated_at').defaultNow().notNull(),
+})
+
+export type SelectTemplate = typeof templateTable.$inferSelect
+export type InsertTemplate = typeof templateTable.$inferInsert
diff --git a/src/main.ts b/src/main.ts
index 1157093..282666c 100644
--- a/src/main.ts
+++ b/src/main.ts
@@ -4,19 +4,21 @@ import { ApplyView } from './ApplyView'
import { ChatView } from './ChatView'
import { ChatProps } from './components/chat-view/Chat'
import { APPLY_VIEW_TYPE, CHAT_VIEW_TYPE } from './constants'
+import { RAGEngine } from './core/rag/ragEngine'
+import { DatabaseManager } from './database/DatabaseManager'
import { SmartCopilotSettingTab } from './settings/SettingTab'
import {
SmartCopilotSettings,
parseSmartCopilotSettings,
} from './types/settings'
import { getMentionableBlockData } from './utils/obsidian'
-import { RAGEngine } from './utils/ragEngine'
// Remember to rename these classes and interfaces!
export default class SmartCopilotPlugin extends Plugin {
settings: SmartCopilotSettings
initialChatProps?: ChatProps // TODO: change this to use view state like ApplyView
settingsChangeListeners: ((newSettings: SmartCopilotSettings) => void)[] = []
+ dbManager: DatabaseManager | null = null
ragEngine: RAGEngine | null = null
async onload() {
@@ -112,8 +114,8 @@ export default class SmartCopilotPlugin extends Plugin {
}
onunload() {
- this.ragEngine?.cleanup()
- this.ragEngine = null
+ this.dbManager?.cleanup()
+ this.dbManager = null
}
async loadSettings() {
@@ -189,9 +191,17 @@ export default class SmartCopilotPlugin extends Plugin {
chatView.focusMessage()
}
+ async getDbManager(): Promise {
+ if (!this.dbManager) {
+ this.dbManager = await DatabaseManager.create(this.app)
+ }
+ return this.dbManager
+ }
+
async getRAGEngine(): Promise {
+ const dbManager = await this.getDbManager()
if (!this.ragEngine) {
- this.ragEngine = await RAGEngine.create(this.app, this.settings)
+ this.ragEngine = new RAGEngine(this.app, this.settings, dbManager)
}
return this.ragEngine
}
diff --git a/src/utils/promptGenerator.ts b/src/utils/promptGenerator.ts
index ce75072..a1bb4dd 100644
--- a/src/utils/promptGenerator.ts
+++ b/src/utils/promptGenerator.ts
@@ -2,6 +2,7 @@ import { App, TFile, htmlToMarkdown, requestUrl } from 'obsidian'
import { editorStateToPlainText } from '../components/chat-view/chat-input/utils/editor-state-to-plain-text'
import { QueryProgressState } from '../components/chat-view/QueryProgress'
+import { RAGEngine } from '../core/rag/ragEngine'
import { ChatMessage, ChatUserMessage } from '../types/chat'
import { RequestMessage } from '../types/llm/request'
import {
@@ -18,7 +19,6 @@ import {
readMultipleTFiles,
readTFileContent,
} from './obsidian'
-import { RAGEngine } from './ragEngine'
import { tokenCount } from './token'
import { YoutubeTranscript, isYoutubeUrl } from './youtube-transcript'
diff --git a/src/utils/vector-db/repository.ts b/src/utils/vector-db/repository.ts
deleted file mode 100644
index cb3bd2d..0000000
--- a/src/utils/vector-db/repository.ts
+++ /dev/null
@@ -1,283 +0,0 @@
-import { PGlite } from '@electric-sql/pglite'
-import {
- SQL,
- and,
- cosineDistance,
- desc,
- eq,
- getTableColumns,
- gt,
- inArray,
- like,
- or,
- sql,
-} from 'drizzle-orm'
-import { PgliteDatabase, drizzle } from 'drizzle-orm/pglite'
-import { App, requestUrl } from 'obsidian'
-
-import migrations from '../../db/migrations.json'
-import { InsertVector, SelectVector, vectorTables } from '../../db/schema'
-
-export class VectorDbRepository {
- private app: App
- private pgClient: PGlite | null = null
- private db: PgliteDatabase | null = null
- private dbPath: string
-
- constructor(app: App, dbPath: string) {
- this.app = app
- this.dbPath = dbPath
- }
-
- static async create(app: App, dbPath: string): Promise {
- const repo = new VectorDbRepository(app, dbPath)
- repo.db = await repo.loadExistingDatabase()
- if (!repo.db) {
- repo.db = await repo.createNewDatabase()
- }
- await repo.migrateDatabase()
- await repo.save()
- console.log('Smart composer database initialized.')
- return repo
- }
-
- async cleanup() {
- this.pgClient?.close()
- this.db = null
- this.pgClient = null
- }
-
- async getIndexedFilePaths(embeddingModel: {
- name: string
- }): Promise {
- if (!this.db) {
- throw new Error('Database not initialized')
- }
- const vectors = vectorTables[embeddingModel.name]
- const indexedFiles = await this.db
- .select({
- path: vectors.path,
- })
- .from(vectors)
- return indexedFiles.map((row) => row.path)
- }
-
- async getVectorsByFilePath(
- filePath: string,
- embeddingModel: { name: string },
- ): Promise {
- if (!this.db) {
- throw new Error('Database not initialized')
- }
- const vectors = vectorTables[embeddingModel.name]
- const fileVectors = await this.db
- .select()
- .from(vectors)
- .where(eq(vectors.path, filePath))
- return fileVectors
- }
-
- async deleteVectorsForSingleFile(
- filePath: string,
- embeddingModel: { name: string },
- ): Promise {
- if (!this.db) {
- throw new Error('Database not initialized')
- }
- const vectors = vectorTables[embeddingModel.name]
- await this.db.delete(vectors).where(eq(vectors.path, filePath))
- }
-
- async deleteVectorsForMultipleFiles(
- filePaths: string[],
- embeddingModel: { name: string },
- ): Promise {
- if (!this.db) {
- throw new Error('Database not initialized')
- }
- const vectors = vectorTables[embeddingModel.name]
- await this.db.delete(vectors).where(inArray(vectors.path, filePaths))
- }
-
- async clearAllVectors(embeddingModel: { name: string }): Promise {
- if (!this.db) {
- throw new Error('Database not initialized')
- }
- const vectors = vectorTables[embeddingModel.name]
- await this.db.delete(vectors)
- }
-
- async insertVectors(
- data: InsertVector[],
- embeddingModel: { name: string },
- ): Promise {
- if (!this.db) {
- throw new Error('Database not initialized')
- }
- const vectors = vectorTables[embeddingModel.name]
- await this.db.insert(vectors).values(data)
- }
-
- async performSimilaritySearch(
- queryVector: number[],
- embeddingModel: { name: string },
- options: {
- minSimilarity: number
- limit: number
- scope?: {
- files: string[]
- folders: string[]
- }
- },
- ): Promise<
- (Omit & {
- similarity: number
- })[]
- > {
- if (!this.db) {
- throw new Error('Database not initialized')
- }
- const vectors = vectorTables[embeddingModel.name]
-
- const similarity = sql`1 - (${cosineDistance(vectors.embedding, queryVector)})`
- const similarityCondition = gt(similarity, options.minSimilarity)
-
- const getScopeCondition = (): SQL | undefined => {
- if (!options.scope) {
- return undefined
- }
- const conditions: (SQL | undefined)[] = []
- if (options.scope.files.length > 0) {
- conditions.push(inArray(vectors.path, options.scope.files))
- }
- if (options.scope.folders.length > 0) {
- conditions.push(
- or(
- ...options.scope.folders.map((folder) =>
- like(vectors.path, `${folder}/%`),
- ),
- ),
- )
- }
- if (conditions.length === 0) {
- return undefined
- }
- return or(...conditions)
- }
- const scopeCondition = getScopeCondition()
-
- const similaritySearchResult = await this.db
- .select({
- ...(() => {
- const { embedding, ...rest } = getTableColumns(vectors)
- return rest
- })(),
- similarity,
- })
- .from(vectors)
- .where(and(similarityCondition, scopeCondition))
- .orderBy((t) => desc(t.similarity))
- .limit(options.limit)
-
- return similaritySearchResult
- }
-
- async save(): Promise {
- if (!this.pgClient) {
- return
- }
- try {
- const blob: Blob = await this.pgClient.dumpDataDir('gzip')
- await this.app.vault.adapter.writeBinary(
- this.dbPath,
- Buffer.from(await blob.arrayBuffer()),
- )
- } catch (error) {
- console.error('Error saving database:', error)
- }
- }
-
- private async createNewDatabase() {
- const { fsBundle, wasmModule, vectorExtensionBundlePath } =
- await this.loadPGliteResources()
- this.pgClient = await PGlite.create({
- fsBundle: fsBundle,
- wasmModule: wasmModule,
- extensions: {
- vector: vectorExtensionBundlePath,
- },
- })
- const db = drizzle(this.pgClient)
- return db
- }
-
- private async loadExistingDatabase(): Promise {
- try {
- const databaseFileExists = await this.app.vault.adapter.exists(
- this.dbPath,
- )
- if (!databaseFileExists) {
- return null
- }
- const fileBuffer = await this.app.vault.adapter.readBinary(this.dbPath)
- const fileBlob = new Blob([fileBuffer], { type: 'application/x-gzip' })
- const { fsBundle, wasmModule, vectorExtensionBundlePath } =
- await this.loadPGliteResources()
- this.pgClient = await PGlite.create({
- loadDataDir: fileBlob,
- fsBundle: fsBundle,
- wasmModule: wasmModule,
- extensions: {
- vector: vectorExtensionBundlePath,
- },
- })
- return drizzle(this.pgClient)
- } catch (error) {
- console.error('Error loading database:', error)
- return null
- }
- }
-
- private async migrateDatabase(): Promise {
- try {
- // Workaround for running Drizzle migrations in a browser environment
- // This method uses an undocumented API to perform migrations
- // See: https://github.com/drizzle-team/drizzle-orm/discussions/2532#discussioncomment-10780523
- // eslint-disable-next-line @typescript-eslint/ban-ts-comment
- // @ts-expect-error
- await this.db.dialect.migrate(migrations, this.db.session, {
- migrationsTable: 'drizzle_migrations',
- })
- } catch (error) {
- console.error('Error migrating database:', error)
- throw error
- }
- }
-
- // TODO: This function is a temporary workaround chosen due to the difficulty of bundling postgres.wasm and postgres.data from node_modules into a single JS file. The ultimate goal is to bundle everything into one JS file in the future.
- private async loadPGliteResources(): Promise<{
- fsBundle: Blob
- wasmModule: WebAssembly.Module
- vectorExtensionBundlePath: URL
- }> {
- try {
- const [fsBundleResponse, wasmResponse] = await Promise.all([
- requestUrl('https://unpkg.com/@electric-sql/pglite/dist/postgres.data'),
- requestUrl('https://unpkg.com/@electric-sql/pglite/dist/postgres.wasm'),
- ])
-
- const fsBundle = new Blob([fsBundleResponse.arrayBuffer], {
- type: 'application/octet-stream',
- })
- const wasmModule = await WebAssembly.compile(wasmResponse.arrayBuffer)
- const vectorExtensionBundlePath = new URL(
- 'https://unpkg.com/@electric-sql/pglite/dist/vector.tar.gz',
- )
-
- return { fsBundle, wasmModule, vectorExtensionBundlePath }
- } catch (error) {
- console.error('Error loading PGlite resources:', error)
- throw error
- }
- }
-}
diff --git a/styles.css b/styles.css
index 44f856b..5a6724f 100644
--- a/styles.css
+++ b/styles.css
@@ -179,6 +179,7 @@ button:not(.clickable-icon).smtcmp-chat-list-dropdown {
}
.smtcmp-chat-user-input-container {
+ position: relative;
display: flex;
flex-direction: column;
-webkit-app-region: no-drag;
@@ -190,7 +191,6 @@ button:not(.clickable-icon).smtcmp-chat-list-dropdown {
font-size: var(--font-ui-small);
border-radius: var(--radius-s);
outline: none;
- margin-top: var(--size-4-1);
&:focus-within,
&:focus,
@@ -315,7 +315,7 @@ button:not(.clickable-icon).smtcmp-chat-list-dropdown {
* ChatUserInput
*/
-.smtcmp-chat-input-root .mention {
+.smtcmp-lexical-content-editable-root .mention {
background-color: var(--tag-background);
color: var(--tag-color);
padding: var(--size-2-1) calc(var(--size-2-1));
@@ -327,7 +327,7 @@ button:not(.clickable-icon).smtcmp-chat-list-dropdown {
word-break: break-all;
}
-.smtcmp-chat-input-paragraph {
+.smtcmp-lexical-content-editable-paragraph {
margin: 0;
line-height: 1.6;
}
@@ -757,3 +757,90 @@ button.smtcmp-chat-input-model-select {
resize: none;
}
}
+
+.smtcmp-dialog-content {
+ position: fixed;
+ left: calc(50% - var(--size-4-4));
+ top: 50%;
+ z-index: 50;
+ display: grid;
+ width: calc(100% - var(--size-4-8));
+ max-width: 32rem;
+ transform: translate(-50%, -50%);
+ gap: var(--size-4-2);
+ border: var(--border-width) solid var(--background-modifier-border);
+ background-color: var(--background-secondary);
+ padding: var(--size-4-5);
+ transition-duration: 200ms;
+ border-radius: var(--radius-m);
+ box-shadow: 0 25px 50px -12px rgb(0 0 0 / 0.25);
+ margin: var(--size-4-4);
+
+ .smtcmp-dialog-header {
+ margin-bottom: var(--size-4-2);
+ display: grid;
+ gap: var(--size-2-3);
+ }
+
+ .smtcmp-dialog-title {
+ font-size: var(--font-ui-medium);
+ font-weight: var(--font-semibold);
+ line-height: var(--line-height-tight);
+ margin: 0;
+ }
+
+ .smtcmp-dialog-input {
+ display: grid;
+ gap: var(--size-4-1);
+
+ & label {
+ font-size: var(--font-ui-smaller);
+ }
+ }
+
+ .smtcmp-dialog-description {
+ font-size: var(--font-ui-small);
+ color: var(--text-muted);
+ margin: 0;
+ }
+
+ .smtcmp-dialog-footer {
+ margin-top: var(--size-4-2);
+ display: flex;
+ justify-content: flex-end;
+ }
+
+ .smtcmp-dialog-close {
+ position: absolute;
+ right: var(--size-4-4);
+ top: var(--size-4-4);
+ cursor: var(--cursor);
+ opacity: 0.7;
+ transition: opacity 0.2s;
+
+ &:hover {
+ opacity: 1;
+ }
+ }
+}
+
+.smtcmp-template-menu-item {
+ display: flex;
+ align-items: center;
+ justify-content: space-between;
+ gap: var(--size-4-1);
+ width: 100%;
+
+ .smtcmp-template-menu-item-delete {
+ display: flex;
+ align-items: center;
+ padding: var(--size-4-1);
+ margin: calc(var(--size-4-1) * -1);
+ opacity: 0.7;
+ transition: opacity 0.2s;
+
+ &:hover {
+ opacity: 1;
+ }
+ }
+}