Skip to content

Commit

Permalink
Display similarity search results (#107)
Browse files Browse the repository at this point in the history
* Display similarity search results
  • Loading branch information
kevin-on authored Nov 20, 2024
1 parent abb1cac commit d1bb01a
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 60 deletions.
94 changes: 52 additions & 42 deletions src/components/chat-view/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import { editorStateToPlainText } from './chat-input/utils/editor-state-to-plain
import { ChatListDropdown } from './ChatListDropdown'
import QueryProgress, { QueryProgressState } from './QueryProgress'
import ReactMarkdown from './ReactMarkdown'
import SimilaritySearchResults from './SimilaritySearchResults'

// Add an empty line here
const getNewInputMessage = (app: App): ChatUserMessage => {
Expand Down Expand Up @@ -539,48 +540,57 @@ const Chat = forwardRef<ChatRef, ChatProps>((props, ref) => {
<div className="smtcmp-chat-messages" ref={chatMessagesRef}>
{chatMessages.map((message, index) =>
message.role === 'user' ? (
<ChatUserInput
key={message.id}
ref={(ref) => registerChatUserInputRef(message.id, ref)}
initialSerializedEditorState={message.content}
onChange={(content) => {
setChatMessages((prevChatHistory) =>
prevChatHistory.map((msg) =>
msg.role === 'user' && msg.id === message.id
? {
...msg,
content,
}
: msg,
),
)
}}
onSubmit={(content, useVaultSearch) => {
if (editorStateToPlainText(content).trim() === '') return
handleSubmit(
[
...chatMessages.slice(0, index),
{
...message,
content,
},
],
useVaultSearch,
)
chatUserInputRefs.current.get(inputMessage.id)?.focus()
}}
onFocus={() => {
setFocusedMessageId(message.id)
}}
mentionables={message.mentionables}
setMentionables={(mentionables) => {
setChatMessages((prevChatHistory) =>
prevChatHistory.map((msg) =>
msg.id === message.id ? { ...msg, mentionables } : msg,
),
)
}}
/>
<div key={message.id} className="smtcmp-chat-messages-user">
<ChatUserInput
ref={(ref) => registerChatUserInputRef(message.id, ref)}
initialSerializedEditorState={message.content}
onChange={(content) => {
setChatMessages((prevChatHistory) =>
prevChatHistory.map((msg) =>
msg.role === 'user' && msg.id === message.id
? {
...msg,
content,
}
: msg,
),
)
}}
onSubmit={(content, useVaultSearch) => {
if (editorStateToPlainText(content).trim() === '') return
handleSubmit(
[
...chatMessages.slice(0, index),
{
role: 'user',
content: content,
promptContent: null,
id: message.id,
mentionables: message.mentionables,
},
],
useVaultSearch,
)
chatUserInputRefs.current.get(inputMessage.id)?.focus()
}}
onFocus={() => {
setFocusedMessageId(message.id)
}}
mentionables={message.mentionables}
setMentionables={(mentionables) => {
setChatMessages((prevChatHistory) =>
prevChatHistory.map((msg) =>
msg.id === message.id ? { ...msg, mentionables } : msg,
),
)
}}
/>
{message.similaritySearchResults && (
<SimilaritySearchResults
similaritySearchResults={message.similaritySearchResults}
/>
)}
</div>
) : (
<ReactMarkdownItem
key={message.id}
Expand Down
2 changes: 1 addition & 1 deletion src/components/chat-view/ChatListDropdown.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export function ChatListDropdown({
</DropdownMenu.Trigger>

<DropdownMenu.Portal>
<DropdownMenu.Content className="smtcmp-popover">
<DropdownMenu.Content className="smtcmp-popover smtcmp-chat-list-dropdown-content">
<ul>
{chatList.length === 0 ? (
<li className="smtcmp-chat-list-dropdown-empty">
Expand Down
68 changes: 68 additions & 0 deletions src/components/chat-view/SimilaritySearchResults.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import path from 'path'

import { ChevronDown, ChevronRight } from 'lucide-react'
import { useState } from 'react'

import { useApp } from '../../contexts/app-context'
import { SelectVector } from '../../database/schema'
import { openMarkdownFile } from '../../utils/obsidian'

function SimiliartySearchItem({
chunk,
}: {
chunk: Omit<SelectVector, 'embedding'> & {
similarity: number
}
}) {
const app = useApp()

const handleClick = () => {
openMarkdownFile(app, chunk.path, chunk.metadata.startLine)
}
return (
<div onClick={handleClick} className="smtcmp-similarity-search-item">
<div className="smtcmp-similarity-search-item__path">
{path.basename(chunk.path)}
</div>
<div className="smtcmp-similarity-search-item__line-numbers">
{`${chunk.metadata.startLine} - ${chunk.metadata.endLine}`}
</div>
</div>
)
}

export default function SimilaritySearchResults({
similaritySearchResults,
}: {
similaritySearchResults: (Omit<SelectVector, 'embedding'> & {
similarity: number
})[]
}) {
const [isOpen, setIsOpen] = useState(false)

return (
<div className="smtcmp-similarity-search-results">
<div
onClick={() => {
setIsOpen(!isOpen)
}}
className="smtcmp-similarity-search-results__trigger"
>
{isOpen ? <ChevronDown size={16} /> : <ChevronRight size={16} />}
<div>Show Referenced Documents ({similaritySearchResults.length})</div>
</div>
{isOpen && (
<div
style={{
display: 'flex',
flexDirection: 'column',
}}
>
{similaritySearchResults.map((chunk) => (
<SimiliartySearchItem key={chunk.id} chunk={chunk} />
))}
</div>
)}
</div>
)
}
6 changes: 5 additions & 1 deletion src/components/chat-view/chat-input/ChatUserInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,11 @@ const ChatUserInput = forwardRef<ChatUserInputRef, ChatUserInputProps>(
mentionableKey === displayedMentionableKey
) {
// open file on click again
openMarkdownFile(app, m.file.path)
openMarkdownFile(
app,
m.file.path,
m.type === 'block' ? m.startLine : undefined,
)
} else {
setDisplayedMentionableKey(mentionableKey)
}
Expand Down
4 changes: 2 additions & 2 deletions src/database/modules/vector/VectorRepository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ export class VectorRepository {
}
const scopeCondition = getScopeCondition()

const similaritySearchResult = await this.db
const similaritySearchResults = await this.db
.select({
...(() => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
Expand All @@ -159,6 +159,6 @@ export class VectorRepository {
.orderBy((t) => desc(t.similarity))
.limit(options.limit)

return similaritySearchResult
return similaritySearchResults
}
}
2 changes: 2 additions & 0 deletions src/hooks/useChatHistory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const serializeChatMessage = (message: ChatMessage): SerializedChatMessage => {
promptContent: message.promptContent,
id: message.id,
mentionables: message.mentionables.map(serializeMentionable),
similaritySearchResults: message.similaritySearchResults,
}
case 'assistant':
return {
Expand All @@ -50,6 +51,7 @@ const deserializeChatMessage = (
mentionables: message.mentionables
.map((m) => deserializeMentionable(m, app))
.filter((m): m is Mentionable => m !== null),
similaritySearchResults: message.similaritySearchResults,
}
}
case 'assistant':
Expand Down
8 changes: 8 additions & 0 deletions src/types/chat.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { SerializedEditorState } from 'lexical'

import { SelectVector } from '../database/schema'

import { Mentionable, SerializedMentionable } from './mentionable'

export type ChatUserMessage = {
Expand All @@ -8,6 +10,9 @@ export type ChatUserMessage = {
promptContent: string | null
id: string
mentionables: Mentionable[]
similaritySearchResults?: (Omit<SelectVector, 'embedding'> & {
similarity: number
})[]
}
export type ChatAssistantMessage = {
role: 'assistant'
Expand All @@ -22,6 +27,9 @@ export type SerializedChatUserMessage = {
promptContent: string | null
id: string
mentionables: SerializedMentionable[]
similaritySearchResults?: (Omit<SelectVector, 'embedding'> & {
similarity: number
})[]
}
export type SerializedChatAssistantMessage = {
role: 'assistant'
Expand Down
4 changes: 2 additions & 2 deletions src/utils/obsidian.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ export function openMarkdownFile(

if (startLine) {
const view = existingLeaf.view as MarkdownView
view.setEphemeralState({ line: startLine })
view.setEphemeralState({ line: startLine - 1 }) // -1 because line is 0-indexed
}
} else {
const leaf = app.workspace.getLeaf('tab')
leaf.openFile(file, {
eState: startLine ? { line: startLine } : undefined,
eState: startLine ? { line: startLine - 1 } : undefined, // -1 because line is 0-indexed
})
}
}
32 changes: 20 additions & 12 deletions src/utils/promptGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { App, TFile, htmlToMarkdown, requestUrl } from 'obsidian'
import { editorStateToPlainText } from '../components/chat-view/chat-input/utils/editor-state-to-plain-text'
import { QueryProgressState } from '../components/chat-view/QueryProgress'
import { RAGEngine } from '../core/rag/ragEngine'
import { SelectVector } from '../database/schema'
import { ChatMessage, ChatUserMessage } from '../types/chat'
import { RequestMessage } from '../types/llm/request'
import {
Expand Down Expand Up @@ -57,31 +58,33 @@ export class PromptGenerator {
throw new Error('Last message is not a user message')
}

const { promptContent, shouldUseRAG } = await this.compileUserMessagePrompt(
{
const { promptContent, shouldUseRAG, similaritySearchResults } =
await this.compileUserMessagePrompt({
message: lastUserMessage,
useVaultSearch,
onQueryProgressChange: onQueryProgressChange,
},
)
onQueryProgressChange,
})
let compiledMessages = [
...messages.slice(0, -1),
{
...lastUserMessage,
promptContent: promptContent,
promptContent,
similaritySearchResults,
},
]

// Safeguard: ensure all user messages have parsed content
compiledMessages = await Promise.all(
compiledMessages.map(async (message) => {
if (message.role === 'user' && !message.promptContent) {
const { promptContent } = await this.compileUserMessagePrompt({
message,
})
const { promptContent, similaritySearchResults } =
await this.compileUserMessagePrompt({
message,
})
return {
...message,
promptContent: promptContent,
promptContent,
similaritySearchResults,
}
}
return message
Expand Down Expand Up @@ -136,6 +139,9 @@ export class PromptGenerator {
}): Promise<{
promptContent: string
shouldUseRAG: boolean
similaritySearchResults?: (Omit<SelectVector, 'embedding'> & {
similarity: number
})[]
}> {
if (!message.content) {
return {
Expand All @@ -144,6 +150,7 @@ export class PromptGenerator {
}
}
const query = editorStateToPlainText(message.content)
let similaritySearchResults = undefined

useVaultSearch =
// eslint-disable-next-line @typescript-eslint/prefer-nullish-coalescing
Expand Down Expand Up @@ -183,7 +190,7 @@ export class PromptGenerator {

let filePrompt: string
if (shouldUseRAG) {
const results = useVaultSearch
similaritySearchResults = useVaultSearch
? await (
await this.getRagEngine()
).processQuery({
Expand All @@ -201,7 +208,7 @@ export class PromptGenerator {
onQueryProgressChange: onQueryProgressChange,
})
filePrompt = `## Potentially Relevant Snippets from the current vault
${results
${similaritySearchResults
.map(({ path, content, metadata }) => {
const contentWithLineNumbers = this.addLineNumbersToContent({
content,
Expand Down Expand Up @@ -251,6 +258,7 @@ ${await this.getWebsiteContent(url)}
return {
promptContent: `${filePrompt}${blockPrompt}${urlPrompt}\n\n${query}\n\n`,
shouldUseRAG,
similaritySearchResults: similaritySearchResults,
}
}

Expand Down
Loading

0 comments on commit d1bb01a

Please sign in to comment.