diff --git a/public/chat_header_button.tsx b/public/chat_header_button.tsx index db1fc958..df46bdf2 100644 --- a/public/chat_header_button.tsx +++ b/public/chat_header_button.tsx @@ -42,9 +42,6 @@ export const HeaderChatButton: React.FC = (props) => { const [inputFocus, setInputFocus] = useState(false); const flyoutFullScreen = chatSize === 'fullscreen'; const inputRef = useRef(null); - const [rootAgentId, setRootAgentId] = useState( - new URL(window.location.href).searchParams.get('agent_id') || '' - ); if (!flyoutLoaded && flyoutVisible) flyoutLoaded = true; @@ -80,7 +77,6 @@ export const HeaderChatButton: React.FC = (props) => { setTitle, traceId, setTraceId, - rootAgentId, }), [ appId, diff --git a/public/contexts/chat_context.tsx b/public/contexts/chat_context.tsx index 5d340e3d..04df189a 100644 --- a/public/contexts/chat_context.tsx +++ b/public/contexts/chat_context.tsx @@ -25,7 +25,6 @@ export interface IChatContext { setTitle: React.Dispatch>; traceId?: string; setTraceId: React.Dispatch>; - rootAgentId?: string; } export const ChatContext = React.createContext(null); diff --git a/public/hooks/use_chat_actions.test.tsx b/public/hooks/use_chat_actions.test.tsx index 4fca1f39..405e2df9 100644 --- a/public/hooks/use_chat_actions.test.tsx +++ b/public/hooks/use_chat_actions.test.tsx @@ -64,7 +64,6 @@ describe('useChatActions hook', () => { const setTraceIdMock = jest.fn(); const chatContextMock = { - rootAgentId: 'root_agent_id_mock', selectedTabId: 'chat', setSessionId: jest.fn(), setTitle: jest.fn(), @@ -109,7 +108,6 @@ describe('useChatActions hook', () => { // it should call send message api expect(httpMock.post).toHaveBeenCalledWith(ASSISTANT_API.SEND_MESSAGE, { body: JSON.stringify({ - rootAgentId: 'root_agent_id_mock', messages: [], input: INPUT_MESSAGE, }), @@ -173,7 +171,6 @@ describe('useChatActions hook', () => { // sending message with the suggestion expect(httpMock.post).toHaveBeenCalledWith(ASSISTANT_API.SEND_MESSAGE, { body: JSON.stringify({ - rootAgentId: 'root_agent_id_mock', messages: [], input: { type: 'input', content: 'message that send as input', contentType: 'text' }, }), @@ -255,7 +252,6 @@ describe('useChatActions hook', () => { expect(httpMock.put).toHaveBeenCalledWith(ASSISTANT_API.REGENERATE, { body: JSON.stringify({ sessionId: 'session_id_mock', - rootAgentId: 'root_agent_id_mock', interactionId: 'interaction_id_mock', }), }); @@ -281,7 +277,6 @@ describe('useChatActions hook', () => { expect(httpMock.put).toHaveBeenCalledWith(ASSISTANT_API.REGENERATE, { body: JSON.stringify({ sessionId: 'session_id_mock', - rootAgentId: 'root_agent_id_mock', interactionId: 'interaction_id_mock', }), }); diff --git a/public/hooks/use_chat_actions.tsx b/public/hooks/use_chat_actions.tsx index b1a56de9..244252bf 100644 --- a/public/hooks/use_chat_actions.tsx +++ b/public/hooks/use_chat_actions.tsx @@ -38,7 +38,6 @@ export const useChatActions = (): AssistantActions => { // do not send abort signal to http client to allow LLM call run in background body: JSON.stringify({ sessionId: chatContext.sessionId, - rootAgentId: chatContext.rootAgentId, ...(!chatContext.sessionId && { messages: chatState.messages }), // include all previous messages for new chats input, }), @@ -168,7 +167,6 @@ export const useChatActions = (): AssistantActions => { const response = await core.services.http.put(`${ASSISTANT_API.REGENERATE}`, { body: JSON.stringify({ sessionId: chatContext.sessionId, - rootAgentId: chatContext.rootAgentId, interactionId, }), }); diff --git a/public/index.ts b/public/index.ts index 30e31105..cdce4ade 100644 --- a/public/index.ts +++ b/public/index.ts @@ -11,3 +11,5 @@ export { AssistantPlugin as Plugin }; export function plugin(initializerContext: PluginInitializerContext) { return new AssistantPlugin(initializerContext); } + +export { AssistantSetup } from './types'; diff --git a/public/plugin.tsx b/public/plugin.tsx index 1929c229..38a38795 100644 --- a/public/plugin.tsx +++ b/public/plugin.tsx @@ -30,6 +30,7 @@ interface PublicConfig { chat: { // whether chat feature is enabled, UI should hide if false enabled: boolean; + rootAgentId?: string; }; } @@ -74,7 +75,7 @@ export class AssistantPlugin const checkAccess = (account: Awaited>) => account.data.roles.some((role) => ['all_access', 'assistant_user'].includes(role)); - if (this.config.chat.enabled) { + if (this.config.chat.enabled && this.config.chat.rootAgentId) { core.getStartServices().then(async ([coreStart, startDeps]) => { const CoreContext = createOpenSearchDashboardsReactContext({ ...coreStart, diff --git a/server/index.ts b/server/index.ts index 8e2e488d..6b5e9028 100644 --- a/server/index.ts +++ b/server/index.ts @@ -11,12 +11,13 @@ export function plugin(initializerContext: PluginInitializerContext) { return new AssistantPlugin(initializerContext); } -export { AssistantPluginSetup, AssistantPluginStart } from './types'; +export { AssistantPluginSetup, AssistantPluginStart, MessageParser } from './types'; const assistantConfig = { schema: schema.object({ chat: schema.object({ enabled: schema.boolean({ defaultValue: false }), + rootAgentId: schema.maybe(schema.string()), }), }), }; diff --git a/server/plugin.ts b/server/plugin.ts index e4ebc3f5..bbca5bff 100644 --- a/server/plugin.ts +++ b/server/plugin.ts @@ -16,6 +16,7 @@ import { setupRoutes } from './routes/index'; import { AssistantPluginSetup, AssistantPluginStart, MessageParser } from './types'; import { BasicInputOutputParser } from './parsers/basic_input_output_parser'; import { VisualizationCardParser } from './parsers/visualization_card_parser'; +import { AgentIdNotFoundError } from './routes/chat_routes'; export class AssistantPlugin implements Plugin { private readonly logger: Logger; @@ -25,12 +26,21 @@ export class AssistantPlugin implements Plugin { this.logger.debug('Assistant: Setup'); const config = await this.initializerContext.config .create() .pipe(first()) .toPromise(); + + /** + * Check if user enable the chat without specifying a root agent id. + * If so, gives a warning for guidance. + */ + if (config.chat.enabled && !config.chat.rootAgentId) { + this.logger.warn(AgentIdNotFoundError); + } + const router = core.http.createRouter(); core.http.registerRouteHandlerContext('assistant_plugin', () => { @@ -43,6 +53,7 @@ export class AssistantPlugin implements Plugin ({ diff --git a/server/routes/chat_routes.ts b/server/routes/chat_routes.ts index 5e195820..72f87e3d 100644 --- a/server/routes/chat_routes.ts +++ b/server/routes/chat_routes.ts @@ -23,7 +23,6 @@ const llmRequestRoute = { body: schema.object({ sessionId: schema.maybe(schema.string()), messages: schema.maybe(schema.arrayOf(schema.any())), - rootAgentId: schema.string(), input: schema.object({ type: schema.literal('input'), context: schema.object({ @@ -37,6 +36,9 @@ const llmRequestRoute = { }; export type LLMRequestSchema = TypeOf; +export const AgentIdNotFoundError = + 'rootAgentId is required, please specify one in opensearch_dashboards.yml'; + const getSessionRoute = { path: `${ASSISTANT_API.SESSION}/{sessionId}`, validate: { @@ -62,7 +64,6 @@ const regenerateRoute = { validate: { body: schema.object({ sessionId: schema.string(), - rootAgentId: schema.string(), interactionId: schema.string(), }), }, @@ -142,7 +143,11 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const { messages = [], input, sessionId: sessionIdInRequestBody, rootAgentId } = request.body; + if (!routeOptions.rootAgentId) { + context.assistant_plugin.logger.error(AgentIdNotFoundError); + return response.custom({ statusCode: 400, body: AgentIdNotFoundError }); + } + const { messages = [], input, sessionId: sessionIdInRequestBody } = request.body; const storageService = createStorageService(context); const chatService = createChatService(); @@ -153,7 +158,12 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) */ try { outputs = await chatService.requestLLM( - { messages, input, sessionId: sessionIdInRequestBody, rootAgentId }, + { + messages, + input, + sessionId: sessionIdInRequestBody, + rootAgentId: routeOptions.rootAgentId, + }, context ); } catch (error) { @@ -314,7 +324,11 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const { sessionId, rootAgentId, interactionId } = request.body; + if (!routeOptions.rootAgentId) { + context.assistant_plugin.logger.error(AgentIdNotFoundError); + return response.custom({ statusCode: 400, body: AgentIdNotFoundError }); + } + const { sessionId, interactionId } = request.body; const storageService = createStorageService(context); const chatService = createChatService(); @@ -324,7 +338,10 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) * Get final answer from Agent framework */ try { - outputs = await chatService.regenerate({ sessionId, rootAgentId, interactionId }, context); + outputs = await chatService.regenerate( + { sessionId, rootAgentId: routeOptions.rootAgentId, interactionId }, + context + ); } catch (error) { context.assistant_plugin.logger.error(error); } diff --git a/server/routes/regenerate.test.ts b/server/routes/regenerate.test.ts new file mode 100644 index 00000000..3b4dcd83 --- /dev/null +++ b/server/routes/regenerate.test.ts @@ -0,0 +1,180 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ResponseObject } from '@hapi/hapi'; +import { Boom } from '@hapi/boom'; +import { Router } from '../../../../src/core/server/http/router'; +import { enhanceWithContext, triggerHandler } from './router.mock'; +import { mockOllyChatService } from '../services/chat/olly_chat_service.mock'; +import { mockAgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service.mock'; +import { httpServerMock } from '../../../../src/core/server/http/http_server.mocks'; +import { loggerMock } from '../../../../src/core/server/logging/logger.mock'; +import { registerChatRoutes, RegenerateSchema, AgentIdNotFoundError } from './chat_routes'; +import { ASSISTANT_API } from '../../common/constants/llm'; + +const mockedLogger = loggerMock.create(); + +describe('regenerate route when rootAgentId is provided', () => { + const router = new Router( + '', + mockedLogger, + enhanceWithContext({ + assistant_plugin: { + logger: mockedLogger, + }, + }) + ); + registerChatRoutes(router, { + messageParsers: [], + rootAgentId: 'foo', + }); + const regenerateRequest = (payload: RegenerateSchema) => + triggerHandler(router, { + method: 'put', + path: ASSISTANT_API.REGENERATE, + req: httpServerMock.createRawRequest({ + payload: JSON.stringify(payload), + }), + }); + beforeEach(() => { + loggerMock.clear(mockedLogger); + }); + it('return back successfully when regenerate returns momery back', async () => { + mockOllyChatService.regenerate.mockImplementationOnce(async () => { + return { + messages: [], + memoryId: 'foo', + }; + }); + mockAgentFrameworkStorageService.getSession.mockImplementationOnce(async () => { + return { + messages: [], + title: 'foo', + interactions: [], + createdTimeMs: 0, + updatedTimeMs: 0, + }; + }); + const result = (await regenerateRequest({ + sessionId: 'foo', + interactionId: 'bar', + })) as ResponseObject; + expect(result.source).toMatchInlineSnapshot(` + Object { + "createdTimeMs": 0, + "interactions": Array [], + "messages": Array [], + "sessionId": "foo", + "title": "foo", + "updatedTimeMs": 0, + } + `); + }); + + it('log error when regenerate throws an error', async () => { + mockOllyChatService.regenerate.mockImplementationOnce(() => { + throw new Error('something went wrong'); + }); + mockAgentFrameworkStorageService.getSession.mockImplementationOnce(async () => { + return { + messages: [], + title: 'foo', + interactions: [], + createdTimeMs: 0, + updatedTimeMs: 0, + }; + }); + const result = (await regenerateRequest({ + sessionId: 'foo', + interactionId: 'bar', + })) as ResponseObject; + expect(mockedLogger.error).toBeCalledTimes(1); + expect(result.source).toMatchInlineSnapshot(` + Object { + "createdTimeMs": 0, + "interactions": Array [], + "messages": Array [], + "sessionId": "foo", + "title": "foo", + "updatedTimeMs": 0, + } + `); + }); + + it('return 500 when get session throws an error', async () => { + mockOllyChatService.regenerate.mockImplementationOnce(async () => { + return { + messages: [], + memoryId: 'foo', + }; + }); + mockAgentFrameworkStorageService.getSession.mockImplementationOnce(() => { + throw new Error('foo'); + }); + const result = (await regenerateRequest({ + sessionId: 'foo', + interactionId: 'bar', + })) as Boom; + expect(mockedLogger.error).toBeCalledTimes(1); + expect(mockedLogger.error).toBeCalledWith(new Error('foo')); + expect(result.output).toMatchInlineSnapshot(` + Object { + "headers": Object {}, + "payload": Object { + "error": "Internal Server Error", + "message": "foo", + "statusCode": 500, + }, + "statusCode": 500, + } + `); + }); +}); + +describe('regenerate route when rootAgentId is not provided', () => { + const router = new Router( + '', + mockedLogger, + enhanceWithContext({ + assistant_plugin: { + logger: mockedLogger, + }, + }) + ); + registerChatRoutes(router, { + messageParsers: [], + }); + const regenerateRequest = (payload: RegenerateSchema) => + triggerHandler(router, { + method: 'put', + path: ASSISTANT_API.REGENERATE, + req: httpServerMock.createRawRequest({ + payload: JSON.stringify(payload), + }), + }); + beforeEach(() => { + loggerMock.clear(mockedLogger); + }); + + it('return 400', async () => { + const result = (await regenerateRequest({ + interactionId: 'bar', + sessionId: 'foo', + })) as Boom; + expect(mockedLogger.error).toBeCalledTimes(1); + expect(mockedLogger.error).toBeCalledWith(AgentIdNotFoundError); + expect(result.output).toMatchInlineSnapshot(` + Object { + "headers": Object {}, + "payload": Object { + "error": "Bad Request", + "message": "rootAgentId is required, please specify one in opensearch_dashboards.yml", + "statusCode": 400, + }, + "statusCode": 400, + } + `); + }); +}); diff --git a/server/routes/router.mock.ts b/server/routes/router.mock.ts index 94400675..bdba471d 100644 --- a/server/routes/router.mock.ts +++ b/server/routes/router.mock.ts @@ -20,10 +20,12 @@ import { httpServerMock } from '../../../../src/core/server/http/http_server.moc import { OpenSearchDashboardsRequest, OpenSearchDashboardsResponseFactory, + RouteMethod, Router, } from '../../../../src/core/server/http/router'; import { CoreRouteHandlerContext } from '../../../../src/core/server/core_route_handler_context'; import { coreMock } from '../../../../src/core/server/mocks'; +import { ContextEnhancer } from '../../../../src/core/server/http/router/router'; /** * For hapi, ResponseToolkit is an internal implementation @@ -91,7 +93,7 @@ export class MockResponseToolkit implements ResponseToolkit { } } -const enhanceWithContext = (otherContext?: object) => (fn: (...args: unknown[]) => unknown) => ( +const enhanceWithContext = (((otherContext?: object) => (fn: (...args: unknown[]) => unknown) => ( req: OpenSearchDashboardsRequest, res: OpenSearchDashboardsResponseFactory ) => { @@ -105,7 +107,9 @@ const enhanceWithContext = (otherContext?: object) => (fn: (...args: unknown[]) req, res ); -}; +}) as unknown) as ( + otherContext?: object +) => ContextEnhancer; const triggerHandler = async ( router: Router, diff --git a/server/routes/send_message.test.ts b/server/routes/send_message.test.ts index 75554a58..f08bace6 100644 --- a/server/routes/send_message.test.ts +++ b/server/routes/send_message.test.ts @@ -11,25 +11,25 @@ import { mockOllyChatService } from '../services/chat/olly_chat_service.mock'; import { mockAgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service.mock'; import { httpServerMock } from '../../../../src/core/server/http/http_server.mocks'; import { loggerMock } from '../../../../src/core/server/logging/logger.mock'; -import { registerChatRoutes, LLMRequestSchema } from './chat_routes'; +import { registerChatRoutes, LLMRequestSchema, AgentIdNotFoundError } from './chat_routes'; import { ASSISTANT_API } from '../../common/constants/llm'; const mockedLogger = loggerMock.create(); -const router = new Router( - '', - mockedLogger, - enhanceWithContext({ - assistant_plugin: { - logger: mockedLogger, - }, - }) -); -registerChatRoutes(router, { - messageParsers: [], -}); - -describe('send_message route', () => { +describe('send_message route when rootAgentId is provided', () => { + const router = new Router( + '', + mockedLogger, + enhanceWithContext({ + assistant_plugin: { + logger: mockedLogger, + }, + }) + ); + registerChatRoutes(router, { + messageParsers: [], + rootAgentId: 'foo', + }); const sendMessageRequest = (payload: LLMRequestSchema) => triggerHandler(router, { method: 'post', @@ -58,7 +58,6 @@ describe('send_message route', () => { }; }); const result = (await sendMessageRequest({ - rootAgentId: 'foo', input: { content: '1', contentType: 'text', @@ -81,7 +80,6 @@ describe('send_message route', () => { throw new Error('something went wrong'); }); const result = (await sendMessageRequest({ - rootAgentId: 'foo', input: { content: '1', contentType: 'text', @@ -111,7 +109,6 @@ describe('send_message route', () => { }; }); const result = (await sendMessageRequest({ - rootAgentId: 'foo', input: { content: '1', contentType: 'text', @@ -147,7 +144,6 @@ describe('send_message route', () => { }; }); const result = (await sendMessageRequest({ - rootAgentId: 'foo', input: { content: '1', contentType: 'text', @@ -180,7 +176,6 @@ describe('send_message route', () => { throw new Error('foo'); }); const result = (await sendMessageRequest({ - rootAgentId: 'foo', input: { content: '1', contentType: 'text', @@ -206,3 +201,56 @@ describe('send_message route', () => { `); }); }); + +describe('send_message route when rootAgentId is not provided', () => { + const router = new Router( + '', + mockedLogger, + enhanceWithContext({ + assistant_plugin: { + logger: mockedLogger, + }, + }) + ); + registerChatRoutes(router, { + messageParsers: [], + }); + const sendMessageRequest = (payload: LLMRequestSchema) => + triggerHandler(router, { + method: 'post', + path: ASSISTANT_API.SEND_MESSAGE, + req: httpServerMock.createRawRequest({ + payload: JSON.stringify(payload), + }), + }); + beforeEach(() => { + loggerMock.clear(mockedLogger); + }); + + it('return 400', async () => { + const result = (await sendMessageRequest({ + input: { + content: '1', + contentType: 'text', + type: 'input', + context: { + appId: '', + }, + }, + sessionId: 'foo', + })) as Boom; + expect(mockedLogger.error).toBeCalledTimes(1); + expect(mockedLogger.error).toBeCalledWith(AgentIdNotFoundError); + expect(result.output).toMatchInlineSnapshot(` + Object { + "headers": Object {}, + "payload": Object { + "error": "Bad Request", + "message": "rootAgentId is required, please specify one in opensearch_dashboards.yml", + "statusCode": 400, + }, + "statusCode": 400, + } + `); + }); +}); diff --git a/server/services/chat/olly_chat_service.mock.ts b/server/services/chat/olly_chat_service.mock.ts index c4b91df3..0aded8c6 100644 --- a/server/services/chat/olly_chat_service.mock.ts +++ b/server/services/chat/olly_chat_service.mock.ts @@ -3,10 +3,12 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { PublicContract } from '@osd/utility-types'; import { OllyChatService } from './olly_chat_service'; -const mockOllyChatService: jest.Mocked = { +const mockOllyChatService: jest.Mocked> = { requestLLM: jest.fn(), + regenerate: jest.fn(), abortAgentExecution: jest.fn(), }; diff --git a/server/types.ts b/server/types.ts index 719d76c2..c1857607 100644 --- a/server/types.ts +++ b/server/types.ts @@ -6,8 +6,10 @@ import { IMessage, Interaction } from '../common/types/chat_saved_object_attributes'; import { Logger } from '../../../src/core/server'; -// eslint-disable-next-line @typescript-eslint/no-empty-interface -export interface AssistantPluginSetup {} +export interface AssistantPluginSetup { + registerMessageParser: (message: MessageParser) => void; + removeMessageParser: (parserId: MessageParser['id']) => void; +} // eslint-disable-next-line @typescript-eslint/no-empty-interface export interface AssistantPluginStart {} @@ -35,6 +37,7 @@ export interface MessageParser { export interface RoutesOptions { messageParsers: MessageParser[]; + rootAgentId?: string; } declare module '../../../src/core/server' {