Skip to content

Commit

Permalink
Merge pull request #2 from opensearch-project/feature/switch-chat-api
Browse files Browse the repository at this point in the history
feat: use agent framework API to generate answer
  • Loading branch information
SuZhou-Joe authored Nov 20, 2023
2 parents 841a273 + c6dbfb1 commit 88eb43e
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 77 deletions.
4 changes: 2 additions & 2 deletions babel.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ module.exports = function (api) {
],
plugins: [
[require('@babel/plugin-transform-runtime'), { regenerator: true }],
require('@babel/plugin-proposal-class-properties'),
require('@babel/plugin-proposal-object-rest-spread'),
require('@babel/plugin-transform-class-properties'),
require('@babel/plugin-transform-object-rest-spread'),
[require('@babel/plugin-transform-modules-commonjs'), { allowTopLevelThis: true }],
],
};
Expand Down
4 changes: 4 additions & 0 deletions public/chat_header_button.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ export const HeaderChatButton: React.FC<HeaderChatButtonProps> = (props) => {
const [traceId, setTraceId] = useState<string | undefined>(undefined);
const [chatSize, setChatSize] = useState<number | 'fullscreen' | 'dock-right'>('dock-right');
const flyoutFullScreen = chatSize === 'fullscreen';
const [rootAgentId, setRootAgentId] = useState<string>(
new URL(window.location.href).searchParams.get('agent_id') || ''
);

if (!flyoutLoaded && flyoutVisible) flyoutLoaded = true;

Expand Down Expand Up @@ -73,6 +76,7 @@ export const HeaderChatButton: React.FC<HeaderChatButtonProps> = (props) => {
setTitle,
traceId,
setTraceId,
rootAgentId,
}),
[
appId,
Expand Down
1 change: 1 addition & 0 deletions public/contexts/chat_context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export interface IChatContext {
setTitle: React.Dispatch<React.SetStateAction<string | undefined>>;
traceId?: string;
setTraceId: React.Dispatch<React.SetStateAction<string | undefined>>;
rootAgentId?: string;
}
export const ChatContext = React.createContext<IChatContext | null>(null);

Expand Down
1 change: 1 addition & 0 deletions public/hooks/use_chat_actions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export const useChatActions = (): AssistantActions => {
// do not send abort signal to http client to allow LLM call run in background
body: JSON.stringify({
sessionId: chatContext.sessionId,
rootAgentId: chatContext.rootAgentId,
...(!chatContext.sessionId && { messages: chatState.messages }), // include all previous messages for new chats
input,
}),
Expand Down
37 changes: 15 additions & 22 deletions server/routes/chat_routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ import { ASSISTANT_API } from '../../common/constants/llm';
import { OllyChatService } from '../services/chat/olly_chat_service';
import { SavedObjectsStorageService } from '../services/storage/saved_objects_storage_service';
import { IMessage, IInput } from '../../common/types/chat_saved_object_attributes';
import { AgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service';

const llmRequestRoute = {
path: ASSISTANT_API.SEND_MESSAGE,
validate: {
body: schema.object({
sessionId: schema.maybe(schema.string()),
messages: schema.maybe(schema.arrayOf(schema.any())),
rootAgentId: schema.string(),
input: schema.object({
type: schema.literal('input'),
context: schema.object({
Expand Down Expand Up @@ -104,7 +106,7 @@ const updateSessionRoute = {

export function registerChatRoutes(router: IRouter) {
const createStorageService = (context: RequestHandlerContext) =>
new SavedObjectsStorageService(context.core.savedObjects.client);
new AgentFrameworkStorageService(context.core.opensearch.client.asCurrentUser);
const createChatService = () => new OllyChatService();

router.post(
Expand All @@ -114,34 +116,25 @@ export function registerChatRoutes(router: IRouter) {
request,
response
): Promise<IOpenSearchDashboardsResponse<HttpResponsePayload | ResponseError>> => {
const { sessionId, input, messages = [] } = request.body;
const { messages = [], input, sessionId: sessionIdInRequestBody } = request.body;
const storageService = createStorageService(context);
const chatService = createChatService();

// get history from the chat object for existing chats
if (sessionId && messages.length === 0) {
try {
const session = await storageService.getSession(sessionId);
messages.push(...session.messages);
} catch (error) {
return response.custom({ statusCode: error.statusCode || 500, body: error.message });
}
}

try {
const outputs = await chatService.requestLLM(
{ messages, input, sessionId },
{ messages, input, sessionId: sessionIdInRequestBody },
context,
request
);
const title = input.content.substring(0, 50);
const saveMessagesResponse = await storageService.saveMessages(
title,
sessionId,
[...messages, input, ...outputs].filter((message) => message.content !== 'AbortError')
);
const sessionId = outputs.memoryId;
const finalMessage = await storageService.getSession(sessionId);

return response.ok({
body: { ...saveMessagesResponse, title },
body: {
messages: finalMessage.messages,
sessionId: outputs.memoryId,
title: finalMessage.title
},
});
} catch (error) {
context.assistant_plugin.logger.warn(error);
Expand Down Expand Up @@ -278,13 +271,13 @@ export function registerChatRoutes(router: IRouter) {
const outputs = await chatService.requestLLM(
{ messages, input, sessionId },
context,
request
request as any
);
const title = input.content.substring(0, 50);
const saveMessagesResponse = await storageService.saveMessages(
title,
sessionId,
[...messages, input, ...outputs].filter((message) => message.content !== 'AbortError')
[...messages, input, ...outputs.messages].filter((message) => message.content !== 'AbortError')
);
return response.ok({
body: { ...saveMessagesResponse, title },
Expand Down
5 changes: 4 additions & 1 deletion server/services/chat/chat_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ export interface ChatService {
payload: { messages: IMessage[]; input: IInput; sessionId?: string },
context: RequestHandlerContext,
request: OpenSearchDashboardsRequest<unknown, unknown, LLMRequestSchema, 'post'>
): Promise<IMessage[]>;
): Promise<{
messages: IMessage[];
memoryId: string;
}>;
generatePPL(
context: RequestHandlerContext,
request: OpenSearchDashboardsRequest<unknown, unknown, PPLGenerationRequestSchema, 'post'>
Expand Down
110 changes: 58 additions & 52 deletions server/services/chat/olly_chat_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,87 +5,93 @@

import { Run } from 'langchain/callbacks';
import { v4 as uuid } from 'uuid';
import { ApiResponse } from '@opensearch-project/opensearch';
import { OpenSearchDashboardsRequest, RequestHandlerContext } from '../../../../../src/core/server';
import { IMessage, IInput } from '../../../common/types/chat_saved_object_attributes';
import { convertToTraces } from '../../../common/utils/llm_chat/traces';
import { chatAgentInit } from '../../olly/agents/agent_helpers';
import { OpenSearchTracer } from '../../olly/callbacks/opensearch_tracer';
import { requestSuggestionsChain } from '../../olly/chains/suggestions_generator';
import { memoryInit } from '../../olly/memory/chat_agent_memory';
import { LLMModelFactory } from '../../olly/models/llm_model_factory';
import { initTools } from '../../olly/tools/tools_helper';
import { PPLTools } from '../../olly/tools/tool_sets/ppl';
import { buildOutputs } from '../../olly/utils/output_builders/build_outputs';
import { AbortAgentExecutionSchema, LLMRequestSchema } from '../../routes/chat_routes';
import { PPLGenerationRequestSchema } from '../../routes/langchain_routes';
import { ChatService } from './chat_service';

const MEMORY_ID_FIELD = 'memory_id';
const RESPONSE_FIELD = 'response';

export class OllyChatService implements ChatService {
static abortControllers: Map<string, AbortController> = new Map();

public async requestLLM(
payload: { messages: IMessage[]; input: IInput; sessionId?: string },
context: RequestHandlerContext,
request: OpenSearchDashboardsRequest
): Promise<IMessage[]> {
const traceId = uuid();
const observabilityClient = context.assistant_plugin.observabilityClient.asScoped(request);
request: OpenSearchDashboardsRequest<unknown, unknown, LLMRequestSchema, 'post'>
): Promise<{
messages: IMessage[];
memoryId: string;
}> {
const { input, sessionId, rootAgentId } = request.body;
const opensearchClient = context.core.opensearch.client.asCurrentUser;
const savedObjectsClient = context.core.savedObjects.client;

if (payload.sessionId) {
OllyChatService.abortControllers.set(payload.sessionId, new AbortController());
}

try {
const runs: Run[] = [];
const callbacks = [new OpenSearchTracer(opensearchClient, traceId, runs)];
const model = LLMModelFactory.createModel({ client: opensearchClient });
const embeddings = LLMModelFactory.createEmbeddings({ client: opensearchClient });
const pluginTools = initTools(
model,
embeddings,
opensearchClient,
observabilityClient,
savedObjectsClient,
callbacks
);
const memory = memoryInit(payload.messages);
const chatAgent = chatAgentInit(
model,
pluginTools.flatMap((tool) => tool.toolsList),
callbacks,
memory
);
const agentResponse = await chatAgent.run(
payload.input.content,
payload.sessionId ? OllyChatService.abortControllers.get(payload.sessionId) : undefined
);

const suggestions = await requestSuggestionsChain(
model,
pluginTools.flatMap((tool) => tool.toolsList),
memory,
callbacks
);
/**
* Wait for an API to fetch root agent id.
*/
const parametersPayload: {
question: string;
verbose?: boolean;
memory_id?: string;
} = {
question: input.content,
verbose: true,
};
if (sessionId) {
parametersPayload.memory_id = sessionId;
}
const agentFrameworkResponse = (await opensearchClient.transport.request({
method: 'POST',
path: `/_plugins/_ml/agents/${rootAgentId}/_execute`,
body: {
parameters: parametersPayload,
},
})) as ApiResponse<{
inference_results: Array<{
output: Array<{ name: string; result?: string }>;
}>;
}>;
const outputBody =
agentFrameworkResponse.body.inference_results?.[0]?.output ||
agentFrameworkResponse.body.inference_results?.[0]?.output;
const memoryIdItem = outputBody?.find((item) => item.name === MEMORY_ID_FIELD);
const reversedOutputBody = [...outputBody].reverse();
const finalAnswerItem = reversedOutputBody.find((item) => item.name === RESPONSE_FIELD);

return buildOutputs(
payload.input.content,
agentResponse,
traceId,
suggestions,
convertToTraces(runs)
);
const agentFrameworkAnswer = finalAnswerItem?.result || '';

return {
messages: buildOutputs(input.content, agentFrameworkAnswer, '', {}, convertToTraces(runs)),
memoryId: memoryIdItem?.result || '',
};
} catch (error) {
context.assistant_plugin.logger.error(error);
return [
{
type: 'output',
traceId,
contentType: 'error',
content: error.message,
},
];
return {
messages: [
{
type: 'output',
traceId: '',
contentType: 'error',
content: error.message,
},
],
memoryId: '',
};
} finally {
if (payload.sessionId) {
OllyChatService.abortControllers.delete(payload.sessionId);
Expand Down
76 changes: 76 additions & 0 deletions server/services/storage/agent_framework_storage_service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { ApiResponse } from '@opensearch-project/opensearch/.';
import { OpenSearchClient } from '../../../../../src/core/server';
import { LLM_INDEX } from '../../../common/constants/llm';
import {
IInput,
IMessage,
IOutput,
ISession,
ISessionFindResponse,
} from '../../../common/types/chat_saved_object_attributes';
import { GetSessionsSchema } from '../../routes/chat_routes';
import { StorageService } from './storage_service';

export class AgentFrameworkStorageService implements StorageService {
constructor(private readonly client: OpenSearchClient) {}
async getSession(sessionId: string): Promise<ISession> {
const session = (await this.client.transport.request({
method: 'GET',
path: `/_plugins/_ml/memory/conversation/${sessionId}`,
})) as ApiResponse<{
interactions: Array<{
input: string;
response: string;
parent_interaction_id: string;
interaction_id: string;
}>;
}>;
return {
title: 'test',
version: 1,
createdTimeMs: Date.now(),
updatedTimeMs: Date.now(),
messages: session.body.interactions
.filter((item) => !item.parent_interaction_id)
.reduce((total, current) => {
const inputItem: IInput = {
type: 'input',
contentType: 'text',
content: current.input,
};
const outputItems: IOutput[] = [
{
type: 'output',
contentType: 'markdown',
content: current.response,
traceId: current.interaction_id,
},
];
return [...total, inputItem, ...outputItems];
}, [] as IMessage[]),
};
}

async getSessions(query: GetSessionsSchema): Promise<ISessionFindResponse> {
throw new Error('Method not implemented.');
}

async saveMessages(
title: string,
sessionId: string | undefined,
messages: IMessage[]
): Promise<{ sessionId: string; messages: IMessage[] }> {
throw new Error('Method not implemented.');
}
deleteSession(sessionId: string): Promise<{}> {
throw new Error('Method not implemented.');
}
updateSession(sessionId: string, title: string): Promise<{}> {
throw new Error('Method not implemented.');
}
}
2 changes: 2 additions & 0 deletions server/services/storage/storage_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ export interface StorageService {
sessionId: string | undefined,
messages: IMessage[]
): Promise<{ sessionId: string; messages: IMessage[] }>;
deleteSession(sessionId: string): Promise<{}>;
updateSession(sessionId: string, title: string): Promise<{}>;
}

0 comments on commit 88eb43e

Please sign in to comment.