Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Display similarity search results #107

Merged
merged 3 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 52 additions & 42 deletions src/components/chat-view/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import { editorStateToPlainText } from './chat-input/utils/editor-state-to-plain
import { ChatListDropdown } from './ChatListDropdown'
import QueryProgress, { QueryProgressState } from './QueryProgress'
import ReactMarkdown from './ReactMarkdown'
import SimilaritySearchResults from './SimilaritySearchResults'

// Add an empty line here
const getNewInputMessage = (app: App): ChatUserMessage => {
Expand Down Expand Up @@ -539,48 +540,57 @@ const Chat = forwardRef<ChatRef, ChatProps>((props, ref) => {
<div className="smtcmp-chat-messages" ref={chatMessagesRef}>
{chatMessages.map((message, index) =>
message.role === 'user' ? (
<ChatUserInput
key={message.id}
ref={(ref) => registerChatUserInputRef(message.id, ref)}
initialSerializedEditorState={message.content}
onChange={(content) => {
setChatMessages((prevChatHistory) =>
prevChatHistory.map((msg) =>
msg.role === 'user' && msg.id === message.id
? {
...msg,
content,
}
: msg,
),
)
}}
onSubmit={(content, useVaultSearch) => {
if (editorStateToPlainText(content).trim() === '') return
handleSubmit(
[
...chatMessages.slice(0, index),
{
...message,
content,
},
],
useVaultSearch,
)
chatUserInputRefs.current.get(inputMessage.id)?.focus()
}}
onFocus={() => {
setFocusedMessageId(message.id)
}}
mentionables={message.mentionables}
setMentionables={(mentionables) => {
setChatMessages((prevChatHistory) =>
prevChatHistory.map((msg) =>
msg.id === message.id ? { ...msg, mentionables } : msg,
),
)
}}
/>
<div key={message.id} className="smtcmp-chat-messages-user">
<ChatUserInput
ref={(ref) => registerChatUserInputRef(message.id, ref)}
initialSerializedEditorState={message.content}
onChange={(content) => {
setChatMessages((prevChatHistory) =>
prevChatHistory.map((msg) =>
msg.role === 'user' && msg.id === message.id
? {
...msg,
content,
}
: msg,
),
)
}}
onSubmit={(content, useVaultSearch) => {
if (editorStateToPlainText(content).trim() === '') return
handleSubmit(
[
...chatMessages.slice(0, index),
{
role: 'user',
content: content,
promptContent: null,
id: message.id,
mentionables: message.mentionables,
glowingjade marked this conversation as resolved.
Show resolved Hide resolved
},
],
useVaultSearch,
)
chatUserInputRefs.current.get(inputMessage.id)?.focus()
}}
onFocus={() => {
setFocusedMessageId(message.id)
}}
mentionables={message.mentionables}
setMentionables={(mentionables) => {
setChatMessages((prevChatHistory) =>
prevChatHistory.map((msg) =>
msg.id === message.id ? { ...msg, mentionables } : msg,
),
)
}}
/>
{message.similaritySearchResults && (
<SimilaritySearchResults
similaritySearchResults={message.similaritySearchResults}
/>
)}
</div>
) : (
<ReactMarkdownItem
key={message.id}
Expand Down
2 changes: 1 addition & 1 deletion src/components/chat-view/ChatListDropdown.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export function ChatListDropdown({
</DropdownMenu.Trigger>

<DropdownMenu.Portal>
<DropdownMenu.Content className="smtcmp-popover">
<DropdownMenu.Content className="smtcmp-popover smtcmp-chat-list-dropdown-content">
<ul>
{chatList.length === 0 ? (
<li className="smtcmp-chat-list-dropdown-empty">
Expand Down
68 changes: 68 additions & 0 deletions src/components/chat-view/SimilaritySearchResults.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import path from 'path'

import { ChevronDown, ChevronRight } from 'lucide-react'
import { useState } from 'react'

import { useApp } from '../../contexts/app-context'
import { SelectVector } from '../../database/schema'
import { openMarkdownFile } from '../../utils/obsidian'

function SimiliartySearchItem({
chunk,
}: {
chunk: Omit<SelectVector, 'embedding'> & {
similarity: number
}
}) {
const app = useApp()

const handleClick = () => {
openMarkdownFile(app, chunk.path, chunk.metadata.startLine)
}
return (
<div onClick={handleClick} className="smtcmp-similarity-search-item">
<div className="smtcmp-similarity-search-item__path">
{path.basename(chunk.path)}
</div>
<div className="smtcmp-similarity-search-item__line-numbers">
{`${chunk.metadata.startLine} - ${chunk.metadata.endLine}`}
</div>
</div>
)
}

export default function SimilaritySearchResults({
similaritySearchResults,
}: {
similaritySearchResults: (Omit<SelectVector, 'embedding'> & {
similarity: number
})[]
}) {
const [isOpen, setIsOpen] = useState(false)

return (
<div className="smtcmp-similarity-search-results">
<div
onClick={() => {
setIsOpen(!isOpen)
}}
className="smtcmp-similarity-search-results__trigger"
>
{isOpen ? <ChevronDown size={16} /> : <ChevronRight size={16} />}
<div>Show Referenced Documents ({similaritySearchResults.length})</div>
</div>
{isOpen && (
<div
style={{
display: 'flex',
flexDirection: 'column',
}}
>
{similaritySearchResults.map((chunk) => (
<SimiliartySearchItem key={chunk.id} chunk={chunk} />
))}
</div>
)}
</div>
)
}
4 changes: 2 additions & 2 deletions src/database/modules/vector/VectorRepository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ export class VectorRepository {
}
const scopeCondition = getScopeCondition()

const similaritySearchResult = await this.db
const similaritySearchResults = await this.db
.select({
...(() => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
Expand All @@ -159,6 +159,6 @@ export class VectorRepository {
.orderBy((t) => desc(t.similarity))
.limit(options.limit)

return similaritySearchResult
return similaritySearchResults
}
}
2 changes: 2 additions & 0 deletions src/hooks/useChatHistory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const serializeChatMessage = (message: ChatMessage): SerializedChatMessage => {
promptContent: message.promptContent,
id: message.id,
mentionables: message.mentionables.map(serializeMentionable),
similaritySearchResults: message.similaritySearchResults,
}
case 'assistant':
return {
Expand All @@ -50,6 +51,7 @@ const deserializeChatMessage = (
mentionables: message.mentionables
.map((m) => deserializeMentionable(m, app))
.filter((m): m is Mentionable => m !== null),
similaritySearchResults: message.similaritySearchResults,
}
}
case 'assistant':
Expand Down
8 changes: 8 additions & 0 deletions src/types/chat.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { SerializedEditorState } from 'lexical'

import { SelectVector } from '../database/schema'

import { Mentionable, SerializedMentionable } from './mentionable'

export type ChatUserMessage = {
Expand All @@ -8,6 +10,9 @@ export type ChatUserMessage = {
promptContent: string | null
id: string
mentionables: Mentionable[]
similaritySearchResults?: (Omit<SelectVector, 'embedding'> & {
similarity: number
})[]
}
export type ChatAssistantMessage = {
role: 'assistant'
Expand All @@ -22,6 +27,9 @@ export type SerializedChatUserMessage = {
promptContent: string | null
id: string
mentionables: SerializedMentionable[]
similaritySearchResults?: (Omit<SelectVector, 'embedding'> & {
similarity: number
})[]
}
export type SerializedChatAssistantMessage = {
role: 'assistant'
Expand Down
4 changes: 2 additions & 2 deletions src/utils/obsidian.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ export function openMarkdownFile(

if (startLine) {
const view = existingLeaf.view as MarkdownView
view.setEphemeralState({ line: startLine })
view.setEphemeralState({ line: startLine - 1 })
}
} else {
const leaf = app.workspace.getLeaf('tab')
leaf.openFile(file, {
eState: startLine ? { line: startLine } : undefined,
eState: startLine ? { line: startLine - 1 } : undefined,
glowingjade marked this conversation as resolved.
Show resolved Hide resolved
})
}
}
32 changes: 20 additions & 12 deletions src/utils/promptGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { App, TFile, htmlToMarkdown, requestUrl } from 'obsidian'
import { editorStateToPlainText } from '../components/chat-view/chat-input/utils/editor-state-to-plain-text'
import { QueryProgressState } from '../components/chat-view/QueryProgress'
import { RAGEngine } from '../core/rag/ragEngine'
import { SelectVector } from '../database/schema'
import { ChatMessage, ChatUserMessage } from '../types/chat'
import { RequestMessage } from '../types/llm/request'
import {
Expand Down Expand Up @@ -57,31 +58,33 @@ export class PromptGenerator {
throw new Error('Last message is not a user message')
}

const { promptContent, shouldUseRAG } = await this.compileUserMessagePrompt(
{
const { promptContent, shouldUseRAG, similaritySearchResults } =
await this.compileUserMessagePrompt({
message: lastUserMessage,
useVaultSearch,
onQueryProgressChange: onQueryProgressChange,
},
)
onQueryProgressChange,
})
let compiledMessages = [
...messages.slice(0, -1),
{
...lastUserMessage,
promptContent: promptContent,
promptContent,
similaritySearchResults,
},
]

// Safeguard: ensure all user messages have parsed content
compiledMessages = await Promise.all(
compiledMessages.map(async (message) => {
if (message.role === 'user' && !message.promptContent) {
const { promptContent } = await this.compileUserMessagePrompt({
message,
})
const { promptContent, similaritySearchResults } =
await this.compileUserMessagePrompt({
message,
})
return {
...message,
promptContent: promptContent,
promptContent,
similaritySearchResults,
}
}
return message
Expand Down Expand Up @@ -136,6 +139,9 @@ export class PromptGenerator {
}): Promise<{
promptContent: string
shouldUseRAG: boolean
similaritySearchResults?: (Omit<SelectVector, 'embedding'> & {
similarity: number
})[]
}> {
if (!message.content) {
return {
Expand All @@ -144,6 +150,7 @@ export class PromptGenerator {
}
}
const query = editorStateToPlainText(message.content)
let similaritySearchResults = undefined

useVaultSearch =
// eslint-disable-next-line @typescript-eslint/prefer-nullish-coalescing
Expand Down Expand Up @@ -183,7 +190,7 @@ export class PromptGenerator {

let filePrompt: string
if (shouldUseRAG) {
const results = useVaultSearch
similaritySearchResults = useVaultSearch
? await (
await this.getRagEngine()
).processQuery({
Expand All @@ -201,7 +208,7 @@ export class PromptGenerator {
onQueryProgressChange: onQueryProgressChange,
})
filePrompt = `## Potentially Relevant Snippets from the current vault
${results
${similaritySearchResults
.map(({ path, content, metadata }) => {
const contentWithLineNumbers = this.addLineNumbersToContent({
content,
Expand Down Expand Up @@ -251,6 +258,7 @@ ${await this.getWebsiteContent(url)}
return {
promptContent: `${filePrompt}${blockPrompt}${urlPrompt}\n\n${query}\n\n`,
shouldUseRAG,
similaritySearchResults: similaritySearchResults,
}
}

Expand Down
Loading
Loading