Skip to content

Commit

Permalink
feat: change implementation of basic_input_output to built-in parser (#…
Browse files Browse the repository at this point in the history
…10)

* feat: change implementation of basic_input_output to built-in parser

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: update CHANGELOG

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: enable build input-output message

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: sort interactions

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: remove useless code

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: use parent_interaction_id as traceId

Signed-off-by: SuZhou-Joe <[email protected]>

---------

Signed-off-by: SuZhou-Joe <[email protected]>
  • Loading branch information
SuZhou-Joe committed Dec 1, 2023
1 parent 701a148 commit c3d0027
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 69 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

### 📈 Features/Enhancements

- Add support for registerMessageParser ([#5](https://github.com/opensearch-project/dashboards-assistant/pull/5))
- Add support for registerMessageParser ([#5](https://github.com/opensearch-project/dashboards-assistant/pull/5))
- Change implementation of basic_input_output to built-in parser ([#10](https://github.com/opensearch-project/dashboards-assistant/pull/10))
28 changes: 28 additions & 0 deletions server/parsers/basic_input_output_parser.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { IInput, IOutput } from '../../common/types/chat_saved_object_attributes';
import { Interaction } from '../types';

export const BasicInputOutputParser = {
order: 0,
id: 'output_message',
async parserProvider(interaction: Interaction) {
const inputItem: IInput = {
type: 'input',
contentType: 'text',
content: interaction.input,
};
const outputItems: IOutput[] = [
{
type: 'output',
contentType: 'markdown',
content: interaction.response,
traceId: interaction.parent_interaction_id,
},
];
return [inputItem, ...outputItems];
},
};
25 changes: 16 additions & 9 deletions server/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import { setupRoutes } from './routes/index';
import { chatSavedObject } from './saved_objects/chat_saved_object';
import { AssistantPluginSetup, AssistantPluginStart, MessageParser } from './types';
import { chatConfigSavedObject } from './saved_objects/chat_config_saved_object';
import { BasicInputOutputParser } from './parsers/basic_input_output_parser';

export class AssistantPlugin implements Plugin<AssistantPluginSetup, AssistantPluginStart> {
private readonly logger: Logger;
Expand Down Expand Up @@ -55,7 +56,9 @@ export class AssistantPlugin implements Plugin<AssistantPluginSetup, AssistantPl
});

// Register server side APIs
setupRoutes(router);
setupRoutes(router, {
messageParsers: this.messageParsers,
});

core.savedObjects.registerType(chatSavedObject);
core.savedObjects.registerType(chatConfigSavedObject);
Expand All @@ -66,15 +69,19 @@ export class AssistantPlugin implements Plugin<AssistantPluginSetup, AssistantPl
},
}));

return {
registerMessageParser: (messageParser: MessageParser) => {
const findItem = this.messageParsers.find((item) => item.id === messageParser.id);
if (findItem) {
throw new Error(`There is already a messageParser whose id is ${messageParser.id}`);
}
const registerMessageParser = (messageParser: MessageParser) => {
const findItem = this.messageParsers.find((item) => item.id === messageParser.id);
if (findItem) {
throw new Error(`There is already a messageParser whose id is ${messageParser.id}`);
}

this.messageParsers.push(messageParser);
},
this.messageParsers.push(messageParser);
};

registerMessageParser(BasicInputOutputParser);

return {
registerMessageParser,
removeMessageParser: (parserId: MessageParser['id']) => {
const findIndex = this.messageParsers.findIndex((item) => item.id === parserId);
if (findIndex < 0) {
Expand Down
24 changes: 12 additions & 12 deletions server/routes/chat_routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import {
} from '../../../../src/core/server';
import { ASSISTANT_API } from '../../common/constants/llm';
import { OllyChatService } from '../services/chat/olly_chat_service';
import { SavedObjectsStorageService } from '../services/storage/saved_objects_storage_service';
import { IMessage, IInput } from '../../common/types/chat_saved_object_attributes';
import { AgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service';
import { RoutesOptions } from '../types';
Expand Down Expand Up @@ -63,6 +62,7 @@ const regenerateRoute = {
validate: {
body: schema.object({
sessionId: schema.string(),
rootAgentId: schema.string(),
}),
},
};
Expand Down Expand Up @@ -105,9 +105,12 @@ const updateSessionRoute = {
},
};

export function registerChatRoutes(router: IRouter) {
export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) {
const createStorageService = (context: RequestHandlerContext) =>
new AgentFrameworkStorageService(context.core.opensearch.client.asCurrentUser);
new AgentFrameworkStorageService(
context.core.opensearch.client.asCurrentUser,
routeOptions.messageParsers
);
const createChatService = () => new OllyChatService();

router.post(
Expand All @@ -117,15 +120,14 @@ export function registerChatRoutes(router: IRouter) {
request,
response
): Promise<IOpenSearchDashboardsResponse<HttpResponsePayload | ResponseError>> => {
const { messages = [], input, sessionId: sessionIdInRequestBody } = request.body;
const { messages = [], input, sessionId: sessionIdInRequestBody, rootAgentId } = request.body;
const storageService = createStorageService(context);
const chatService = createChatService();

try {
const outputs = await chatService.requestLLM(
{ messages, input, sessionId: sessionIdInRequestBody },
context,
request
{ messages, input, sessionId: sessionIdInRequestBody, rootAgentId },
context
);
const sessionId = outputs.memoryId;
const finalMessage = await storageService.getSession(sessionId);
Expand Down Expand Up @@ -250,7 +252,7 @@ export function registerChatRoutes(router: IRouter) {
request,
response
): Promise<IOpenSearchDashboardsResponse<HttpResponsePayload | ResponseError>> => {
const { sessionId } = request.body;
const { sessionId, rootAgentId } = request.body;
const storageService = createStorageService(context);
let messages: IMessage[] = [];
const chatService = createChatService();
Expand All @@ -270,10 +272,8 @@ export function registerChatRoutes(router: IRouter) {

try {
const outputs = await chatService.requestLLM(
{ messages, input, sessionId },
context,
// @ts-ignore
request
{ messages, input, sessionId, rootAgentId },
context
);
const title = input.content.substring(0, 50);
const saveMessagesResponse = await storageService.saveMessages(
Expand Down
4 changes: 2 additions & 2 deletions server/routes/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { IRouter } from '../../../../src/core/server';
import { registerChatRoutes } from './chat_routes';
import { registerLangchainRoutes } from './langchain_routes';

export function setupRoutes(router: IRouter) {
registerChatRoutes(router);
export function setupRoutes(router: IRouter, routeOptions: RoutesOptions) {
registerChatRoutes(router, routeOptions);
registerLangchainRoutes(router);
}
29 changes: 10 additions & 19 deletions server/services/chat/olly_chat_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,37 @@
* SPDX-License-Identifier: Apache-2.0
*/

import { Run } from 'langchain/callbacks';
import { v4 as uuid } from 'uuid';
import { ApiResponse } from '@opensearch-project/opensearch';
import { OpenSearchDashboardsRequest, RequestHandlerContext } from '../../../../../src/core/server';
import { IMessage, IInput } from '../../../common/types/chat_saved_object_attributes';
import { convertToTraces } from '../../../common/utils/llm_chat/traces';
import { OpenSearchTracer } from '../../olly/callbacks/opensearch_tracer';
import { LLMModelFactory } from '../../olly/models/llm_model_factory';
import { PPLTools } from '../../olly/tools/tool_sets/ppl';
import { buildOutputs } from '../../olly/utils/output_builders/build_outputs';
import { AbortAgentExecutionSchema, LLMRequestSchema } from '../../routes/chat_routes';
import { PPLGenerationRequestSchema } from '../../routes/langchain_routes';
import { ChatService } from './chat_service';
import { LLMRequestSchema } from '../../routes/chat_routes';

const MEMORY_ID_FIELD = 'memory_id';
const RESPONSE_FIELD = 'response';

export class OllyChatService implements ChatService {
static abortControllers: Map<string, AbortController> = new Map();

public async requestLLM(
payload: { messages: IMessage[]; input: IInput; sessionId?: string },
context: RequestHandlerContext,
request: OpenSearchDashboardsRequest<unknown, unknown, LLMRequestSchema, 'post'>
payload: { messages: IMessage[]; input: IInput; sessionId?: string; rootAgentId: string },
context: RequestHandlerContext
): Promise<{
messages: IMessage[];
memoryId: string;
}> {
const { input, sessionId, rootAgentId } = request.body;
const { input, sessionId, rootAgentId } = payload;
const opensearchClient = context.core.opensearch.client.asCurrentUser;

if (payload.sessionId) {
OllyChatService.abortControllers.set(payload.sessionId, new AbortController());
}

try {
const runs: Run[] = [];

/**
* Wait for an API to fetch root agent id.
*/
Expand All @@ -66,17 +59,15 @@ export class OllyChatService implements ChatService {
output: Array<{ name: string; result?: string }>;
}>;
}>;
const outputBody =
agentFrameworkResponse.body.inference_results?.[0]?.output ||
agentFrameworkResponse.body.inference_results?.[0]?.output;
const outputBody = agentFrameworkResponse.body.inference_results?.[0]?.output;
const memoryIdItem = outputBody?.find((item) => item.name === MEMORY_ID_FIELD);
const reversedOutputBody = [...outputBody].reverse();
const finalAnswerItem = reversedOutputBody.find((item) => item.name === RESPONSE_FIELD);

const agentFrameworkAnswer = finalAnswerItem?.result || '';

return {
messages: buildOutputs(input.content, agentFrameworkAnswer, '', {}, convertToTraces(runs)),
/**
* Interactions will be stored in Agent framework,
* thus we do not need to return the latest message back.
*/
messages: [],
memoryId: memoryIdItem?.result || '',
};
} catch (error) {
Expand Down
54 changes: 28 additions & 26 deletions server/services/storage/agent_framework_storage_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import { ApiResponse } from '@opensearch-project/opensearch/.';
import { OpenSearchClient } from '../../../../../src/core/server';
import { LLM_INDEX } from '../../../common/constants/llm';
import {
IInput,
IMessage,
Expand All @@ -15,44 +14,47 @@ import {
} from '../../../common/types/chat_saved_object_attributes';
import { GetSessionsSchema } from '../../routes/chat_routes';
import { StorageService } from './storage_service';
import { Interaction, MessageParser } from '../../types';
import { MessageParserRunner } from '../../utils/message_parser_runner';

export class AgentFrameworkStorageService implements StorageService {
constructor(private readonly client: OpenSearchClient) {}
constructor(
private readonly client: OpenSearchClient,
private readonly messageParsers: MessageParser[] = []
) {}
async getSession(sessionId: string): Promise<ISession> {
const session = (await this.client.transport.request({
method: 'GET',
path: `/_plugins/_ml/memory/conversation/${sessionId}/_list`,
})) as ApiResponse<{
interactions: Array<{
input: string;
response: string;
parent_interaction_id: string;
interaction_id: string;
}>;
interactions: Interaction[];
}>;
const messageParserRunner = new MessageParserRunner(this.messageParsers);
const finalInteractions: Interaction[] = [...session.body.interactions];

/**
* Sort interactions according to create_time
*/
finalInteractions.sort((interactionA, interactionB) => {
const { create_time: createTimeA } = interactionA;
const { create_time: createTimeB } = interactionB;
const createTimeMSA = +new Date(createTimeA);
const createTimeMSB = +new Date(createTimeB);
if (isNaN(createTimeMSA) || isNaN(createTimeMSB)) {
return 0;
}
return createTimeMSA - createTimeMSB;
});
let finalMessages: IMessage[] = [];
for (const interaction of finalInteractions) {
finalMessages = [...finalMessages, ...(await messageParserRunner.run(interaction))];
}
return {
title: 'test',
version: 1,
createdTimeMs: Date.now(),
updatedTimeMs: Date.now(),
messages: session.body.interactions
.filter((item) => !item.parent_interaction_id)
.reduce((total, current) => {
const inputItem: IInput = {
type: 'input',
contentType: 'text',
content: current.input,
};
const outputItems: IOutput[] = [
{
type: 'output',
contentType: 'markdown',
content: current.response,
traceId: current.interaction_id,
},
];
return [...total, inputItem, ...outputItems];
}, [] as IMessage[]),
messages: finalMessages,
};
}

Expand Down
1 change: 1 addition & 0 deletions server/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export interface Interaction {
interaction_id: string;
create_time: string;
additional_info: Record<string, unknown>;
parent_interaction_id: string;
}

export interface MessageParser {
Expand Down

0 comments on commit c3d0027

Please sign in to comment.