From 15dcaf111969c90c19389c904093deba3f94861b Mon Sep 17 00:00:00 2001 From: SuZhou-Joe Date: Wed, 22 Nov 2023 16:44:22 +0800 Subject: [PATCH] feat: change implementation of basic_input_output to built-in parser (#10) * feat: change implementation of basic_input_output to built-in parser Signed-off-by: SuZhou-Joe * feat: update CHANGELOG Signed-off-by: SuZhou-Joe * feat: enable build input-output message Signed-off-by: SuZhou-Joe * feat: sort interactions Signed-off-by: SuZhou-Joe * feat: remove useless code Signed-off-by: SuZhou-Joe * feat: use parent_interaction_id as traceId Signed-off-by: SuZhou-Joe --------- Signed-off-by: SuZhou-Joe --- CHANGELOG.md | 3 +- server/parsers/basic_input_output_parser.ts | 28 ++++++++++ server/plugin.ts | 25 +++++---- server/routes/chat_routes.ts | 24 ++++----- server/routes/index.ts | 4 +- server/services/chat/olly_chat_service.ts | 29 ++++------ .../agent_framework_storage_service.ts | 54 ++++++++++--------- server/types.ts | 1 + 8 files changed, 99 insertions(+), 69 deletions(-) create mode 100644 server/parsers/basic_input_output_parser.ts diff --git a/CHANGELOG.md b/CHANGELOG.md index 3577f8af..78c3fdb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,4 +4,5 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### 📈 Features/Enhancements -- Add support for registerMessageParser ([#5](https://github.com/opensearch-project/dashboards-assistant/pull/5)) \ No newline at end of file +- Add support for registerMessageParser ([#5](https://github.com/opensearch-project/dashboards-assistant/pull/5)) +- Change implementation of basic_input_output to built-in parser ([#10](https://github.com/opensearch-project/dashboards-assistant/pull/10)) \ No newline at end of file diff --git a/server/parsers/basic_input_output_parser.ts b/server/parsers/basic_input_output_parser.ts new file mode 100644 index 00000000..7880e4e4 --- /dev/null +++ b/server/parsers/basic_input_output_parser.ts @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { IInput, IOutput } from '../../common/types/chat_saved_object_attributes'; +import { Interaction } from '../types'; + +export const BasicInputOutputParser = { + order: 0, + id: 'output_message', + async parserProvider(interaction: Interaction) { + const inputItem: IInput = { + type: 'input', + contentType: 'text', + content: interaction.input, + }; + const outputItems: IOutput[] = [ + { + type: 'output', + contentType: 'markdown', + content: interaction.response, + traceId: interaction.parent_interaction_id, + }, + ]; + return [inputItem, ...outputItems]; + }, +}; diff --git a/server/plugin.ts b/server/plugin.ts index 47906ee0..745e3b0c 100644 --- a/server/plugin.ts +++ b/server/plugin.ts @@ -23,6 +23,7 @@ import { setupRoutes } from './routes/index'; import { chatSavedObject } from './saved_objects/chat_saved_object'; import { AssistantPluginSetup, AssistantPluginStart, MessageParser } from './types'; import { chatConfigSavedObject } from './saved_objects/chat_config_saved_object'; +import { BasicInputOutputParser } from './parsers/basic_input_output_parser'; export class AssistantPlugin implements Plugin { private readonly logger: Logger; @@ -55,7 +56,9 @@ export class AssistantPlugin implements Plugin { - const findItem = this.messageParsers.find((item) => item.id === messageParser.id); - if (findItem) { - throw new Error(`There is already a messageParser whose id is ${messageParser.id}`); - } + const registerMessageParser = (messageParser: MessageParser) => { + const findItem = this.messageParsers.find((item) => item.id === messageParser.id); + if (findItem) { + throw new Error(`There is already a messageParser whose id is ${messageParser.id}`); + } - this.messageParsers.push(messageParser); - }, + this.messageParsers.push(messageParser); + }; + + registerMessageParser(BasicInputOutputParser); + + return { + registerMessageParser, removeMessageParser: (parserId: MessageParser['id']) => { const findIndex = this.messageParsers.findIndex((item) => item.id === parserId); if (findIndex < 0) { diff --git a/server/routes/chat_routes.ts b/server/routes/chat_routes.ts index de4a4dbf..946df488 100644 --- a/server/routes/chat_routes.ts +++ b/server/routes/chat_routes.ts @@ -13,7 +13,6 @@ import { } from '../../../../src/core/server'; import { ASSISTANT_API } from '../../common/constants/llm'; import { OllyChatService } from '../services/chat/olly_chat_service'; -import { SavedObjectsStorageService } from '../services/storage/saved_objects_storage_service'; import { IMessage, IInput } from '../../common/types/chat_saved_object_attributes'; import { AgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service'; import { RoutesOptions } from '../types'; @@ -63,6 +62,7 @@ const regenerateRoute = { validate: { body: schema.object({ sessionId: schema.string(), + rootAgentId: schema.string(), }), }, }; @@ -105,9 +105,12 @@ const updateSessionRoute = { }, }; -export function registerChatRoutes(router: IRouter) { +export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) { const createStorageService = (context: RequestHandlerContext) => - new AgentFrameworkStorageService(context.core.opensearch.client.asCurrentUser); + new AgentFrameworkStorageService( + context.core.opensearch.client.asCurrentUser, + routeOptions.messageParsers + ); const createChatService = () => new OllyChatService(); router.post( @@ -117,15 +120,14 @@ export function registerChatRoutes(router: IRouter) { request, response ): Promise> => { - const { messages = [], input, sessionId: sessionIdInRequestBody } = request.body; + const { messages = [], input, sessionId: sessionIdInRequestBody, rootAgentId } = request.body; const storageService = createStorageService(context); const chatService = createChatService(); try { const outputs = await chatService.requestLLM( - { messages, input, sessionId: sessionIdInRequestBody }, - context, - request + { messages, input, sessionId: sessionIdInRequestBody, rootAgentId }, + context ); const sessionId = outputs.memoryId; const finalMessage = await storageService.getSession(sessionId); @@ -250,7 +252,7 @@ export function registerChatRoutes(router: IRouter) { request, response ): Promise> => { - const { sessionId } = request.body; + const { sessionId, rootAgentId } = request.body; const storageService = createStorageService(context); let messages: IMessage[] = []; const chatService = createChatService(); @@ -270,10 +272,8 @@ export function registerChatRoutes(router: IRouter) { try { const outputs = await chatService.requestLLM( - { messages, input, sessionId }, - context, - // @ts-ignore - request + { messages, input, sessionId, rootAgentId }, + context ); const title = input.content.substring(0, 50); const saveMessagesResponse = await storageService.saveMessages( diff --git a/server/routes/index.ts b/server/routes/index.ts index 52426cf2..093bb313 100644 --- a/server/routes/index.ts +++ b/server/routes/index.ts @@ -8,7 +8,7 @@ import { IRouter } from '../../../../src/core/server'; import { registerChatRoutes } from './chat_routes'; import { registerLangchainRoutes } from './langchain_routes'; -export function setupRoutes(router: IRouter) { - registerChatRoutes(router); +export function setupRoutes(router: IRouter, routeOptions: RoutesOptions) { + registerChatRoutes(router, routeOptions); registerLangchainRoutes(router); } diff --git a/server/services/chat/olly_chat_service.ts b/server/services/chat/olly_chat_service.ts index 68a92ed4..8dadc991 100644 --- a/server/services/chat/olly_chat_service.ts +++ b/server/services/chat/olly_chat_service.ts @@ -3,35 +3,30 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { Run } from 'langchain/callbacks'; import { v4 as uuid } from 'uuid'; import { ApiResponse } from '@opensearch-project/opensearch'; import { OpenSearchDashboardsRequest, RequestHandlerContext } from '../../../../../src/core/server'; import { IMessage, IInput } from '../../../common/types/chat_saved_object_attributes'; -import { convertToTraces } from '../../../common/utils/llm_chat/traces'; import { OpenSearchTracer } from '../../olly/callbacks/opensearch_tracer'; import { LLMModelFactory } from '../../olly/models/llm_model_factory'; import { PPLTools } from '../../olly/tools/tool_sets/ppl'; -import { buildOutputs } from '../../olly/utils/output_builders/build_outputs'; -import { AbortAgentExecutionSchema, LLMRequestSchema } from '../../routes/chat_routes'; import { PPLGenerationRequestSchema } from '../../routes/langchain_routes'; import { ChatService } from './chat_service'; +import { LLMRequestSchema } from '../../routes/chat_routes'; const MEMORY_ID_FIELD = 'memory_id'; -const RESPONSE_FIELD = 'response'; export class OllyChatService implements ChatService { static abortControllers: Map = new Map(); public async requestLLM( - payload: { messages: IMessage[]; input: IInput; sessionId?: string }, - context: RequestHandlerContext, - request: OpenSearchDashboardsRequest + payload: { messages: IMessage[]; input: IInput; sessionId?: string; rootAgentId: string }, + context: RequestHandlerContext ): Promise<{ messages: IMessage[]; memoryId: string; }> { - const { input, sessionId, rootAgentId } = request.body; + const { input, sessionId, rootAgentId } = payload; const opensearchClient = context.core.opensearch.client.asCurrentUser; if (payload.sessionId) { @@ -39,8 +34,6 @@ export class OllyChatService implements ChatService { } try { - const runs: Run[] = []; - /** * Wait for an API to fetch root agent id. */ @@ -66,17 +59,15 @@ export class OllyChatService implements ChatService { output: Array<{ name: string; result?: string }>; }>; }>; - const outputBody = - agentFrameworkResponse.body.inference_results?.[0]?.output || - agentFrameworkResponse.body.inference_results?.[0]?.output; + const outputBody = agentFrameworkResponse.body.inference_results?.[0]?.output; const memoryIdItem = outputBody?.find((item) => item.name === MEMORY_ID_FIELD); - const reversedOutputBody = [...outputBody].reverse(); - const finalAnswerItem = reversedOutputBody.find((item) => item.name === RESPONSE_FIELD); - - const agentFrameworkAnswer = finalAnswerItem?.result || ''; return { - messages: buildOutputs(input.content, agentFrameworkAnswer, '', {}, convertToTraces(runs)), + /** + * Interactions will be stored in Agent framework, + * thus we do not need to return the latest message back. + */ + messages: [], memoryId: memoryIdItem?.result || '', }; } catch (error) { diff --git a/server/services/storage/agent_framework_storage_service.ts b/server/services/storage/agent_framework_storage_service.ts index 05165acc..3448fd67 100644 --- a/server/services/storage/agent_framework_storage_service.ts +++ b/server/services/storage/agent_framework_storage_service.ts @@ -5,7 +5,6 @@ import { ApiResponse } from '@opensearch-project/opensearch/.'; import { OpenSearchClient } from '../../../../../src/core/server'; -import { LLM_INDEX } from '../../../common/constants/llm'; import { IInput, IMessage, @@ -15,44 +14,47 @@ import { } from '../../../common/types/chat_saved_object_attributes'; import { GetSessionsSchema } from '../../routes/chat_routes'; import { StorageService } from './storage_service'; +import { Interaction, MessageParser } from '../../types'; +import { MessageParserRunner } from '../../utils/message_parser_runner'; export class AgentFrameworkStorageService implements StorageService { - constructor(private readonly client: OpenSearchClient) {} + constructor( + private readonly client: OpenSearchClient, + private readonly messageParsers: MessageParser[] = [] + ) {} async getSession(sessionId: string): Promise { const session = (await this.client.transport.request({ method: 'GET', path: `/_plugins/_ml/memory/conversation/${sessionId}/_list`, })) as ApiResponse<{ - interactions: Array<{ - input: string; - response: string; - parent_interaction_id: string; - interaction_id: string; - }>; + interactions: Interaction[]; }>; + const messageParserRunner = new MessageParserRunner(this.messageParsers); + const finalInteractions: Interaction[] = [...session.body.interactions]; + + /** + * Sort interactions according to create_time + */ + finalInteractions.sort((interactionA, interactionB) => { + const { create_time: createTimeA } = interactionA; + const { create_time: createTimeB } = interactionB; + const createTimeMSA = +new Date(createTimeA); + const createTimeMSB = +new Date(createTimeB); + if (isNaN(createTimeMSA) || isNaN(createTimeMSB)) { + return 0; + } + return createTimeMSA - createTimeMSB; + }); + let finalMessages: IMessage[] = []; + for (const interaction of finalInteractions) { + finalMessages = [...finalMessages, ...(await messageParserRunner.run(interaction))]; + } return { title: 'test', version: 1, createdTimeMs: Date.now(), updatedTimeMs: Date.now(), - messages: session.body.interactions - .filter((item) => !item.parent_interaction_id) - .reduce((total, current) => { - const inputItem: IInput = { - type: 'input', - contentType: 'text', - content: current.input, - }; - const outputItems: IOutput[] = [ - { - type: 'output', - contentType: 'markdown', - content: current.response, - traceId: current.interaction_id, - }, - ]; - return [...total, inputItem, ...outputItems]; - }, [] as IMessage[]), + messages: finalMessages, }; } diff --git a/server/types.ts b/server/types.ts index 625f51ba..9d9e7520 100644 --- a/server/types.ts +++ b/server/types.ts @@ -18,6 +18,7 @@ export interface Interaction { interaction_id: string; create_time: string; additional_info: Record; + parent_interaction_id: string; } export interface MessageParser {