From d1bb01a8c936d26047c8fd70d21d9667bca14a44 Mon Sep 17 00:00:00 2001 From: Kevin On <40454531+kevin-on@users.noreply.github.com> Date: Wed, 20 Nov 2024 01:06:30 -0500 Subject: [PATCH] Display similarity search results (#107) * Display similarity search results --- src/components/chat-view/Chat.tsx | 94 ++++++++++-------- src/components/chat-view/ChatListDropdown.tsx | 2 +- .../chat-view/SimilaritySearchResults.tsx | 68 +++++++++++++ .../chat-view/chat-input/ChatUserInput.tsx | 6 +- .../modules/vector/VectorRepository.ts | 4 +- src/hooks/useChatHistory.ts | 2 + src/types/chat.ts | 8 ++ src/utils/obsidian.ts | 4 +- src/utils/promptGenerator.ts | 32 +++--- styles.css | 99 +++++++++++++++++++ 10 files changed, 259 insertions(+), 60 deletions(-) create mode 100644 src/components/chat-view/SimilaritySearchResults.tsx diff --git a/src/components/chat-view/Chat.tsx b/src/components/chat-view/Chat.tsx index ea0eb57..224a41c 100644 --- a/src/components/chat-view/Chat.tsx +++ b/src/components/chat-view/Chat.tsx @@ -44,6 +44,7 @@ import { editorStateToPlainText } from './chat-input/utils/editor-state-to-plain import { ChatListDropdown } from './ChatListDropdown' import QueryProgress, { QueryProgressState } from './QueryProgress' import ReactMarkdown from './ReactMarkdown' +import SimilaritySearchResults from './SimilaritySearchResults' // Add an empty line here const getNewInputMessage = (app: App): ChatUserMessage => { @@ -539,48 +540,57 @@ const Chat = forwardRef((props, ref) => {
{chatMessages.map((message, index) => message.role === 'user' ? ( - registerChatUserInputRef(message.id, ref)} - initialSerializedEditorState={message.content} - onChange={(content) => { - setChatMessages((prevChatHistory) => - prevChatHistory.map((msg) => - msg.role === 'user' && msg.id === message.id - ? { - ...msg, - content, - } - : msg, - ), - ) - }} - onSubmit={(content, useVaultSearch) => { - if (editorStateToPlainText(content).trim() === '') return - handleSubmit( - [ - ...chatMessages.slice(0, index), - { - ...message, - content, - }, - ], - useVaultSearch, - ) - chatUserInputRefs.current.get(inputMessage.id)?.focus() - }} - onFocus={() => { - setFocusedMessageId(message.id) - }} - mentionables={message.mentionables} - setMentionables={(mentionables) => { - setChatMessages((prevChatHistory) => - prevChatHistory.map((msg) => - msg.id === message.id ? { ...msg, mentionables } : msg, - ), - ) - }} - /> +
+ registerChatUserInputRef(message.id, ref)} + initialSerializedEditorState={message.content} + onChange={(content) => { + setChatMessages((prevChatHistory) => + prevChatHistory.map((msg) => + msg.role === 'user' && msg.id === message.id + ? { + ...msg, + content, + } + : msg, + ), + ) + }} + onSubmit={(content, useVaultSearch) => { + if (editorStateToPlainText(content).trim() === '') return + handleSubmit( + [ + ...chatMessages.slice(0, index), + { + role: 'user', + content: content, + promptContent: null, + id: message.id, + mentionables: message.mentionables, + }, + ], + useVaultSearch, + ) + chatUserInputRefs.current.get(inputMessage.id)?.focus() + }} + onFocus={() => { + setFocusedMessageId(message.id) + }} + mentionables={message.mentionables} + setMentionables={(mentionables) => { + setChatMessages((prevChatHistory) => + prevChatHistory.map((msg) => + msg.id === message.id ? { ...msg, mentionables } : msg, + ), + ) + }} + /> + {message.similaritySearchResults && ( + + )} +
) : ( - +
    {chatList.length === 0 ? (
  • diff --git a/src/components/chat-view/SimilaritySearchResults.tsx b/src/components/chat-view/SimilaritySearchResults.tsx new file mode 100644 index 0000000..272016e --- /dev/null +++ b/src/components/chat-view/SimilaritySearchResults.tsx @@ -0,0 +1,68 @@ +import path from 'path' + +import { ChevronDown, ChevronRight } from 'lucide-react' +import { useState } from 'react' + +import { useApp } from '../../contexts/app-context' +import { SelectVector } from '../../database/schema' +import { openMarkdownFile } from '../../utils/obsidian' + +function SimiliartySearchItem({ + chunk, +}: { + chunk: Omit & { + similarity: number + } +}) { + const app = useApp() + + const handleClick = () => { + openMarkdownFile(app, chunk.path, chunk.metadata.startLine) + } + return ( +
    +
    + {path.basename(chunk.path)} +
    +
    + {`${chunk.metadata.startLine} - ${chunk.metadata.endLine}`} +
    +
    + ) +} + +export default function SimilaritySearchResults({ + similaritySearchResults, +}: { + similaritySearchResults: (Omit & { + similarity: number + })[] +}) { + const [isOpen, setIsOpen] = useState(false) + + return ( +
    +
    { + setIsOpen(!isOpen) + }} + className="smtcmp-similarity-search-results__trigger" + > + {isOpen ? : } +
    Show Referenced Documents ({similaritySearchResults.length})
    +
    + {isOpen && ( +
    + {similaritySearchResults.map((chunk) => ( + + ))} +
    + )} +
    + ) +} diff --git a/src/components/chat-view/chat-input/ChatUserInput.tsx b/src/components/chat-view/chat-input/ChatUserInput.tsx index 20a0d0a..81f692b 100644 --- a/src/components/chat-view/chat-input/ChatUserInput.tsx +++ b/src/components/chat-view/chat-input/ChatUserInput.tsx @@ -226,7 +226,11 @@ const ChatUserInput = forwardRef( mentionableKey === displayedMentionableKey ) { // open file on click again - openMarkdownFile(app, m.file.path) + openMarkdownFile( + app, + m.file.path, + m.type === 'block' ? m.startLine : undefined, + ) } else { setDisplayedMentionableKey(mentionableKey) } diff --git a/src/database/modules/vector/VectorRepository.ts b/src/database/modules/vector/VectorRepository.ts index a1831ae..9b75610 100644 --- a/src/database/modules/vector/VectorRepository.ts +++ b/src/database/modules/vector/VectorRepository.ts @@ -145,7 +145,7 @@ export class VectorRepository { } const scopeCondition = getScopeCondition() - const similaritySearchResult = await this.db + const similaritySearchResults = await this.db .select({ ...(() => { // eslint-disable-next-line @typescript-eslint/no-unused-vars @@ -159,6 +159,6 @@ export class VectorRepository { .orderBy((t) => desc(t.similarity)) .limit(options.limit) - return similaritySearchResult + return similaritySearchResults } } diff --git a/src/hooks/useChatHistory.ts b/src/hooks/useChatHistory.ts index 0543364..05ff29b 100644 --- a/src/hooks/useChatHistory.ts +++ b/src/hooks/useChatHistory.ts @@ -26,6 +26,7 @@ const serializeChatMessage = (message: ChatMessage): SerializedChatMessage => { promptContent: message.promptContent, id: message.id, mentionables: message.mentionables.map(serializeMentionable), + similaritySearchResults: message.similaritySearchResults, } case 'assistant': return { @@ -50,6 +51,7 @@ const deserializeChatMessage = ( mentionables: message.mentionables .map((m) => deserializeMentionable(m, app)) .filter((m): m is Mentionable => m !== null), + similaritySearchResults: message.similaritySearchResults, } } case 'assistant': diff --git a/src/types/chat.ts b/src/types/chat.ts index acca208..7984858 100644 --- a/src/types/chat.ts +++ b/src/types/chat.ts @@ -1,5 +1,7 @@ import { SerializedEditorState } from 'lexical' +import { SelectVector } from '../database/schema' + import { Mentionable, SerializedMentionable } from './mentionable' export type ChatUserMessage = { @@ -8,6 +10,9 @@ export type ChatUserMessage = { promptContent: string | null id: string mentionables: Mentionable[] + similaritySearchResults?: (Omit & { + similarity: number + })[] } export type ChatAssistantMessage = { role: 'assistant' @@ -22,6 +27,9 @@ export type SerializedChatUserMessage = { promptContent: string | null id: string mentionables: SerializedMentionable[] + similaritySearchResults?: (Omit & { + similarity: number + })[] } export type SerializedChatAssistantMessage = { role: 'assistant' diff --git a/src/utils/obsidian.ts b/src/utils/obsidian.ts index d0e7a61..14bee02 100644 --- a/src/utils/obsidian.ts +++ b/src/utils/obsidian.ts @@ -113,12 +113,12 @@ export function openMarkdownFile( if (startLine) { const view = existingLeaf.view as MarkdownView - view.setEphemeralState({ line: startLine }) + view.setEphemeralState({ line: startLine - 1 }) // -1 because line is 0-indexed } } else { const leaf = app.workspace.getLeaf('tab') leaf.openFile(file, { - eState: startLine ? { line: startLine } : undefined, + eState: startLine ? { line: startLine - 1 } : undefined, // -1 because line is 0-indexed }) } } diff --git a/src/utils/promptGenerator.ts b/src/utils/promptGenerator.ts index 900a312..d855a21 100644 --- a/src/utils/promptGenerator.ts +++ b/src/utils/promptGenerator.ts @@ -3,6 +3,7 @@ import { App, TFile, htmlToMarkdown, requestUrl } from 'obsidian' import { editorStateToPlainText } from '../components/chat-view/chat-input/utils/editor-state-to-plain-text' import { QueryProgressState } from '../components/chat-view/QueryProgress' import { RAGEngine } from '../core/rag/ragEngine' +import { SelectVector } from '../database/schema' import { ChatMessage, ChatUserMessage } from '../types/chat' import { RequestMessage } from '../types/llm/request' import { @@ -57,18 +58,18 @@ export class PromptGenerator { throw new Error('Last message is not a user message') } - const { promptContent, shouldUseRAG } = await this.compileUserMessagePrompt( - { + const { promptContent, shouldUseRAG, similaritySearchResults } = + await this.compileUserMessagePrompt({ message: lastUserMessage, useVaultSearch, - onQueryProgressChange: onQueryProgressChange, - }, - ) + onQueryProgressChange, + }) let compiledMessages = [ ...messages.slice(0, -1), { ...lastUserMessage, - promptContent: promptContent, + promptContent, + similaritySearchResults, }, ] @@ -76,12 +77,14 @@ export class PromptGenerator { compiledMessages = await Promise.all( compiledMessages.map(async (message) => { if (message.role === 'user' && !message.promptContent) { - const { promptContent } = await this.compileUserMessagePrompt({ - message, - }) + const { promptContent, similaritySearchResults } = + await this.compileUserMessagePrompt({ + message, + }) return { ...message, - promptContent: promptContent, + promptContent, + similaritySearchResults, } } return message @@ -136,6 +139,9 @@ export class PromptGenerator { }): Promise<{ promptContent: string shouldUseRAG: boolean + similaritySearchResults?: (Omit & { + similarity: number + })[] }> { if (!message.content) { return { @@ -144,6 +150,7 @@ export class PromptGenerator { } } const query = editorStateToPlainText(message.content) + let similaritySearchResults = undefined useVaultSearch = // eslint-disable-next-line @typescript-eslint/prefer-nullish-coalescing @@ -183,7 +190,7 @@ export class PromptGenerator { let filePrompt: string if (shouldUseRAG) { - const results = useVaultSearch + similaritySearchResults = useVaultSearch ? await ( await this.getRagEngine() ).processQuery({ @@ -201,7 +208,7 @@ export class PromptGenerator { onQueryProgressChange: onQueryProgressChange, }) filePrompt = `## Potentially Relevant Snippets from the current vault -${results +${similaritySearchResults .map(({ path, content, metadata }) => { const contentWithLineNumbers = this.addLineNumbersToContent({ content, @@ -251,6 +258,7 @@ ${await this.getWebsiteContent(url)} return { promptContent: `${filePrompt}${blockPrompt}${urlPrompt}\n\n${query}\n\n`, shouldUseRAG, + similaritySearchResults: similaritySearchResults, } } diff --git a/styles.css b/styles.css index b248fac..dede952 100644 --- a/styles.css +++ b/styles.css @@ -151,6 +151,18 @@ button:not(.clickable-icon).smtcmp-chat-list-dropdown { gap: var(--size-4-1); padding: 0 var(--size-4-3) var(--size-4-5) var(--size-4-3); margin: var(--size-4-2) calc(var(--size-4-3) * -1) 0; + + .smtcmp-chat-messages-user { + display: flex; + flex-direction: column; + gap: var(--size-4-1); + } + + .smtcmp-chat-messages-assistant { + display: flex; + flex-direction: column; + gap: var(--size-4-1); + } } .obsidian-default-textarea { @@ -463,6 +475,12 @@ button:not(.clickable-icon).smtcmp-chat-list-dropdown { border-radius: var(--radius-s); } +.smtcmp-chat-list-dropdown-content li { + display: flex; + align-items: center; + justify-content: space-between; +} + .smtcmp-code-block { position: relative; border: 1px solid var(--background-modifier-border); @@ -856,3 +874,84 @@ button.smtcmp-chat-input-model-select { } } } + +.smtcmp-assistant-message-footer { + display: flex; + align-items: center; + justify-content: end; + + button { + display: flex; + align-items: center; + justify-content: center; + width: 26px; + height: 26px; + padding: 0; + background-color: transparent; + border-color: transparent; + box-shadow: none; + color: var(--text-muted); + + &:hover { + background-color: var(--background-modifier-hover); + } + } +} + +.smtcmp-popover-content { + background-color: var(--background-primary); + border: 1px solid var(--background-modifier-border); + border-radius: var(--radius-s); + padding: var(--size-4-1) var(--size-4-2); + font-size: var(--font-smallest); + animation: fadeIn 0.1s ease-in-out; +} + +.smtcmp-similarity-search-results { + display: flex; + flex-direction: column; + font-size: var(--font-smaller); + padding-top: var(--size-4-1); + padding-bottom: var(--size-4-1); + user-select: none; + + .smtcmp-similarity-search-results__trigger { + display: flex; + align-items: center; + gap: var(--size-4-1); + padding: var(--size-4-1); + border-radius: var(--radius-s); + cursor: pointer; + &:hover { + background-color: var(--background-modifier-hover); + } + } + + .smtcmp-similarity-search-item { + display: flex; + align-items: center; + justify-content: start; + gap: var(--size-4-2); + padding: var(--size-4-1); + border-radius: var(--radius-s); + cursor: pointer; + + &:hover { + background-color: var(--background-modifier-hover); + } + + .smtcmp-similarity-search-item__path { + flex-shrink: 1; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + font-size: var(--font-smallest); + } + + .smtcmp-similarity-search-item__line-numbers { + flex-shrink: 0; + margin-left: auto; + font-size: var(--font-smallest); + } + } +}