diff --git a/src/components/chat-view/chat-input/ChatUserInput.tsx b/src/components/chat-view/chat-input/ChatUserInput.tsx index 81f692b..16bff5a 100644 --- a/src/components/chat-view/chat-input/ChatUserInput.tsx +++ b/src/components/chat-view/chat-input/ChatUserInput.tsx @@ -2,15 +2,22 @@ import { useQuery } from '@tanstack/react-query' import { $nodesOfType, LexicalEditor, SerializedEditorState } from 'lexical' import { forwardRef, + useCallback, useEffect, useImperativeHandle, + useMemo, useRef, useState, } from 'react' import { useApp } from '../../../contexts/app-context' import { useDarkModeContext } from '../../../contexts/dark-mode-context' -import { Mentionable, SerializedMentionable } from '../../../types/mentionable' +import { + Mentionable, + MentionableImage, + SerializedMentionable, +} from '../../../types/mentionable' +import { fileToMentionableImage } from '../../../utils/image' import { deserializeMentionable, getMentionableKey, @@ -19,6 +26,7 @@ import { import { openMarkdownFile, readTFileContent } from '../../../utils/obsidian' import { MemoizedSyntaxHighlighterWrapper } from '../SyntaxHighlighterWrapper' +import { ImageUploadButton } from './ImageUploadButton' import LexicalContentEditable from './LexicalContentEditable' import MentionableBadge from './MentionableBadge' import { ModelSelect } from './ModelSelect' @@ -57,7 +65,6 @@ const ChatUserInput = forwardRef( ref, ) => { const app = useApp() - const { isDarkMode } = useDarkModeContext() const editorRef = useRef(null) const contentEditableRef = useRef(null) @@ -139,6 +146,29 @@ const ChatUserInput = forwardRef( } } + const handleCreateImageMentionables = useCallback( + (mentionableImages: MentionableImage[]) => { + const newMentionableImages = mentionableImages.filter( + (m) => + !mentionables.some( + (mentionable) => + getMentionableKey(serializeMentionable(mentionable)) === + getMentionableKey(serializeMentionable(m)), + ), + ) + if (newMentionableImages.length === 0) return + setMentionables([...mentionables, ...newMentionableImages]) + setDisplayedMentionableKey( + getMentionableKey( + serializeMentionable( + newMentionableImages[newMentionableImages.length - 1], + ), + ), + ) + }, + [mentionables, setMentionables], + ) + const handleMentionableDelete = (mentionable: Mentionable) => { const mentionableKey = getMentionableKey( serializeMentionable(mentionable), @@ -158,47 +188,12 @@ const ChatUserInput = forwardRef( }) } - const { data: fileContent } = useQuery({ - queryKey: [ - 'file', - displayedMentionableKey, - mentionables.map((m) => getMentionableKey(serializeMentionable(m))), // should be updated when mentionables change (especially on delete) - ], - queryFn: async () => { - if (!displayedMentionableKey) return null - - const displayedMentionable = mentionables.find( - (m) => - getMentionableKey(serializeMentionable(m)) === - displayedMentionableKey, - ) - - if (!displayedMentionable) return null - - if ( - displayedMentionable.type === 'file' || - displayedMentionable.type === 'current-file' - ) { - if (!displayedMentionable.file) return null - return await readTFileContent(displayedMentionable.file, app.vault) - } else if (displayedMentionable.type === 'block') { - const fileContent = await readTFileContent( - displayedMentionable.file, - app.vault, - ) - - return fileContent - .split('\n') - .slice( - displayedMentionable.startLine - 1, - displayedMentionable.endLine, - ) - .join('\n') - } - - return null - }, - }) + const handleUploadImages = async (images: File[]) => { + const mentionableImages = await Promise.all( + images.map((image) => fileToMentionableImage(image)), + ) + handleCreateImageMentionables(mentionableImages) + } const handleSubmit = (options: { useVaultSearch?: boolean } = {}) => { const content = editorRef.current?.getEditorState()?.toJSON() @@ -235,23 +230,19 @@ const ChatUserInput = forwardRef( setDisplayedMentionableKey(mentionableKey) } }} + isFocused={ + getMentionableKey(serializeMentionable(m)) === + displayedMentionableKey + } /> ))} )} - {fileContent && ( -
- - {fileContent} - -
- )} + { @@ -267,6 +258,7 @@ const ChatUserInput = forwardRef( onEnter={() => handleSubmit({ useVaultSearch: false })} onFocus={onFocus} onMentionNodeMutation={handleMentionNodeMutation} + onCreateImageMentionables={handleCreateImageMentionables} autoFocus={autoFocus} plugins={{ onEnter: { @@ -281,8 +273,11 @@ const ChatUserInput = forwardRef( />
- -
+
+ +
+
+ handleSubmit()} /> { @@ -296,6 +291,84 @@ const ChatUserInput = forwardRef( }, ) +function MentionableContentPreview({ + displayedMentionableKey, + mentionables, +}: { + displayedMentionableKey: string | null + mentionables: Mentionable[] +}) { + const app = useApp() + const { isDarkMode } = useDarkModeContext() + + const displayedMentionable: Mentionable | null = useMemo(() => { + return ( + mentionables.find( + (m) => + getMentionableKey(serializeMentionable(m)) === + displayedMentionableKey, + ) ?? null + ) + }, [displayedMentionableKey, mentionables]) + + const { data: displayFileContent } = useQuery({ + enabled: + !!displayedMentionable && + ['file', 'current-file', 'block'].includes(displayedMentionable.type), + queryKey: [ + 'file', + displayedMentionableKey, + mentionables.map((m) => getMentionableKey(serializeMentionable(m))), // should be updated when mentionables change (especially on delete) + ], + queryFn: async () => { + if (!displayedMentionable) return null + if ( + displayedMentionable.type === 'file' || + displayedMentionable.type === 'current-file' + ) { + if (!displayedMentionable.file) return null + return await readTFileContent(displayedMentionable.file, app.vault) + } else if (displayedMentionable.type === 'block') { + const fileContent = await readTFileContent( + displayedMentionable.file, + app.vault, + ) + + return fileContent + .split('\n') + .slice( + displayedMentionable.startLine - 1, + displayedMentionable.endLine, + ) + .join('\n') + } + + return null + }, + }) + + const displayImage: MentionableImage | null = useMemo(() => { + return displayedMentionable?.type === 'image' ? displayedMentionable : null + }, [displayedMentionable]) + + return displayFileContent ? ( +
+ + {displayFileContent} + +
+ ) : displayImage ? ( +
+ {displayImage.name} +
+ ) : null +} + ChatUserInput.displayName = 'ChatUserInput' export default ChatUserInput diff --git a/src/components/chat-view/chat-input/ImageUploadButton.tsx b/src/components/chat-view/chat-input/ImageUploadButton.tsx new file mode 100644 index 0000000..0bdd03d --- /dev/null +++ b/src/components/chat-view/chat-input/ImageUploadButton.tsx @@ -0,0 +1,30 @@ +import { ImageIcon } from 'lucide-react' + +export function ImageUploadButton({ + onUpload, +}: { + onUpload: (files: File[]) => void +}) { + const handleFileChange = (event: React.ChangeEvent) => { + const files = Array.from(event.target.files ?? []) + if (files.length > 0) { + onUpload(files) + } + } + + return ( + + ) +} diff --git a/src/components/chat-view/chat-input/LexicalContentEditable.tsx b/src/components/chat-view/chat-input/LexicalContentEditable.tsx index b7947fe..192fc3f 100644 --- a/src/components/chat-view/chat-input/LexicalContentEditable.tsx +++ b/src/components/chat-view/chat-input/LexicalContentEditable.tsx @@ -13,8 +13,11 @@ import { LexicalEditor, SerializedEditorState } from 'lexical' import { RefObject, useCallback, useEffect } from 'react' import { useApp } from '../../../contexts/app-context' +import { MentionableImage } from '../../../types/mentionable' import { fuzzySearch } from '../../../utils/fuzzy-search' +import DragDropPaste from './plugins/image/DragDropPastePlugin' +import ImagePastePlugin from './plugins/image/ImagePastePlugin' import AutoLinkMentionPlugin from './plugins/mention/AutoLinkMentionPlugin' import { MentionNode } from './plugins/mention/MentionNode' import MentionPlugin from './plugins/mention/MentionPlugin' @@ -33,6 +36,7 @@ export type LexicalContentEditableProps = { onEnter?: (evt: KeyboardEvent) => void onFocus?: () => void onMentionNodeMutation?: (mutations: NodeMutations) => void + onCreateImageMentionables?: (mentionables: MentionableImage[]) => void initialEditorState?: InitialEditorStateType autoFocus?: boolean plugins?: { @@ -52,6 +56,7 @@ export default function LexicalContentEditable({ onEnter, onFocus, onMentionNodeMutation, + onCreateImageMentionables, initialEditorState, autoFocus = false, plugins, @@ -134,6 +139,8 @@ export default function LexicalContentEditable({ + + {plugins?.templatePopover && ( void onClick: () => void + isFocused: boolean }>) { return ( -
+
{children}
void onClick: () => void + isFocused: boolean }) { const Icon = getMentionableIcon(mentionable) return ( - +
{Icon && ( void onClick: () => void + isFocused: boolean }) { const Icon = getMentionableIcon(mentionable) return ( - - {/* TODO: Update style */} +
{Icon && ( void onClick: () => void + isFocused: boolean }) { const Icon = getMentionableIcon(mentionable) return ( - + {/* TODO: Update style */}
{Icon && ( @@ -119,14 +130,16 @@ function CurrentFileBadge({ mentionable, onDelete, onClick, + isFocused, }: { mentionable: MentionableCurrentFile onDelete: () => void onClick: () => void + isFocused: boolean }) { const Icon = getMentionableIcon(mentionable) return mentionable.file ? ( - +
{Icon && ( void onClick: () => void + isFocused: boolean }) { const Icon = getMentionableIcon(mentionable) return ( - +
{Icon && ( void onClick: () => void + isFocused: boolean }) { const Icon = getMentionableIcon(mentionable) return ( - +
{Icon && ( void + onClick: () => void + isFocused: boolean +}) { + const Icon = getMentionableIcon(mentionable) + return ( + +
+ {Icon && ( + + )} + {mentionable.name} +
+
+ ) +} + export default function MentionableBadge({ mentionable, onDelete, onClick, + isFocused = false, }: { mentionable: Mentionable onDelete: () => void onClick: () => void + isFocused?: boolean }) { switch (mentionable.type) { case 'file': @@ -212,6 +258,7 @@ export default function MentionableBadge({ mentionable={mentionable} onDelete={onDelete} onClick={onClick} + isFocused={isFocused} /> ) case 'folder': @@ -220,6 +267,7 @@ export default function MentionableBadge({ mentionable={mentionable} onDelete={onDelete} onClick={onClick} + isFocused={isFocused} /> ) case 'vault': @@ -228,6 +276,7 @@ export default function MentionableBadge({ mentionable={mentionable} onDelete={onDelete} onClick={onClick} + isFocused={isFocused} /> ) case 'current-file': @@ -236,6 +285,7 @@ export default function MentionableBadge({ mentionable={mentionable} onDelete={onDelete} onClick={onClick} + isFocused={isFocused} /> ) case 'block': @@ -244,6 +294,7 @@ export default function MentionableBadge({ mentionable={mentionable} onDelete={onDelete} onClick={onClick} + isFocused={isFocused} /> ) case 'url': @@ -252,6 +303,16 @@ export default function MentionableBadge({ mentionable={mentionable} onDelete={onDelete} onClick={onClick} + isFocused={isFocused} + /> + ) + case 'image': + return ( + ) } diff --git a/src/components/chat-view/chat-input/ModelSelect.tsx b/src/components/chat-view/chat-input/ModelSelect.tsx index b1bb371..6040d56 100644 --- a/src/components/chat-view/chat-input/ModelSelect.tsx +++ b/src/components/chat-view/chat-input/ModelSelect.tsx @@ -11,12 +11,16 @@ export function ModelSelect() { return ( - { - CHAT_MODEL_OPTIONS.find( - (option) => option.id === settings.chatModelId, - )?.name - } - {isOpen ? : } +
+ { + CHAT_MODEL_OPTIONS.find( + (option) => option.id === settings.chatModelId, + )?.name + } +
+
+ {isOpen ? : } +
diff --git a/src/components/chat-view/chat-input/plugins/image/DragDropPastePlugin.tsx b/src/components/chat-view/chat-input/plugins/image/DragDropPastePlugin.tsx new file mode 100644 index 0000000..a94de8b --- /dev/null +++ b/src/components/chat-view/chat-input/plugins/image/DragDropPastePlugin.tsx @@ -0,0 +1,34 @@ +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' +import { DRAG_DROP_PASTE } from '@lexical/rich-text' +import { COMMAND_PRIORITY_LOW } from 'lexical' +import { useEffect } from 'react' + +import { MentionableImage } from '../../../../../types/mentionable' +import { fileToMentionableImage } from '../../../../../utils/image' + +export default function DragDropPaste({ + onCreateImageMentionables, +}: { + onCreateImageMentionables?: (mentionables: MentionableImage[]) => void +}): null { + const [editor] = useLexicalComposerContext() + + useEffect(() => { + return editor.registerCommand( + DRAG_DROP_PASTE, // dispatched in RichTextPlugin + (files) => { + ;(async () => { + const images = files.filter((file) => file.type.startsWith('image/')) + const mentionableImages = await Promise.all( + images.map(async (image) => await fileToMentionableImage(image)), + ) + onCreateImageMentionables?.(mentionableImages) + })() + return true + }, + COMMAND_PRIORITY_LOW, + ) + }, [editor, onCreateImageMentionables]) + + return null +} diff --git a/src/components/chat-view/chat-input/plugins/image/ImagePastePlugin.tsx b/src/components/chat-view/chat-input/plugins/image/ImagePastePlugin.tsx new file mode 100644 index 0000000..01ee760 --- /dev/null +++ b/src/components/chat-view/chat-input/plugins/image/ImagePastePlugin.tsx @@ -0,0 +1,42 @@ +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' +import { COMMAND_PRIORITY_LOW, PASTE_COMMAND, PasteCommandType } from 'lexical' +import { useEffect } from 'react' + +import { MentionableImage } from '../../../../../types/mentionable' +import { fileToMentionableImage } from '../../../../../utils/image' + +export default function ImagePastePlugin({ + onCreateImageMentionables, +}: { + onCreateImageMentionables?: (mentionables: MentionableImage[]) => void +}) { + const [editor] = useLexicalComposerContext() + + useEffect(() => { + const handlePaste = (event: PasteCommandType) => { + const clipboardData = + event instanceof ClipboardEvent ? event.clipboardData : null + if (!clipboardData) return false + + const images = Array.from(clipboardData.files).filter((file) => + file.type.startsWith('image/'), + ) + if (images.length === 0) return false + + Promise.all(images.map((image) => fileToMentionableImage(image))).then( + (mentionableImages) => { + onCreateImageMentionables?.(mentionableImages) + }, + ) + return true + } + + return editor.registerCommand( + PASTE_COMMAND, + handlePaste, + COMMAND_PRIORITY_LOW, + ) + }, [editor, onCreateImageMentionables]) + + return null +} diff --git a/src/components/chat-view/chat-input/utils/get-metionable-icon.ts b/src/components/chat-view/chat-input/utils/get-metionable-icon.ts index 268bf5c..1fdc50e 100644 --- a/src/components/chat-view/chat-input/utils/get-metionable-icon.ts +++ b/src/components/chat-view/chat-input/utils/get-metionable-icon.ts @@ -1,4 +1,10 @@ -import { FileIcon, FolderClosedIcon, FoldersIcon, LinkIcon } from 'lucide-react' +import { + FileIcon, + FolderClosedIcon, + FoldersIcon, + ImageIcon, + LinkIcon, +} from 'lucide-react' import { Mentionable } from '../../../../types/mentionable' @@ -16,6 +22,8 @@ export const getMentionableIcon = (mentionable: Mentionable) => { return FileIcon case 'url': return LinkIcon + case 'image': + return ImageIcon default: return null } diff --git a/src/core/llm/anthropic.ts b/src/core/llm/anthropic.ts index 2c88a12..7ab256e 100644 --- a/src/core/llm/anthropic.ts +++ b/src/core/llm/anthropic.ts @@ -1,7 +1,9 @@ import Anthropic from '@anthropic-ai/sdk' import { + ImageBlockParam, MessageParam, MessageStreamEvent, + TextBlockParam, } from '@anthropic-ai/sdk/resources/messages' import { LLMModel } from '../../types/llm/model' @@ -16,6 +18,7 @@ import { LLMResponseStreaming, ResponseUsage, } from '../../types/llm/response' +import { parseImageDataUrl } from '../../utils/image' import { BaseLLMProvider } from './base' import { @@ -43,10 +46,7 @@ export class AnthropicProvider implements BaseLLMProvider { ) } - const systemMessages = request.messages.filter((m) => m.role === 'system') - if (systemMessages.length > 1) { - throw new Error('Anthropic does not support more than one system message') - } + const systemMessage = this.validateSystemMessages(request.messages) try { const response = await this.client.messages.create( @@ -55,8 +55,7 @@ export class AnthropicProvider implements BaseLLMProvider { messages: request.messages .filter((m) => m.role !== 'system') .map((m) => AnthropicProvider.parseRequestMessage(m)), - system: - systemMessages.length > 0 ? systemMessages[0].content : undefined, + system: systemMessage, max_tokens: request.max_tokens ?? AnthropicProvider.DEFAULT_MAX_TOKENS, temperature: request.temperature, @@ -90,10 +89,7 @@ export class AnthropicProvider implements BaseLLMProvider { ) } - const systemMessages = request.messages.filter((m) => m.role === 'system') - if (systemMessages.length > 1) { - throw new Error('Anthropic does not support more than one system message') - } + const systemMessage = this.validateSystemMessages(request.messages) try { const stream = await this.client.messages.create( @@ -102,8 +98,7 @@ export class AnthropicProvider implements BaseLLMProvider { messages: request.messages .filter((m) => m.role !== 'system') .map((m) => AnthropicProvider.parseRequestMessage(m)), - system: - systemMessages.length > 0 ? systemMessages[0].content : undefined, + system: systemMessage, max_tokens: request.max_tokens ?? AnthropicProvider.DEFAULT_MAX_TOKENS, temperature: request.temperature, @@ -176,11 +171,39 @@ export class AnthropicProvider implements BaseLLMProvider { static parseRequestMessage(message: RequestMessage): MessageParam { if (message.role !== 'user' && message.role !== 'assistant') { - throw new Error('Unsupported role') + throw new Error(`Anthropic does not support role: ${message.role}`) } + + if (message.role === 'user' && Array.isArray(message.content)) { + const content = message.content.map( + (part): TextBlockParam | ImageBlockParam => { + switch (part.type) { + case 'text': + return { type: 'text', text: part.text } + case 'image_url': { + const { mimeType, base64Data } = parseImageDataUrl( + part.image_url.url, + ) + AnthropicProvider.validateImageType(mimeType) + return { + type: 'image', + source: { + data: base64Data, + media_type: + mimeType as ImageBlockParam['source']['media_type'], + type: 'base64', + }, + } + } + } + }, + ) + return { role: 'user', content } + } + return { role: message.role, - content: message.content, + content: message.content as string, } } @@ -237,4 +260,37 @@ export class AnthropicProvider implements BaseLLMProvider { model: model, } } + + private validateSystemMessages( + messages: RequestMessage[], + ): string | undefined { + const systemMessages = messages.filter((m) => m.role === 'system') + if (systemMessages.length > 1) { + throw new Error(`Anthropic does not support more than one system message`) + } + const systemMessage = + systemMessages.length > 0 ? systemMessages[0].content : undefined + if (systemMessage && typeof systemMessage !== 'string') { + throw new Error( + `Anthropic only supports string content for system messages`, + ) + } + return systemMessage + } + + private static validateImageType(mimeType: string) { + const SUPPORTED_IMAGE_TYPES = [ + 'image/jpeg', + 'image/png', + 'image/gif', + 'image/webp', + ] + if (!SUPPORTED_IMAGE_TYPES.includes(mimeType)) { + throw new Error( + `Anthropic does not support image type ${mimeType}. Supported types: ${SUPPORTED_IMAGE_TYPES.join( + ', ', + )}`, + ) + } + } } diff --git a/src/core/llm/gemini.ts b/src/core/llm/gemini.ts index a041e3f..dd245bc 100644 --- a/src/core/llm/gemini.ts +++ b/src/core/llm/gemini.ts @@ -17,6 +17,7 @@ import { LLMResponseNonStreaming, LLMResponseStreaming, } from '../../types/llm/response' +import { parseImageDataUrl } from '../../utils/image' import { BaseLLMProvider } from './base' import { @@ -47,7 +48,7 @@ export class GeminiProvider implements BaseLLMProvider { ): Promise { if (!this.apiKey) { throw new LLMAPIKeyNotSetException( - 'Gemini API key is missing. Please set it in settings menu.', + `Gemini API key is missing. Please set it in settings menu.`, ) } @@ -95,7 +96,7 @@ export class GeminiProvider implements BaseLLMProvider { if (isInvalidApiKey) { throw new LLMAPIKeyInvalidException( - 'Gemini API key is invalid. Please update it in settings menu.', + `Gemini API key is invalid. Please update it in settings menu.`, ) } @@ -110,7 +111,7 @@ export class GeminiProvider implements BaseLLMProvider { ): Promise> { if (!this.apiKey) { throw new LLMAPIKeyNotSetException( - 'Gemini API key is missing. Please set it in settings menu.', + `Gemini API key is missing. Please set it in settings menu.`, ) } @@ -154,7 +155,7 @@ export class GeminiProvider implements BaseLLMProvider { if (isInvalidApiKey) { throw new LLMAPIKeyInvalidException( - 'Gemini API key is invalid. Please update it in settings menu.', + `Gemini API key is invalid. Please update it in settings menu.`, ) } @@ -176,6 +177,32 @@ export class GeminiProvider implements BaseLLMProvider { if (message.role === 'system') { return null } + + if (Array.isArray(message.content)) { + return { + role: message.role === 'user' ? 'user' : 'model', + parts: message.content.map((part) => { + switch (part.type) { + case 'text': + return { text: part.text } + case 'image_url': { + const { mimeType, base64Data } = parseImageDataUrl( + part.image_url.url, + ) + GeminiProvider.validateImageType(mimeType) + + return { + inlineData: { + data: base64Data, + mimeType, + }, + } + } + } + }), + } + } + return { role: message.role === 'user' ? 'user' : 'model', parts: [ @@ -244,4 +271,21 @@ export class GeminiProvider implements BaseLLMProvider { : undefined, } } + + private static validateImageType(mimeType: string) { + const SUPPORTED_IMAGE_TYPES = [ + 'image/png', + 'image/jpeg', + 'image/webp', + 'image/heic', + 'image/heif', + ] + if (!SUPPORTED_IMAGE_TYPES.includes(mimeType)) { + throw new Error( + `Gemini does not support image type ${mimeType}. Supported types: ${SUPPORTED_IMAGE_TYPES.join( + ', ', + )}`, + ) + } + } } diff --git a/src/core/llm/groq.ts b/src/core/llm/groq.ts index 8f4abf7..7af12ea 100644 --- a/src/core/llm/groq.ts +++ b/src/core/llm/groq.ts @@ -2,6 +2,7 @@ import Groq from 'groq-sdk' import { ChatCompletion, ChatCompletionChunk, + ChatCompletionContentPart, ChatCompletionMessageParam, } from 'groq-sdk/resources/chat/completions' @@ -119,9 +120,32 @@ export class GroqProvider implements BaseLLMProvider { static parseRequestMessage( message: RequestMessage, ): ChatCompletionMessageParam { - return { - role: message.role, - content: message.content, + switch (message.role) { + case 'user': { + const content = Array.isArray(message.content) + ? message.content.map((part): ChatCompletionContentPart => { + switch (part.type) { + case 'text': + return { type: 'text', text: part.text } + case 'image_url': + return { type: 'image_url', image_url: part.image_url } + } + }) + : message.content + return { role: 'user', content } + } + case 'assistant': { + if (Array.isArray(message.content)) { + throw new Error('Assistant message should be a string') + } + return { role: 'assistant', content: message.content } + } + case 'system': { + if (Array.isArray(message.content)) { + throw new Error('System message should be a string') + } + return { role: 'system', content: message.content } + } } } diff --git a/src/core/llm/openaiMessageAdapter.ts b/src/core/llm/openaiMessageAdapter.ts index 4793b66..aed2a43 100644 --- a/src/core/llm/openaiMessageAdapter.ts +++ b/src/core/llm/openaiMessageAdapter.ts @@ -2,6 +2,7 @@ import OpenAI from 'openai' import { ChatCompletion, ChatCompletionChunk, + ChatCompletionContentPart, ChatCompletionMessageParam, } from 'openai/resources/chat/completions' @@ -83,9 +84,32 @@ export class OpenAIMessageAdapter { static parseRequestMessage( message: RequestMessage, ): ChatCompletionMessageParam { - return { - role: message.role, - content: message.content, + switch (message.role) { + case 'user': { + const content = Array.isArray(message.content) + ? message.content.map((part): ChatCompletionContentPart => { + switch (part.type) { + case 'text': + return { type: 'text', text: part.text } + case 'image_url': + return { type: 'image_url', image_url: part.image_url } + } + }) + : message.content + return { role: 'user', content } + } + case 'assistant': { + if (Array.isArray(message.content)) { + throw new Error('Assistant message should be a string') + } + return { role: 'assistant', content: message.content } + } + case 'system': { + if (Array.isArray(message.content)) { + throw new Error('System message should be a string') + } + return { role: 'system', content: message.content } + } } } diff --git a/src/types/chat.ts b/src/types/chat.ts index e8d8e19..7f70b8e 100644 --- a/src/types/chat.ts +++ b/src/types/chat.ts @@ -3,13 +3,14 @@ import { SerializedEditorState } from 'lexical' import { SelectVector } from '../database/schema' import { LLMModel } from './llm/model' +import { ContentPart } from './llm/request' import { ResponseUsage } from './llm/response' import { Mentionable, SerializedMentionable } from './mentionable' export type ChatUserMessage = { role: 'user' content: SerializedEditorState | null - promptContent: string | null + promptContent: string | ContentPart[] | null id: string mentionables: Mentionable[] similaritySearchResults?: (Omit & { @@ -30,7 +31,7 @@ export type ChatMessage = ChatUserMessage | ChatAssistantMessage export type SerializedChatUserMessage = { role: 'user' content: SerializedEditorState | null - promptContent: string | null + promptContent: string | ContentPart[] | null id: string mentionables: SerializedMentionable[] similaritySearchResults?: (Omit & { diff --git a/src/types/llm/request.ts b/src/types/llm/request.ts index fc741f6..128dbe0 100644 --- a/src/types/llm/request.ts +++ b/src/types/llm/request.ts @@ -31,9 +31,24 @@ export type LLMRequestStreaming = LLMRequestBase & { export type LLMRequest = LLMRequestNonStreaming | LLMRequestStreaming +type TextContent = { + type: 'text' + text: string +} + +type ImageContentPart = { + type: 'image_url' + image_url: { + url: string // URL or base64 encoded image data + } +} + +export type ContentPart = TextContent | ImageContentPart + export type RequestMessage = { role: 'user' | 'assistant' | 'system' - content: string + // ContentParts are only for the 'user' role: + content: string | ContentPart[] } export type LLMOptions = { diff --git a/src/types/mentionable.ts b/src/types/mentionable.ts index fcd8ca2..a0aeb97 100644 --- a/src/types/mentionable.ts +++ b/src/types/mentionable.ts @@ -28,6 +28,12 @@ export type MentionableUrl = { type: 'url' url: string } +export type MentionableImage = { + type: 'image' + name: string + mimeType: string + data: string // base64 +} export type Mentionable = | MentionableFile | MentionableFolder @@ -35,7 +41,7 @@ export type Mentionable = | MentionableCurrentFile | MentionableBlock | MentionableUrl - + | MentionableImage export type SerializedMentionableFile = { type: 'file' file: string @@ -44,9 +50,7 @@ export type SerializedMentionableFolder = { type: 'folder' folder: string } -export type SerializedMentionableVault = { - type: 'vault' -} +export type SerializedMentionableVault = MentionableVault export type SerializedMentionableCurrentFile = { type: 'current-file' file: string | null @@ -58,10 +62,8 @@ export type SerializedMentionableBlock = { startLine: number endLine: number } -export type SerializedMentionableUrl = { - type: 'url' - url: string -} +export type SerializedMentionableUrl = MentionableUrl +export type SerializedMentionableImage = MentionableImage export type SerializedMentionable = | SerializedMentionableFile | SerializedMentionableFolder @@ -69,3 +71,4 @@ export type SerializedMentionable = | SerializedMentionableCurrentFile | SerializedMentionableBlock | SerializedMentionableUrl + | SerializedMentionableImage diff --git a/src/utils/image.ts b/src/utils/image.ts new file mode 100644 index 0000000..62dddac --- /dev/null +++ b/src/utils/image.ts @@ -0,0 +1,34 @@ +import { MentionableImage } from '../types/mentionable' + +export function parseImageDataUrl(dataUrl: string): { + mimeType: string + base64Data: string +} { + const matches = dataUrl.match(/^data:([^;]+);base64,(.+)/) + if (!matches) { + throw new Error('Invalid image data URL format') + } + const [, mimeType, base64Data] = matches + return { mimeType, base64Data } +} + +export async function fileToMentionableImage( + file: File, +): Promise { + const base64Data = await fileToBase64(file) + return { + type: 'image', + name: file.name, + mimeType: file.type, + data: base64Data, + } +} + +function fileToBase64(file: File): Promise { + return new Promise((resolve, reject) => { + const reader = new FileReader() + reader.readAsDataURL(file) + reader.onload = () => resolve(reader.result as string) + reader.onerror = () => reject(new Error('Failed to read file')) + }) +} diff --git a/src/utils/mentionable.ts b/src/utils/mentionable.ts index f6dd50e..90c4ea3 100644 --- a/src/utils/mentionable.ts +++ b/src/utils/mentionable.ts @@ -38,6 +38,13 @@ export const serializeMentionable = ( type: 'url', url: mentionable.url, } + case 'image': + return { + type: 'image', + name: mentionable.name, + mimeType: mentionable.mimeType, + data: mentionable.data, + } } } @@ -103,6 +110,14 @@ export const deserializeMentionable = ( url: mentionable.url, } } + case 'image': { + return { + type: 'image', + name: mentionable.name, + mimeType: mentionable.mimeType, + data: mentionable.data, + } + } } } catch (e) { console.error('Error deserializing mentionable', e) @@ -124,6 +139,8 @@ export function getMentionableKey(mentionable: SerializedMentionable): string { return `block:${mentionable.file}:${mentionable.startLine}:${mentionable.endLine}:${mentionable.content}` case 'url': return `url:${mentionable.url}` + case 'image': + return `image:${mentionable.name}:${mentionable.data.length}:${mentionable.data.slice(-32)}` } } @@ -141,5 +158,7 @@ export function getMentionableName(mentionable: Mentionable): string { return `${mentionable.file.name} (${mentionable.startLine}:${mentionable.endLine})` case 'url': return mentionable.url + case 'image': + return mentionable.name } } diff --git a/src/utils/promptGenerator.ts b/src/utils/promptGenerator.ts index d855a21..9045f05 100644 --- a/src/utils/promptGenerator.ts +++ b/src/utils/promptGenerator.ts @@ -5,11 +5,12 @@ import { QueryProgressState } from '../components/chat-view/QueryProgress' import { RAGEngine } from '../core/rag/ragEngine' import { SelectVector } from '../database/schema' import { ChatMessage, ChatUserMessage } from '../types/chat' -import { RequestMessage } from '../types/llm/request' +import { ContentPart, RequestMessage } from '../types/llm/request' import { MentionableBlock, MentionableFile, MentionableFolder, + MentionableImage, MentionableUrl, MentionableVault, } from '../types/mentionable' @@ -137,7 +138,7 @@ export class PromptGenerator { useVaultSearch?: boolean onQueryProgressChange?: (queryProgress: QueryProgressState) => void }): Promise<{ - promptContent: string + promptContent: ChatUserMessage['promptContent'] shouldUseRAG: boolean similaritySearchResults?: (Omit & { similarity: number @@ -255,8 +256,25 @@ ${await this.getWebsiteContent(url)} ` : '' + const imageDataUrls = message.mentionables + .filter((m): m is MentionableImage => m.type === 'image') + .map(({ data }) => data) + return { - promptContent: `${filePrompt}${blockPrompt}${urlPrompt}\n\n${query}\n\n`, + promptContent: [ + ...imageDataUrls.map( + (data): ContentPart => ({ + type: 'image_url', + image_url: { + url: data, + }, + }), + ), + { + type: 'text', + text: `${filePrompt}${blockPrompt}${urlPrompt}\n\n${query}\n\n`, + }, + ], shouldUseRAG, similaritySearchResults: similaritySearchResults, } diff --git a/styles.css b/styles.css index 1f42cbe..7b22ab9 100644 --- a/styles.css +++ b/styles.css @@ -241,7 +241,13 @@ button:not(.clickable-icon).smtcmp-chat-list-dropdown { align-items: center; height: var(--size-4-4); - .smtcmp-chat-user-input-controls-buttons { + .smtcmp-chat-user-input-controls__model-select-container { + flex-shrink: 1; + overflow: hidden; + } + + .smtcmp-chat-user-input-controls__buttons { + flex-shrink: 0; display: flex; gap: var(--size-4-2); align-items: center; @@ -285,6 +291,10 @@ button:not(.clickable-icon).smtcmp-chat-list-dropdown { overflow: hidden; text-overflow: ellipsis; white-space: nowrap; + + &.smtcmp-chat-user-input-file-badge-focused { + border: 1px solid var(--interactive-accent); + } } .smtcmp-chat-user-input-file-badge:hover { @@ -333,6 +343,11 @@ button:not(.clickable-icon).smtcmp-chat-list-dropdown { max-height: 350px; overflow-y: auto; white-space: pre-line; + + img { + max-width: 100%; + max-height: 350px; + } } /** @@ -709,12 +724,28 @@ button.smtcmp-chat-input-model-select { font-weight: var(--font-medium); color: var(--text-muted); justify-content: flex-start; + align-items: center; + cursor: pointer; height: var(--size-4-4); - width: fit-content; + max-width: 100%; &:hover { color: var(--text-normal); } + + .smtcmp-chat-input-model-select__model-name { + flex-shrink: 1; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + } + + .smtcmp-chat-input-model-select__icon { + flex-shrink: 0; + display: flex; + align-items: center; + justify-content: center; + } } .smtcmp-query-progress {