Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement rag prototpye #22

Merged
merged 32 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ecb31f1
Implement rag prototpye
kevin-on Oct 16, 2024
2483cf1
Merge branch 'main' into kevin/rag
kevin-on Oct 18, 2024
32b93b1
Add settings for RAG
kevin-on Oct 20, 2024
4de4721
Add MarkdownReferenceBlock
kevin-on Oct 20, 2024
7c3d772
Add MentionableFolder and MentionableVault
kevin-on Oct 20, 2024
2a437ad
Implement RAG
kevin-on Oct 20, 2024
797a65b
Update comment
kevin-on Oct 20, 2024
67f2139
Show query progress after submit
kevin-on Oct 20, 2024
61cf2b0
Add comment
kevin-on Oct 20, 2024
e3e83a3
Style QueryProgress
realsnoopso Oct 21, 2024
1e9ba17
Style block header
realsnoopso Oct 21, 2024
61a1dd3
Add icons on Mentionable
realsnoopso Oct 21, 2024
b1e3064
Add icon to Mentionables
realsnoopso Oct 21, 2024
58b193c
Fix lint errors
realsnoopso Oct 21, 2024
10ee558
Add Rag Vault Button
realsnoopso Oct 21, 2024
74b995a
Add Drizzle ORM
kevin-on Oct 21, 2024
27b923e
Merge branch 'kevin/rag' of https://github.com/glowingjade/obsidian-s…
kevin-on Oct 21, 2024
2dc3bfe
Setup migration
kevin-on Oct 21, 2024
1b6a2c5
Handle Vault Search button
kevin-on Oct 21, 2024
a4639d6
Add notice when indexing vault through command
kevin-on Oct 22, 2024
3fa0ec3
Update embedding types
kevin-on Oct 22, 2024
6da3310
Fix lint
kevin-on Oct 22, 2024
b4a8e48
Update dependencies
kevin-on Oct 22, 2024
696bb8d
Resolve review
kevin-on Oct 23, 2024
02743ea
Rename function and variables
kevin-on Oct 23, 2024
9239fdb
Resolve review
kevin-on Oct 23, 2024
6955479
Resolve review
kevin-on Oct 23, 2024
f2ef25d
Fix bug when initializing ragEngine
kevin-on Oct 23, 2024
c0bec53
Fix bug in rag engine intialization
kevin-on Oct 23, 2024
4e74929
Insert vectors batch-wise
kevin-on Oct 23, 2024
f434b60
Set private functions
kevin-on Oct 23, 2024
1f03d90
fix tests by mocking obsidian
glowingjade Oct 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"@tanstack/react-query": "^5.56.2",
"diff": "^7.0.0",
"drizzle-orm": "^0.35.2",
"exponential-backoff": "^3.1.1",
"fuzzysort": "^3.1.0",
"groq-sdk": "^0.7.0",
"js-tiktoken": "^1.0.15",
Expand Down
11 changes: 6 additions & 5 deletions src/ChatView.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import { SettingsProvider } from './contexts/settings-context'
import SmartCopilotPlugin from './main'
import { MentionableBlockData } from './types/mentionable'
import { SmartCopilotSettings } from './types/settings'
import { RAGEngine } from './utils/ragEngine'

export class ChatView extends ItemView {
private root: Root | null = null
private settings: SmartCopilotSettings
private initialChatProps?: ChatProps
private chatRef: React.RefObject<ChatRef> = React.createRef()
private onRAGEngineChange: (ragEngine: RAGEngine) => void

constructor(
leaf: WorkspaceLeaf,
Expand All @@ -27,6 +29,9 @@ export class ChatView extends ItemView {
super(leaf)
this.settings = plugin.settings
this.initialChatProps = plugin.initialChatProps
this.onRAGEngineChange = (ragEngine) => {
plugin.ragEngine = ragEngine
}
}

getViewType() {
Expand Down Expand Up @@ -70,11 +75,7 @@ export class ChatView extends ItemView {
>
<DarkModeProvider>
<LLMProvider>
<RAGProvider
setPluginRAGEngine={(ragEngine) =>
(this.plugin.ragEngine = ragEngine)
}
>
<RAGProvider onRAGEngineChange={this.onRAGEngineChange}>
<QueryClientProvider client={queryClient}>
<React.StrictMode>
<Chat ref={this.chatRef} {...this.initialChatProps} />
Expand Down
4 changes: 2 additions & 2 deletions src/components/chat-view/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ const getNewInputMessage = (app: App): ChatUserMessage => {
return {
role: 'user',
content: null,
parsedContent: null,
promptContent: null,
id: uuidv4(),
mentionables: [
{
Expand Down Expand Up @@ -199,7 +199,7 @@ const Chat = forwardRef<ChatRef, ChatProps>((props, ref) => {
await promptGenerator.generateRequestMessages({
messages: newChatHistory,
useVaultSearch,
setQueryProgress,
onQueryProgressChange: setQueryProgress,
})
setQueryProgress({
type: 'idle',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import { createPortal } from 'react-dom'
import { serializeMentionable } from 'src/utils/mentionable'

import { Mentionable } from '../../../../../types/mentionable'
import { SearchResultItem } from '../../../../../utils/fuzzy-search'
import { SearchableMentionable } from '../../../../../utils/fuzzy-search'
import { getMentionableName } from '../../../../../utils/mentionable'
import { getMentionableIcon } from '../../utils/get-metionable-icon'
import { MenuOption, MenuTextMatch } from '../shared/LexicalMenu'
Expand Down Expand Up @@ -105,7 +105,7 @@ class MentionTypeaheadOption extends MenuOption {
mentionable: Mentionable
icon: React.ReactNode

constructor(result: SearchResultItem) {
constructor(result: SearchableMentionable) {
switch (result.type) {
case 'file':
super(result.file.path)
Expand Down Expand Up @@ -167,7 +167,7 @@ export default function NewMentionsPlugin({
searchResultByQuery,
onAddMention,
}: {
searchResultByQuery: (query: string) => SearchResultItem[]
searchResultByQuery: (query: string) => SearchableMentionable[]
onAddMention: (mentionable: Mentionable) => void
}): JSX.Element | null {
const [editor] = useLexicalComposerContext()
Expand Down
33 changes: 10 additions & 23 deletions src/contexts/rag-context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ import {
createContext,
useContext,
useEffect,
useMemo,
useState,
} from 'react'

import { getEmbeddingModel } from '../utils/embedding'
import { RAGEngine } from '../utils/ragEngine'

import { useApp } from './app-context'
Expand All @@ -19,40 +18,28 @@ export type RAGContextType = {
const RAGContext = createContext<RAGContextType | null>(null)

export function RAGProvider({
setPluginRAGEngine,
children,
}: PropsWithChildren<{ setPluginRAGEngine: (ragEngine: RAGEngine) => void }>) {
onRAGEngineChange,
}: PropsWithChildren<{ onRAGEngineChange: (ragEngine: RAGEngine) => void }>) {
const app = useApp()
const { settings } = useSettings()
const [ragEngine, setRagEngine] = useState<RAGEngine | null>(null)

const ragEngine = useMemo(() => {
return new RAGEngine(app, settings)
useEffect(() => {
RAGEngine.create(app, settings).then(setRagEngine)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [app])

const embeddingModel = useMemo(() => {
return getEmbeddingModel(settings.embeddingModel, {
openAIApiKey: settings.openAIApiKey,
})
}, [settings.embeddingModel, settings.openAIApiKey])

useEffect(() => {
void ragEngine.initialize(embeddingModel)
setPluginRAGEngine(ragEngine)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [ragEngine])
if (ragEngine) {
onRAGEngineChange(ragEngine)
}
}, [ragEngine, onRAGEngineChange])

useEffect(() => {
ragEngine?.setSettings(settings)
}, [ragEngine, settings])
glowingjade marked this conversation as resolved.
Show resolved Hide resolved

useEffect(() => {
if (!ragEngine || !embeddingModel) {
return
}
ragEngine.setEmbeddingModel(embeddingModel)
}, [ragEngine, embeddingModel])

return (
<RAGContext.Provider value={{ ragEngine }}>{children}</RAGContext.Provider>
)
Expand Down
4 changes: 2 additions & 2 deletions src/hooks/useChatHistory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const serializeChatMessage = (message: ChatMessage): SerializedChatMessage => {
return {
role: 'user',
content: message.content,
parsedContent: message.parsedContent,
promptContent: message.promptContent,
id: message.id,
mentionables: message.mentionables.map(serializeMentionable),
}
Expand All @@ -45,7 +45,7 @@ const deserializeChatMessage = (
return {
role: 'user',
content: message.content,
parsedContent: message.parsedContent,
promptContent: message.promptContent,
id: message.id,
mentionables: message.mentionables
.map((m) => deserializeMentionable(m, app))
Expand Down
17 changes: 13 additions & 4 deletions src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ export default class SmartCopilotPlugin extends Plugin {
callback: async () => {
const notice = new Notice('Re-indexing vault...', 0)
try {
await this.ragEngine?.updateVaultIndex(
{ overwrite: true },
const ragEngine = await this.getRAGEngine()
await ragEngine.updateVaultIndex(
{ reindexAll: true },
(queryProgress) => {
if (queryProgress.type === 'indexing') {
const { completedChunks, totalChunks } =
Expand Down Expand Up @@ -81,8 +82,9 @@ export default class SmartCopilotPlugin extends Plugin {
callback: async () => {
const notice = new Notice('Updating vault index...', 0)
try {
await this.ragEngine?.updateVaultIndex(
{ overwrite: false },
const ragEngine = await this.getRAGEngine()
await ragEngine.updateVaultIndex(
{ reindexAll: false },
(queryProgress) => {
if (queryProgress.type === 'indexing') {
const { completedChunks, totalChunks } =
Expand Down Expand Up @@ -181,4 +183,11 @@ export default class SmartCopilotPlugin extends Plugin {
chatView.addSelectionToChat(data)
chatView.focusMessage()
}

async getRAGEngine(): Promise<RAGEngine> {
if (!this.ragEngine) {
this.ragEngine = await RAGEngine.create(this.app, this.settings)
}
return this.ragEngine
}
}
4 changes: 2 additions & 2 deletions src/types/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { Mentionable, SerializedMentionable } from './mentionable'
export type ChatUserMessage = {
role: 'user'
content: SerializedEditorState | null
parsedContent: string | null
promptContent: string | null
id: string
mentionables: Mentionable[]
}
Expand All @@ -19,7 +19,7 @@ export type ChatMessage = ChatUserMessage | ChatAssistantMessage
export type SerializedChatUserMessage = {
role: 'user'
content: SerializedEditorState | null
parsedContent: string | null
promptContent: string | null
id: string
mentionables: SerializedMentionable[]
}
Expand Down
69 changes: 28 additions & 41 deletions src/utils/fuzzy-search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@ import {

import { calculateFileDistance, getOpenFiles } from './obsidian'

export type SearchResultItem =
export type SearchableMentionable =
| MentionableFile
| MentionableFolder
| MentionableVault

type SearchItem = FolderWithMetadata | FileWithMetadata | VaultSearchItem

type VaultSearchItem = {
type: 'vault'
path: string
Expand All @@ -36,6 +34,8 @@ type FolderWithMetadata = {
folder: TFolder
}

type SearchItem = FolderWithMetadata | FileWithMetadata | VaultSearchItem

function scoreFnWithBoost(score: number, searchItem: SearchItem) {
switch (searchItem.type) {
case 'file': {
Expand Down Expand Up @@ -76,7 +76,7 @@ function scoreFnWithBoost(score: number, searchItem: SearchItem) {
function getEmptyQueryResult(
searchItems: SearchItem[],
limit: number,
): SearchResultItem[] {
): SearchableMentionable[] {
// Sort files based on a custom scoring function
const sortedFiles = searchItems.sort((a, b) => {
const scoreA = scoreFnWithBoost(0.5, a) // Use 0.5 as a base score
Expand All @@ -85,27 +85,12 @@ function getEmptyQueryResult(
})

// Return only the top 'limit' files
return sortedFiles.slice(0, limit).map((item) => {
switch (item.type) {
case 'file':
return {
type: 'file',
file: item.file,
}
case 'folder':
return {
type: 'folder',
folder: item.folder,
}
case 'vault':
return {
type: 'vault',
}
}
})
return sortedFiles
.slice(0, limit)
.map((item) => searchItemToMentionable(item))
}

export function fuzzySearch(app: App, query: string): SearchResultItem[] {
export function fuzzySearch(app: App, query: string): SearchableMentionable[] {
const currentFile = app.workspace.getActiveFile()
const openFiles = getOpenFiles(app)

Expand Down Expand Up @@ -158,22 +143,24 @@ export function fuzzySearch(app: App, query: string): SearchResultItem[] {
scoreFn: (result) => scoreFnWithBoost(result.score, result.obj),
})

return results.map((result) => {
switch (result.obj.type) {
case 'file':
return {
type: 'file',
file: result.obj.file,
}
case 'folder':
return {
type: 'folder',
folder: result.obj.folder,
}
case 'vault':
return {
type: 'vault',
}
}
})
return results.map((result) => searchItemToMentionable(result.obj))
}

function searchItemToMentionable(item: SearchItem): SearchableMentionable {
switch (item.type) {
case 'file':
return {
type: 'file',
file: item.file,
}
case 'folder':
return {
type: 'folder',
folder: item.folder,
}
case 'vault':
return {
type: 'vault',
}
}
}
Loading
Loading