Skip to content

Commit

Permalink
Add interaction into message props (#12)
Browse files Browse the repository at this point in the history
* feat: add mechannism to register messageParser

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: update

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: add interaction into message_bubble.tsx

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: update CHANGELOG

Signed-off-by: SuZhou-Joe <[email protected]>

* fix: lint checker

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: remove useless code

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: update

Signed-off-by: SuZhou-Joe <[email protected]>

---------

Signed-off-by: SuZhou-Joe <[email protected]>
  • Loading branch information
SuZhou-Joe committed Dec 1, 2023
1 parent c3d0027 commit 1e47afe
Show file tree
Hide file tree
Showing 15 changed files with 109 additions and 33 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
### 📈 Features/Enhancements

- Add support for registerMessageParser ([#5](https://github.com/opensearch-project/dashboards-assistant/pull/5))
- Change implementation of basic_input_output to built-in parser ([#10](https://github.com/opensearch-project/dashboards-assistant/pull/10))
- Change implementation of basic_input_output to built-in parser ([#10](https://github.com/opensearch-project/dashboards-assistant/pull/10))
- Add interactions into ChatState and pass specific interaction into message_bubble ([#12](https://github.com/opensearch-project/dashboards-assistant/pull/12))
11 changes: 11 additions & 0 deletions common/types/chat_saved_object_attributes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,23 @@
export const CHAT_SAVED_OBJECT = 'assistant-chat';
export const SAVED_OBJECT_VERSION = 1;

export interface Interaction {
input: string;
response: string;
conversation_id: string;
interaction_id: string;
create_time: string;
additional_info: Record<string, unknown>;
parent_interaction_id?: string;
}

export interface ISession {
title: string;
version: number;
createdTimeMs: number;
updatedTimeMs: number;
messages: IMessage[];
interactions: Interaction[];
}

export interface ISessionFindResponse {
Expand Down
31 changes: 27 additions & 4 deletions public/hooks/use_chat_actions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
*/

import { ASSISTANT_API } from '../../common/constants/llm';
import { IMessage, ISuggestedAction } from '../../common/types/chat_saved_object_attributes';
import {
IMessage,
ISuggestedAction,
Interaction,
} from '../../common/types/chat_saved_object_attributes';
import { useChatContext } from '../contexts/chat_context';
import { useCore } from '../contexts/core_context';
import { AssistantActions } from '../types';
Expand All @@ -14,6 +18,7 @@ interface SendResponse {
sessionId: string;
title: string;
messages: IMessage[];
interactions: Interaction[];
}

interface SetParagraphResponse {
Expand Down Expand Up @@ -56,7 +61,13 @@ export const useChatActions = (): AssistantActions => {
if (!chatContext.title) {
chatContext.setTitle(response.title);
}
chatStateDispatch({ type: 'receive', payload: response.messages });
chatStateDispatch({
type: 'receive',
payload: {
messages: response.messages,
interactions: response.interactions,
},
});
} catch (error) {
if (abortController.signal.aborted) return;
chatStateDispatch({ type: 'error', payload: error });
Expand All @@ -79,7 +90,13 @@ export const useChatActions = (): AssistantActions => {
}
const session = await core.services.sessionLoad.load(sessionId);
if (session) {
chatStateDispatch({ type: 'receive', payload: session.messages });
chatStateDispatch({
type: 'receive',
payload: {
messages: session.messages,
interactions: session.interactions,
},
});
}
};

Expand Down Expand Up @@ -156,7 +173,13 @@ export const useChatActions = (): AssistantActions => {
if (abortController.signal.aborted) {
return;
}
chatStateDispatch({ type: 'receive', payload: response.messages });
chatStateDispatch({
type: 'receive',
payload: {
messages: response.messages,
interactions: response.interactions,
},
});
} catch (error) {
if (abortController.signal.aborted) {
return;
Expand Down
15 changes: 12 additions & 3 deletions public/hooks/use_chat_state.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import { produce } from 'immer';
import React, { useContext, useMemo, useReducer } from 'react';
import { IMessage } from '../../common/types/chat_saved_object_attributes';
import { IMessage, Interaction } from '../../common/types/chat_saved_object_attributes';

interface ChatState {
messages: IMessage[];
interactions: Interaction[];
llmResponding: boolean;
llmError?: Error;
}
Expand All @@ -18,7 +19,13 @@ type ChatStateAction =
| { type: 'abort' }
| { type: 'reset' }
| { type: 'send'; payload: IMessage }
| { type: 'receive'; payload: ChatState['messages'] }
| {
type: 'receive';
payload: {
messages: ChatState['messages'];
interactions: ChatState['interactions'];
};
}
| {
type: 'error';
payload: NonNullable<ChatState['llmError']> | { body: NonNullable<ChatState['llmError']> };
Expand All @@ -31,6 +38,7 @@ interface IChatStateContext {
const ChatStateContext = React.createContext<IChatStateContext | null>(null);

const initialState: ChatState = {
interactions: [],
messages: [],
llmResponding: false,
};
Expand All @@ -48,7 +56,8 @@ const chatStateReducer: React.Reducer<ChatState, ChatStateAction> = (state, acti
break;

case 'receive':
draft.messages = action.payload;
draft.messages = action.payload.messages;
draft.interactions = action.payload.interactions;
draft.llmResponding = false;
draft.llmError = undefined;
break;
Expand Down
8 changes: 7 additions & 1 deletion public/tabs/chat/chat_page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ export const ChatPage: React.FC<ChatPageProps> = (props) => {
}
const session = await core.services.sessionLoad.load(chatContext.sessionId);
if (session) {
chatStateDispatch({ type: 'receive', payload: session.messages });
chatStateDispatch({
type: 'receive',
payload: {
messages: session.messages,
interactions: session.interactions,
},
});
}
}, [chatContext.sessionId, chatStateDispatch]);

Expand Down
14 changes: 13 additions & 1 deletion public/tabs/chat/chat_page_content.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ import {
EuiText,
} from '@elastic/eui';
import React, { useLayoutEffect, useRef } from 'react';
import { IMessage, ISuggestedAction } from '../../../common/types/chat_saved_object_attributes';
import {
IMessage,
ISuggestedAction,
Interaction,
} from '../../../common/types/chat_saved_object_attributes';
import { TermsAndConditions } from '../../components/terms_and_conditions';
import { useChatContext } from '../../contexts/chat_context';
import { useChatState } from '../../hooks/use_chat_state';
Expand Down Expand Up @@ -120,6 +124,13 @@ export const ChatPageContent: React.FC<ChatPageContentProps> = React.memo((props
// Only show suggestion on llm outputs after last user input
const showSuggestions = i > lastInputIndex;

let interaction: Interaction | undefined;
if (message.type === 'output' && message.traceId) {
interaction = chatState.interactions.find(
(item) => item.interaction_id === message.traceId
);
}

return (
<React.Fragment key={i}>
<ToolsUsed message={message} />
Expand All @@ -129,6 +140,7 @@ export const ChatPageContent: React.FC<ChatPageContentProps> = React.memo((props
showRegenerate={isLatestOutput}
shouldActionBarVisibleOnHover={!isLatestOutput}
onRegenerate={chatActions.regenerate}
interaction={interaction}
>
<MessageContent message={message} />
{/* <MessageFooter message={message} previousInput={findPreviousInput(array, i)} />*/}
Expand Down
7 changes: 6 additions & 1 deletion public/tabs/chat/messages/message_bubble.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ import React, { useCallback } from 'react';
import { IconType } from '@elastic/eui/src/components/icon/icon';
import cx from 'classnames';
import chatIcon from '../../../assets/chat.svg';
import { IMessage, IOutput } from '../../../../common/types/chat_saved_object_attributes';
import {
IMessage,
IOutput,
Interaction,
} from '../../../../common/types/chat_saved_object_attributes';
import { useFeedback } from '../../../hooks/use_feed_back';

type MessageBubbleProps = {
Expand All @@ -30,6 +34,7 @@ type MessageBubbleProps = {
} & (
| {
message: IMessage;
interaction?: Interaction;
}
| {
loading: boolean;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ describe('build outputs', () => {
it('sanitizes markdown outputs', () => {
const outputs = buildOutputs(
'test question',
'normal text<b onmouseover=alert("XSS testing!")></b> <img src="image.jpg" alt="image" width="500" height="600"> !!!!!!![](https://badurl) ![image](https://badurl) [good link](https://link)',
'normal text<b onmouseover=alert("XSS testing!")></b> <img src="image.jpg" alt="image" width="500" height="600"> !!!!!!![](http://evil.com/) ![image](http://evil.com/) [good link](https://link)',
'test-session',
{},
[]
);
expect(outputs).toEqual([
{
content:
'normal text<b></b> [](https://badurl) [image](https://badurl) [good link](https://link)',
'normal text<b></b> [](http://evil.com/) [image](http://evil.com/) [good link](https://link)',
contentType: 'markdown',
traceId: 'test-session',
suggestedActions: [],
Expand Down
5 changes: 2 additions & 3 deletions server/parsers/basic_input_output_parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

import { IInput, IOutput } from '../../common/types/chat_saved_object_attributes';
import { Interaction } from '../types';
import { IInput, IOutput, Interaction } from '../../common/types/chat_saved_object_attributes';

export const BasicInputOutputParser = {
order: 0,
Expand All @@ -20,7 +19,7 @@ export const BasicInputOutputParser = {
type: 'output',
contentType: 'markdown',
content: interaction.response,
traceId: interaction.parent_interaction_id,
traceId: interaction.interaction_id,
},
];
return [inputItem, ...outputItems];
Expand Down
1 change: 1 addition & 0 deletions server/routes/chat_routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
messages: finalMessage.messages,
sessionId: outputs.memoryId,
title: finalMessage.title,
interactions: finalMessage.interactions,
},
});
} catch (error) {
Expand Down
6 changes: 3 additions & 3 deletions server/services/storage/agent_framework_storage_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
import { ApiResponse } from '@opensearch-project/opensearch/.';
import { OpenSearchClient } from '../../../../../src/core/server';
import {
IInput,
IMessage,
IOutput,
ISession,
ISessionFindResponse,
Interaction,
} from '../../../common/types/chat_saved_object_attributes';
import { GetSessionsSchema } from '../../routes/chat_routes';
import { StorageService } from './storage_service';
import { Interaction, MessageParser } from '../../types';
import { MessageParser } from '../../types';
import { MessageParserRunner } from '../../utils/message_parser_runner';

export class AgentFrameworkStorageService implements StorageService {
Expand Down Expand Up @@ -55,6 +54,7 @@ export class AgentFrameworkStorageService implements StorageService {
createdTimeMs: Date.now(),
updatedTimeMs: Date.now(),
messages: finalMessages,
interactions: finalInteractions,
};
}

Expand Down
6 changes: 5 additions & 1 deletion server/services/storage/saved_objects_storage_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

import { MessageParser } from '../../types';
import { SavedObjectsClientContract } from '../../../../../src/core/server';
import {
CHAT_SAVED_OBJECT,
Expand All @@ -15,7 +16,10 @@ import { GetSessionsSchema } from '../../routes/chat_routes';
import { StorageService } from './storage_service';

export class SavedObjectsStorageService implements StorageService {
constructor(private readonly client: SavedObjectsClientContract) {}
constructor(
private readonly client: SavedObjectsClientContract,
private readonly messageParsers: MessageParser[]
) {}

private convertUpdatedTimeField(updatedAt: string | undefined) {
return updatedAt ? new Date(updatedAt).getTime() : undefined;
Expand Down
12 changes: 1 addition & 11 deletions server/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,14 @@
* SPDX-License-Identifier: Apache-2.0
*/

import { IMessage } from '../common/types/chat_saved_object_attributes';
import { IMessage, Interaction } from '../common/types/chat_saved_object_attributes';
import { ILegacyClusterClient, Logger } from '../../../src/core/server';

// eslint-disable-next-line @typescript-eslint/no-empty-interface
export interface AssistantPluginSetup {}
// eslint-disable-next-line @typescript-eslint/no-empty-interface
export interface AssistantPluginStart {}

export interface Interaction {
input: string;
response: string;
conversation_id: string;
interaction_id: string;
create_time: string;
additional_info: Record<string, unknown>;
parent_interaction_id: string;
}

export interface MessageParser {
/**
* The id of the parser, should be unique among the parsers.
Expand Down
15 changes: 15 additions & 0 deletions server/utils/message_parser_runner.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ describe('MessageParserRunner', () => {
await messageParserRunner.run({
response: 'output',
input: 'input',
conversation_id: '',
interaction_id: '',
create_time: '',
additional_info: {},
parent_interaction_id: ''
})
).toEqual([
{
Expand Down Expand Up @@ -95,6 +100,11 @@ describe('MessageParserRunner', () => {
await messageParserRunner.run({
response: 'output',
input: 'input',
conversation_id: '',
interaction_id: '',
create_time: '',
additional_info: {},
parent_interaction_id: ''
})
).toEqual([
{
Expand Down Expand Up @@ -144,6 +154,11 @@ describe('MessageParserRunner', () => {
await messageParserRunner.run({
response: 'output',
input: 'input',
conversation_id: '',
interaction_id: '',
create_time: '',
additional_info: {},
parent_interaction_id: ''
})
).toEqual([]);
});
Expand Down
4 changes: 2 additions & 2 deletions server/utils/message_parser_runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
* SPDX-License-Identifier: Apache-2.0
*/

import { IMessage } from '../../common/types/chat_saved_object_attributes';
import { Interaction, MessageParser } from '../types';
import { IMessage, Interaction } from '../../common/types/chat_saved_object_attributes';
import { MessageParser } from '../types';

export class MessageParserRunner {
constructor(private readonly messageParsers: MessageParser[]) {}
Expand Down

0 comments on commit 1e47afe

Please sign in to comment.