diff --git a/CHANGELOG.md b/CHANGELOG.md index 78c3fdb1..055c2ab4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,4 +5,5 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### 📈 Features/Enhancements - Add support for registerMessageParser ([#5](https://github.com/opensearch-project/dashboards-assistant/pull/5)) -- Change implementation of basic_input_output to built-in parser ([#10](https://github.com/opensearch-project/dashboards-assistant/pull/10)) \ No newline at end of file +- Change implementation of basic_input_output to built-in parser ([#10](https://github.com/opensearch-project/dashboards-assistant/pull/10)) +- Add interactions into ChatState and pass specific interaction into message_bubble ([#12](https://github.com/opensearch-project/dashboards-assistant/pull/12)) \ No newline at end of file diff --git a/common/types/chat_saved_object_attributes.ts b/common/types/chat_saved_object_attributes.ts index fb3e0d2a..0421cd54 100644 --- a/common/types/chat_saved_object_attributes.ts +++ b/common/types/chat_saved_object_attributes.ts @@ -6,12 +6,23 @@ export const CHAT_SAVED_OBJECT = 'assistant-chat'; export const SAVED_OBJECT_VERSION = 1; +export interface Interaction { + input: string; + response: string; + conversation_id: string; + interaction_id: string; + create_time: string; + additional_info: Record; + parent_interaction_id?: string; +} + export interface ISession { title: string; version: number; createdTimeMs: number; updatedTimeMs: number; messages: IMessage[]; + interactions: Interaction[]; } export interface ISessionFindResponse { diff --git a/public/hooks/use_chat_actions.tsx b/public/hooks/use_chat_actions.tsx index 6561a89d..2dd26964 100644 --- a/public/hooks/use_chat_actions.tsx +++ b/public/hooks/use_chat_actions.tsx @@ -4,7 +4,11 @@ */ import { ASSISTANT_API } from '../../common/constants/llm'; -import { IMessage, ISuggestedAction } from '../../common/types/chat_saved_object_attributes'; +import { + IMessage, + ISuggestedAction, + Interaction, +} from '../../common/types/chat_saved_object_attributes'; import { useChatContext } from '../contexts/chat_context'; import { useCore } from '../contexts/core_context'; import { AssistantActions } from '../types'; @@ -14,6 +18,7 @@ interface SendResponse { sessionId: string; title: string; messages: IMessage[]; + interactions: Interaction[]; } interface SetParagraphResponse { @@ -56,7 +61,13 @@ export const useChatActions = (): AssistantActions => { if (!chatContext.title) { chatContext.setTitle(response.title); } - chatStateDispatch({ type: 'receive', payload: response.messages }); + chatStateDispatch({ + type: 'receive', + payload: { + messages: response.messages, + interactions: response.interactions, + }, + }); } catch (error) { if (abortController.signal.aborted) return; chatStateDispatch({ type: 'error', payload: error }); @@ -79,7 +90,13 @@ export const useChatActions = (): AssistantActions => { } const session = await core.services.sessionLoad.load(sessionId); if (session) { - chatStateDispatch({ type: 'receive', payload: session.messages }); + chatStateDispatch({ + type: 'receive', + payload: { + messages: session.messages, + interactions: session.interactions, + }, + }); } }; @@ -156,7 +173,13 @@ export const useChatActions = (): AssistantActions => { if (abortController.signal.aborted) { return; } - chatStateDispatch({ type: 'receive', payload: response.messages }); + chatStateDispatch({ + type: 'receive', + payload: { + messages: response.messages, + interactions: response.interactions, + }, + }); } catch (error) { if (abortController.signal.aborted) { return; diff --git a/public/hooks/use_chat_state.tsx b/public/hooks/use_chat_state.tsx index 9f3a66f5..13bb8542 100644 --- a/public/hooks/use_chat_state.tsx +++ b/public/hooks/use_chat_state.tsx @@ -5,10 +5,11 @@ import { produce } from 'immer'; import React, { useContext, useMemo, useReducer } from 'react'; -import { IMessage } from '../../common/types/chat_saved_object_attributes'; +import { IMessage, Interaction } from '../../common/types/chat_saved_object_attributes'; interface ChatState { messages: IMessage[]; + interactions: Interaction[]; llmResponding: boolean; llmError?: Error; } @@ -18,7 +19,13 @@ type ChatStateAction = | { type: 'abort' } | { type: 'reset' } | { type: 'send'; payload: IMessage } - | { type: 'receive'; payload: ChatState['messages'] } + | { + type: 'receive'; + payload: { + messages: ChatState['messages']; + interactions: ChatState['interactions']; + }; + } | { type: 'error'; payload: NonNullable | { body: NonNullable }; @@ -31,6 +38,7 @@ interface IChatStateContext { const ChatStateContext = React.createContext(null); const initialState: ChatState = { + interactions: [], messages: [], llmResponding: false, }; @@ -48,7 +56,8 @@ const chatStateReducer: React.Reducer = (state, acti break; case 'receive': - draft.messages = action.payload; + draft.messages = action.payload.messages; + draft.interactions = action.payload.interactions; draft.llmResponding = false; draft.llmError = undefined; break; diff --git a/public/tabs/chat/chat_page.tsx b/public/tabs/chat/chat_page.tsx index f8101d2b..025ed2b1 100644 --- a/public/tabs/chat/chat_page.tsx +++ b/public/tabs/chat/chat_page.tsx @@ -31,7 +31,13 @@ export const ChatPage: React.FC = (props) => { } const session = await core.services.sessionLoad.load(chatContext.sessionId); if (session) { - chatStateDispatch({ type: 'receive', payload: session.messages }); + chatStateDispatch({ + type: 'receive', + payload: { + messages: session.messages, + interactions: session.interactions, + }, + }); } }, [chatContext.sessionId, chatStateDispatch]); diff --git a/public/tabs/chat/chat_page_content.tsx b/public/tabs/chat/chat_page_content.tsx index 16ebc5b7..359318db 100644 --- a/public/tabs/chat/chat_page_content.tsx +++ b/public/tabs/chat/chat_page_content.tsx @@ -14,7 +14,11 @@ import { EuiText, } from '@elastic/eui'; import React, { useLayoutEffect, useRef } from 'react'; -import { IMessage, ISuggestedAction } from '../../../common/types/chat_saved_object_attributes'; +import { + IMessage, + ISuggestedAction, + Interaction, +} from '../../../common/types/chat_saved_object_attributes'; import { TermsAndConditions } from '../../components/terms_and_conditions'; import { useChatContext } from '../../contexts/chat_context'; import { useChatState } from '../../hooks/use_chat_state'; @@ -120,6 +124,13 @@ export const ChatPageContent: React.FC = React.memo((props // Only show suggestion on llm outputs after last user input const showSuggestions = i > lastInputIndex; + let interaction: Interaction | undefined; + if (message.type === 'output' && message.traceId) { + interaction = chatState.interactions.find( + (item) => item.interaction_id === message.traceId + ); + } + return ( @@ -129,6 +140,7 @@ export const ChatPageContent: React.FC = React.memo((props showRegenerate={isLatestOutput} shouldActionBarVisibleOnHover={!isLatestOutput} onRegenerate={chatActions.regenerate} + interaction={interaction} > {/* */} diff --git a/public/tabs/chat/messages/message_bubble.tsx b/public/tabs/chat/messages/message_bubble.tsx index 451e36af..4c0133ac 100644 --- a/public/tabs/chat/messages/message_bubble.tsx +++ b/public/tabs/chat/messages/message_bubble.tsx @@ -19,7 +19,11 @@ import React, { useCallback } from 'react'; import { IconType } from '@elastic/eui/src/components/icon/icon'; import cx from 'classnames'; import chatIcon from '../../../assets/chat.svg'; -import { IMessage, IOutput } from '../../../../common/types/chat_saved_object_attributes'; +import { + IMessage, + IOutput, + Interaction, +} from '../../../../common/types/chat_saved_object_attributes'; import { useFeedback } from '../../../hooks/use_feed_back'; type MessageBubbleProps = { @@ -30,6 +34,7 @@ type MessageBubbleProps = { } & ( | { message: IMessage; + interaction?: Interaction; } | { loading: boolean; diff --git a/server/olly/utils/output_builders/__tests__/build_outputs.test.ts b/server/olly/utils/output_builders/__tests__/build_outputs.test.ts index c796f35a..77815a06 100644 --- a/server/olly/utils/output_builders/__tests__/build_outputs.test.ts +++ b/server/olly/utils/output_builders/__tests__/build_outputs.test.ts @@ -35,7 +35,7 @@ describe('build outputs', () => { it('sanitizes markdown outputs', () => { const outputs = buildOutputs( 'test question', - 'normal text image !!!!!!![](https://badurl) ![image](https://badurl) [good link](https://link)', + 'normal text image !!!!!!![](http://evil.com/) ![image](http://evil.com/) [good link](https://link)', 'test-session', {}, [] @@ -43,7 +43,7 @@ describe('build outputs', () => { expect(outputs).toEqual([ { content: - 'normal text [](https://badurl) [image](https://badurl) [good link](https://link)', + 'normal text [](http://evil.com/) [image](http://evil.com/) [good link](https://link)', contentType: 'markdown', traceId: 'test-session', suggestedActions: [], diff --git a/server/parsers/basic_input_output_parser.ts b/server/parsers/basic_input_output_parser.ts index 7880e4e4..7febe7b7 100644 --- a/server/parsers/basic_input_output_parser.ts +++ b/server/parsers/basic_input_output_parser.ts @@ -3,8 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { IInput, IOutput } from '../../common/types/chat_saved_object_attributes'; -import { Interaction } from '../types'; +import { IInput, IOutput, Interaction } from '../../common/types/chat_saved_object_attributes'; export const BasicInputOutputParser = { order: 0, @@ -20,7 +19,7 @@ export const BasicInputOutputParser = { type: 'output', contentType: 'markdown', content: interaction.response, - traceId: interaction.parent_interaction_id, + traceId: interaction.interaction_id, }, ]; return [inputItem, ...outputItems]; diff --git a/server/routes/chat_routes.ts b/server/routes/chat_routes.ts index 946df488..632babe4 100644 --- a/server/routes/chat_routes.ts +++ b/server/routes/chat_routes.ts @@ -137,6 +137,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) messages: finalMessage.messages, sessionId: outputs.memoryId, title: finalMessage.title, + interactions: finalMessage.interactions, }, }); } catch (error) { diff --git a/server/services/storage/agent_framework_storage_service.ts b/server/services/storage/agent_framework_storage_service.ts index 3448fd67..5393105e 100644 --- a/server/services/storage/agent_framework_storage_service.ts +++ b/server/services/storage/agent_framework_storage_service.ts @@ -6,15 +6,14 @@ import { ApiResponse } from '@opensearch-project/opensearch/.'; import { OpenSearchClient } from '../../../../../src/core/server'; import { - IInput, IMessage, - IOutput, ISession, ISessionFindResponse, + Interaction, } from '../../../common/types/chat_saved_object_attributes'; import { GetSessionsSchema } from '../../routes/chat_routes'; import { StorageService } from './storage_service'; -import { Interaction, MessageParser } from '../../types'; +import { MessageParser } from '../../types'; import { MessageParserRunner } from '../../utils/message_parser_runner'; export class AgentFrameworkStorageService implements StorageService { @@ -55,6 +54,7 @@ export class AgentFrameworkStorageService implements StorageService { createdTimeMs: Date.now(), updatedTimeMs: Date.now(), messages: finalMessages, + interactions: finalInteractions, }; } diff --git a/server/services/storage/saved_objects_storage_service.ts b/server/services/storage/saved_objects_storage_service.ts index 78fcffb4..f85bba48 100644 --- a/server/services/storage/saved_objects_storage_service.ts +++ b/server/services/storage/saved_objects_storage_service.ts @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { MessageParser } from '../../types'; import { SavedObjectsClientContract } from '../../../../../src/core/server'; import { CHAT_SAVED_OBJECT, @@ -15,7 +16,10 @@ import { GetSessionsSchema } from '../../routes/chat_routes'; import { StorageService } from './storage_service'; export class SavedObjectsStorageService implements StorageService { - constructor(private readonly client: SavedObjectsClientContract) {} + constructor( + private readonly client: SavedObjectsClientContract, + private readonly messageParsers: MessageParser[] + ) {} private convertUpdatedTimeField(updatedAt: string | undefined) { return updatedAt ? new Date(updatedAt).getTime() : undefined; diff --git a/server/types.ts b/server/types.ts index 9d9e7520..5b692036 100644 --- a/server/types.ts +++ b/server/types.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { IMessage } from '../common/types/chat_saved_object_attributes'; +import { IMessage, Interaction } from '../common/types/chat_saved_object_attributes'; import { ILegacyClusterClient, Logger } from '../../../src/core/server'; // eslint-disable-next-line @typescript-eslint/no-empty-interface @@ -11,16 +11,6 @@ export interface AssistantPluginSetup {} // eslint-disable-next-line @typescript-eslint/no-empty-interface export interface AssistantPluginStart {} -export interface Interaction { - input: string; - response: string; - conversation_id: string; - interaction_id: string; - create_time: string; - additional_info: Record; - parent_interaction_id: string; -} - export interface MessageParser { /** * The id of the parser, should be unique among the parsers. diff --git a/server/utils/message_parser_runner.test.ts b/server/utils/message_parser_runner.test.ts index e931f47a..ca4032a6 100644 --- a/server/utils/message_parser_runner.test.ts +++ b/server/utils/message_parser_runner.test.ts @@ -26,6 +26,11 @@ describe('MessageParserRunner', () => { await messageParserRunner.run({ response: 'output', input: 'input', + conversation_id: '', + interaction_id: '', + create_time: '', + additional_info: {}, + parent_interaction_id: '' }) ).toEqual([ { @@ -95,6 +100,11 @@ describe('MessageParserRunner', () => { await messageParserRunner.run({ response: 'output', input: 'input', + conversation_id: '', + interaction_id: '', + create_time: '', + additional_info: {}, + parent_interaction_id: '' }) ).toEqual([ { @@ -144,6 +154,11 @@ describe('MessageParserRunner', () => { await messageParserRunner.run({ response: 'output', input: 'input', + conversation_id: '', + interaction_id: '', + create_time: '', + additional_info: {}, + parent_interaction_id: '' }) ).toEqual([]); }); diff --git a/server/utils/message_parser_runner.ts b/server/utils/message_parser_runner.ts index 2f5d7d59..60534247 100644 --- a/server/utils/message_parser_runner.ts +++ b/server/utils/message_parser_runner.ts @@ -3,8 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { IMessage } from '../../common/types/chat_saved_object_attributes'; -import { Interaction, MessageParser } from '../types'; +import { IMessage, Interaction } from '../../common/types/chat_saved_object_attributes'; +import { MessageParser } from '../types'; export class MessageParserRunner { constructor(private readonly messageParsers: MessageParser[]) {}