diff --git a/public/hooks/use_chat_actions.tsx b/public/hooks/use_chat_actions.tsx
index c78bff31..fd5f3f4a 100644
--- a/public/hooks/use_chat_actions.tsx
+++ b/public/hooks/use_chat_actions.tsx
@@ -162,7 +162,7 @@ export const useChatActions = (): AssistantActions => {
}
};
- const regenerate = async () => {
+ const regenerate = async (interactionId: string) => {
if (chatContext.sessionId) {
const abortController = new AbortController();
abortControllerRef = abortController;
@@ -170,7 +170,11 @@ export const useChatActions = (): AssistantActions => {
try {
const response = await core.services.http.put(`${ASSISTANT_API.REGENERATE}`, {
- body: JSON.stringify({ sessionId: chatContext.sessionId }),
+ body: JSON.stringify({
+ sessionId: chatContext.sessionId,
+ rootAgentId: chatContext.rootAgentId,
+ interactionId,
+ }),
});
if (abortController.signal.aborted) {
diff --git a/public/tabs/chat/messages/message_bubble.test.tsx b/public/tabs/chat/messages/message_bubble.test.tsx
index 4293dec4..d4652f32 100644
--- a/public/tabs/chat/messages/message_bubble.test.tsx
+++ b/public/tabs/chat/messages/message_bubble.test.tsx
@@ -130,6 +130,13 @@ describe('', () => {
contentType: 'markdown',
content: 'here are the indices in your cluster: .alert',
}}
+ interaction={{
+ input: 'foo',
+ response: 'bar',
+ conversation_id: 'foo',
+ interaction_id: 'bar',
+ create_time: new Date().toLocaleString(),
+ }}
/>
);
expect(screen.queryAllByTitle('regenerate message')).toHaveLength(1);
diff --git a/public/tabs/chat/messages/message_bubble.tsx b/public/tabs/chat/messages/message_bubble.tsx
index 34199640..8a704e2a 100644
--- a/public/tabs/chat/messages/message_bubble.tsx
+++ b/public/tabs/chat/messages/message_bubble.tsx
@@ -30,7 +30,7 @@ type MessageBubbleProps = {
showActionBar: boolean;
showRegenerate?: boolean;
shouldActionBarVisibleOnHover?: boolean;
- onRegenerate?: () => void;
+ onRegenerate?: (interactionId: string) => void;
} & (
| {
message: IMessage;
@@ -192,17 +192,17 @@ export const MessageBubble: React.FC = React.memo((props) =>
)}
- {props.showRegenerate && (
+ {props.showRegenerate && props.interaction?.interaction_id ? (
props.onRegenerate?.(props.interaction?.interaction_id || '')}
title="regenerate message"
color="text"
iconType="refresh"
/>
- )}
+ ) : null}
{showFeedback && (
// After feedback, only corresponding thumb icon will be kept and disabled.
<>
diff --git a/public/types.ts b/public/types.ts
index d8f152a3..5fd027e2 100644
--- a/public/types.ts
+++ b/public/types.ts
@@ -16,7 +16,7 @@ export interface AssistantActions {
openChatUI: (sessionId?: string) => void;
executeAction: (suggestedAction: ISuggestedAction, message: IMessage) => void;
abortAction: (sessionId?: string) => void;
- regenerate: () => void;
+ regenerate: (interactionId: string) => void;
}
export interface AppPluginStartDependencies {
diff --git a/server/routes/chat_routes.ts b/server/routes/chat_routes.ts
index 22328032..5e195820 100644
--- a/server/routes/chat_routes.ts
+++ b/server/routes/chat_routes.ts
@@ -13,7 +13,6 @@ import {
} from '../../../../src/core/server';
import { ASSISTANT_API } from '../../common/constants/llm';
import { OllyChatService } from '../services/chat/olly_chat_service';
-import { IMessage, IInput } from '../../common/types/chat_saved_object_attributes';
import { AgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service';
import { RoutesOptions } from '../types';
import { ChatService } from '../services/chat/chat_service';
@@ -64,6 +63,7 @@ const regenerateRoute = {
body: schema.object({
sessionId: schema.string(),
rootAgentId: schema.string(),
+ interactionId: schema.string(),
}),
},
};
@@ -314,42 +314,35 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise> => {
- const { sessionId, rootAgentId } = request.body;
+ const { sessionId, rootAgentId, interactionId } = request.body;
const storageService = createStorageService(context);
- let messages: IMessage[] = [];
const chatService = createChatService();
+ let outputs: Awaited> | undefined;
+
+ /**
+ * Get final answer from Agent framework
+ */
try {
- const session = await storageService.getSession(sessionId);
- messages.push(...session.messages);
+ outputs = await chatService.regenerate({ sessionId, rootAgentId, interactionId }, context);
} catch (error) {
- return response.custom({ statusCode: error.statusCode || 500, body: error.message });
+ context.assistant_plugin.logger.error(error);
}
- const lastInputIndex = messages.findLastIndex((msg) => msg.type === 'input');
- // Find last input message
- const input = messages[lastInputIndex] as IInput;
- // Take the messages before last input message as memory as regenerate will exclude the last outputs
- messages = messages.slice(0, lastInputIndex);
-
+ /**
+ * Retrieve latest interactions from memory
+ */
try {
- const outputs = await chatService.requestLLM(
- { messages, input, sessionId, rootAgentId },
- context
- );
- const title = input.content.substring(0, 50);
- const saveMessagesResponse = await storageService.saveMessages(
- title,
- sessionId,
- [...messages, input, ...outputs.messages].filter(
- (message) => message.content !== 'AbortError'
- )
- );
+ const conversation = await storageService.getSession(sessionId);
+
return response.ok({
- body: { ...saveMessagesResponse, title },
+ body: {
+ ...conversation,
+ sessionId,
+ },
});
} catch (error) {
- context.assistant_plugin.logger.warn(error);
+ context.assistant_plugin.logger.error(error);
return response.custom({ statusCode: error.statusCode || 500, body: error.message });
}
}
diff --git a/server/services/chat/chat_service.ts b/server/services/chat/chat_service.ts
index ac15adf6..25fe703f 100644
--- a/server/services/chat/chat_service.ts
+++ b/server/services/chat/chat_service.ts
@@ -10,8 +10,15 @@ import { LLMRequestSchema } from '../../routes/chat_routes';
export interface ChatService {
requestLLM(
payload: { messages: IMessage[]; input: IInput; sessionId?: string },
- context: RequestHandlerContext,
- request: OpenSearchDashboardsRequest
+ context: RequestHandlerContext
+ ): Promise<{
+ messages: IMessage[];
+ memoryId: string;
+ }>;
+
+ regenerate(
+ payload: { sessionId: string; interactionId: string; rootAgentId: string },
+ context: RequestHandlerContext
): Promise<{
messages: IMessage[];
memoryId: string;
diff --git a/server/services/chat/olly_chat_service.test.ts b/server/services/chat/olly_chat_service.test.ts
new file mode 100644
index 00000000..1d5f563d
--- /dev/null
+++ b/server/services/chat/olly_chat_service.test.ts
@@ -0,0 +1,176 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { OllyChatService } from './olly_chat_service';
+import { CoreRouteHandlerContext } from '../../../../../src/core/server/core_route_handler_context';
+import { coreMock, httpServerMock } from '../../../../../src/core/server/mocks';
+import { loggerMock } from '../../../../../src/core/server/logging/logger.mock';
+
+describe('OllyChatService', () => {
+ const ollyChatService = new OllyChatService();
+ const coreContext = new CoreRouteHandlerContext(
+ coreMock.createInternalStart(),
+ httpServerMock.createOpenSearchDashboardsRequest()
+ );
+ const mockedTransport = coreContext.opensearch.client.asCurrentUser.transport
+ .request as jest.Mock;
+ const contextMock = {
+ core: coreContext,
+ assistant_plugin: {
+ logger: loggerMock.create(),
+ },
+ };
+ beforeEach(() => {
+ mockedTransport.mockClear();
+ });
+ it('requestLLM should invoke client call with correct params', async () => {
+ mockedTransport.mockImplementationOnce(() => {
+ return {
+ body: {
+ inference_results: [
+ {
+ output: [
+ {
+ name: 'memory_id',
+ result: 'foo',
+ },
+ ],
+ },
+ ],
+ },
+ };
+ });
+ const result = await ollyChatService.requestLLM(
+ {
+ messages: [],
+ input: {
+ type: 'input',
+ contentType: 'text',
+ content: 'content',
+ },
+ sessionId: '',
+ rootAgentId: 'rootAgentId',
+ },
+ contextMock
+ );
+ expect(mockedTransport.mock.calls).toMatchInlineSnapshot(`
+ Array [
+ Array [
+ Object {
+ "body": Object {
+ "parameters": Object {
+ "question": "content",
+ "verbose": true,
+ },
+ },
+ "method": "POST",
+ "path": "/_plugins/_ml/agents/rootAgentId/_execute",
+ },
+ Object {
+ "maxRetries": 0,
+ "requestTimeout": 300000,
+ },
+ ],
+ ]
+ `);
+ expect(result).toMatchInlineSnapshot(`
+ Object {
+ "memoryId": "foo",
+ "messages": Array [],
+ }
+ `);
+ });
+
+ it('requestLLM should throw error when transport.request throws error', async () => {
+ mockedTransport.mockImplementationOnce(() => {
+ throw new Error('error');
+ });
+ expect(
+ ollyChatService.requestLLM(
+ {
+ messages: [],
+ input: {
+ type: 'input',
+ contentType: 'text',
+ content: 'content',
+ },
+ sessionId: '',
+ rootAgentId: 'rootAgentId',
+ },
+ contextMock
+ )
+ ).rejects.toMatchInlineSnapshot(`[Error: error]`);
+ });
+
+ it('regenerate should invoke client call with correct params', async () => {
+ mockedTransport.mockImplementationOnce(() => {
+ return {
+ body: {
+ inference_results: [
+ {
+ output: [
+ {
+ name: 'memory_id',
+ result: 'foo',
+ },
+ ],
+ },
+ ],
+ },
+ };
+ });
+ const result = await ollyChatService.regenerate(
+ {
+ sessionId: 'sessionId',
+ rootAgentId: 'rootAgentId',
+ interactionId: 'interactionId',
+ },
+ contextMock
+ );
+ expect(mockedTransport.mock.calls).toMatchInlineSnapshot(`
+ Array [
+ Array [
+ Object {
+ "body": Object {
+ "parameters": Object {
+ "memory_id": "sessionId",
+ "regenerate_interaction_id": "interactionId",
+ "verbose": true,
+ },
+ },
+ "method": "POST",
+ "path": "/_plugins/_ml/agents/rootAgentId/_execute",
+ },
+ Object {
+ "maxRetries": 0,
+ "requestTimeout": 300000,
+ },
+ ],
+ ]
+ `);
+ expect(result).toMatchInlineSnapshot(`
+ Object {
+ "memoryId": "foo",
+ "messages": Array [],
+ }
+ `);
+ });
+
+ it('regenerate should throw error when transport.request throws error', async () => {
+ mockedTransport.mockImplementationOnce(() => {
+ throw new Error('error');
+ });
+ expect(
+ ollyChatService.regenerate(
+ {
+ sessionId: 'sessionId',
+ rootAgentId: 'rootAgentId',
+ interactionId: 'interactionId',
+ },
+ contextMock
+ )
+ ).rejects.toMatchInlineSnapshot(`[Error: error]`);
+ });
+});
diff --git a/server/services/chat/olly_chat_service.ts b/server/services/chat/olly_chat_service.ts
index 13981eda..ed2cb57b 100644
--- a/server/services/chat/olly_chat_service.ts
+++ b/server/services/chat/olly_chat_service.ts
@@ -9,46 +9,35 @@ import { IMessage, IInput } from '../../../common/types/chat_saved_object_attrib
import { ChatService } from './chat_service';
import { ML_COMMONS_BASE_API } from '../../utils/constants';
+interface AgentRunPayload {
+ question?: string;
+ verbose?: boolean;
+ memory_id?: string;
+ regenerate_interaction_id?: string;
+}
+
const MEMORY_ID_FIELD = 'memory_id';
export class OllyChatService implements ChatService {
static abortControllers: Map = new Map();
- public async requestLLM(
- payload: { messages: IMessage[]; input: IInput; sessionId?: string; rootAgentId: string },
+ private async requestAgentRun(
+ rootAgentId: string,
+ payload: AgentRunPayload,
context: RequestHandlerContext
- ): Promise<{
- messages: IMessage[];
- memoryId: string;
- }> {
- const { input, sessionId, rootAgentId } = payload;
- const opensearchClient = context.core.opensearch.client.asCurrentUser;
-
- if (payload.sessionId) {
- OllyChatService.abortControllers.set(payload.sessionId, new AbortController());
+ ) {
+ if (payload.memory_id) {
+ OllyChatService.abortControllers.set(payload.memory_id, new AbortController());
}
+ const opensearchClient = context.core.opensearch.client.asCurrentUser;
try {
- /**
- * Wait for an API to fetch root agent id.
- */
- const parametersPayload: {
- question: string;
- verbose?: boolean;
- memory_id?: string;
- } = {
- question: input.content,
- verbose: true,
- };
- if (sessionId) {
- parametersPayload.memory_id = sessionId;
- }
const agentFrameworkResponse = (await opensearchClient.transport.request(
{
method: 'POST',
path: `${ML_COMMONS_BASE_API}/agents/${rootAgentId}/_execute`,
body: {
- parameters: parametersPayload,
+ parameters: payload,
},
},
{
@@ -69,7 +58,6 @@ export class OllyChatService implements ChatService {
}>;
const outputBody = agentFrameworkResponse.body.inference_results?.[0]?.output;
const memoryIdItem = outputBody?.find((item) => item.name === MEMORY_ID_FIELD);
-
return {
/**
* Interactions will be stored in Agent framework,
@@ -81,12 +69,50 @@ export class OllyChatService implements ChatService {
} catch (error) {
throw error;
} finally {
- if (payload.sessionId) {
- OllyChatService.abortControllers.delete(payload.sessionId);
+ if (payload.memory_id) {
+ OllyChatService.abortControllers.delete(payload.memory_id);
}
}
}
+ public async requestLLM(
+ payload: { messages: IMessage[]; input: IInput; sessionId?: string; rootAgentId: string },
+ context: RequestHandlerContext
+ ): Promise<{
+ messages: IMessage[];
+ memoryId: string;
+ }> {
+ const { input, sessionId, rootAgentId } = payload;
+
+ const parametersPayload: Pick = {
+ question: input.content,
+ verbose: true,
+ };
+
+ if (sessionId) {
+ parametersPayload.memory_id = sessionId;
+ }
+
+ return await this.requestAgentRun(rootAgentId, parametersPayload, context);
+ }
+
+ async regenerate(
+ payload: { sessionId: string; interactionId: string; rootAgentId: string },
+ context: RequestHandlerContext
+ ): Promise<{ messages: IMessage[]; memoryId: string }> {
+ const { sessionId, interactionId, rootAgentId } = payload;
+ const parametersPayload: Pick<
+ AgentRunPayload,
+ 'regenerate_interaction_id' | 'verbose' | 'memory_id'
+ > = {
+ memory_id: sessionId,
+ regenerate_interaction_id: interactionId,
+ verbose: true,
+ };
+
+ return await this.requestAgentRun(rootAgentId, parametersPayload, context);
+ }
+
abortAgentExecution(sessionId: string) {
if (OllyChatService.abortControllers.has(sessionId)) {
OllyChatService.abortControllers.get(sessionId)?.abort();
diff --git a/server/types.ts b/server/types.ts
index 948ed5aa..3e93ce8f 100644
--- a/server/types.ts
+++ b/server/types.ts
@@ -4,7 +4,7 @@
*/
import { IMessage, Interaction } from '../common/types/chat_saved_object_attributes';
-import { ILegacyClusterClient, Logger } from '../../../src/core/server';
+import { Logger } from '../../../src/core/server';
// eslint-disable-next-line @typescript-eslint/no-empty-interface
export interface AssistantPluginSetup {}