Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Search root agent id by rootAgentName specified in opensearch_dashboards.yml #86

Merged
merged 3 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions public/plugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ interface PublicConfig {
chat: {
// whether chat feature is enabled, UI should hide if false
enabled: boolean;
rootAgentId?: string;
rootAgentName?: string;
};
}

Expand Down Expand Up @@ -75,7 +75,7 @@ export class AssistantPlugin
const checkAccess = (account: Awaited<ReturnType<typeof getAccount>>) =>
account.data.roles.some((role) => ['all_access', 'assistant_user'].includes(role));

if (this.config.chat.enabled && this.config.chat.rootAgentId) {
if (this.config.chat.enabled) {
core.getStartServices().then(async ([coreStart, startDeps]) => {
const CoreContext = createOpenSearchDashboardsReactContext<AssistantServices>({
...coreStart,
Expand Down
2 changes: 1 addition & 1 deletion server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const assistantConfig = {
schema: schema.object({
chat: schema.object({
enabled: schema.boolean({ defaultValue: false }),
rootAgentId: schema.maybe(schema.string()),
rootAgentName: schema.maybe(schema.string()),
}),
}),
};
Expand Down
8 changes: 4 additions & 4 deletions server/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import { setupRoutes } from './routes/index';
import { AssistantPluginSetup, AssistantPluginStart, MessageParser } from './types';
import { BasicInputOutputParser } from './parsers/basic_input_output_parser';
import { VisualizationCardParser } from './parsers/visualization_card_parser';
import { AgentIdNotFoundError } from './routes/chat_routes';
import { AgentNameNotFoundError } from './routes/chat_routes';

export class AssistantPlugin implements Plugin<AssistantPluginSetup, AssistantPluginStart> {
private readonly logger: Logger;
Expand All @@ -37,8 +37,8 @@ export class AssistantPlugin implements Plugin<AssistantPluginSetup, AssistantPl
* Check if user enable the chat without specifying a root agent id.
gaobinlong marked this conversation as resolved.
Show resolved Hide resolved
* If so, gives a warning for guidance.
*/
if (config.chat.enabled && !config.chat.rootAgentId) {
this.logger.warn(AgentIdNotFoundError);
if (config.chat.enabled && !config.chat.rootAgentName) {
this.logger.warn(AgentNameNotFoundError);
}

const router = core.http.createRouter();
Expand All @@ -53,7 +53,7 @@ export class AssistantPlugin implements Plugin<AssistantPluginSetup, AssistantPl
// Register server side APIs
setupRoutes(router, {
messageParsers: this.messageParsers,
rootAgentId: config.chat.rootAgentId,
rootAgentName: config.chat.rootAgentName,
});

core.capabilities.registerProvider(() => ({
Expand Down
44 changes: 19 additions & 25 deletions server/routes/chat_routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ const llmRequestRoute = {
};
export type LLMRequestSchema = TypeOf<typeof llmRequestRoute.validate.body>;

export const AgentIdNotFoundError =
'rootAgentId is required, please specify one in opensearch_dashboards.yml';
export const AgentNameNotFoundError =
'rootAgentName is required, please specify one in opensearch_dashboards.yml';

const getSessionRoute = {
path: `${ASSISTANT_API.SESSION}/{sessionId}`,
Expand Down Expand Up @@ -135,7 +135,8 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
context.core.opensearch.client.asCurrentUser,
routeOptions.messageParsers
);
const createChatService = () => new OllyChatService();
const createChatService = (context: RequestHandlerContext) =>
gaobinlong marked this conversation as resolved.
Show resolved Hide resolved
new OllyChatService(context, routeOptions.rootAgentName!);

router.post(
llmRequestRoute,
Expand All @@ -144,29 +145,25 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise<IOpenSearchDashboardsResponse<HttpResponsePayload | ResponseError>> => {
if (!routeOptions.rootAgentId) {
context.assistant_plugin.logger.error(AgentIdNotFoundError);
return response.custom({ statusCode: 400, body: AgentIdNotFoundError });
if (!routeOptions.rootAgentName) {
context.assistant_plugin.logger.error(AgentNameNotFoundError);
return response.custom({ statusCode: 400, body: AgentNameNotFoundError });
}
const { messages = [], input, sessionId: sessionIdInRequestBody } = request.body;
const storageService = createStorageService(context);
const chatService = createChatService();
const chatService = createChatService(context);
gaobinlong marked this conversation as resolved.
Show resolved Hide resolved

let outputs: Awaited<ReturnType<ChatService['requestLLM']>> | undefined;

/**
* Get final answer from Agent framework
*/
try {
outputs = await chatService.requestLLM(
{
messages,
input,
sessionId: sessionIdInRequestBody,
rootAgentId: routeOptions.rootAgentId,
},
context
);
outputs = await chatService.requestLLM({
messages,
input,
sessionId: sessionIdInRequestBody,
});
} catch (error) {
context.assistant_plugin.logger.error(error);
const sessionId = outputs?.memoryId || sessionIdInRequestBody;
Expand Down Expand Up @@ -329,7 +326,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise<IOpenSearchDashboardsResponse<HttpResponsePayload | ResponseError>> => {
const chatService = createChatService();
const chatService = createChatService(context);

try {
chatService.abortAgentExecution(request.body.sessionId);
wanglam marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -349,24 +346,21 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise<IOpenSearchDashboardsResponse<HttpResponsePayload | ResponseError>> => {
if (!routeOptions.rootAgentId) {
context.assistant_plugin.logger.error(AgentIdNotFoundError);
return response.custom({ statusCode: 400, body: AgentIdNotFoundError });
if (!routeOptions.rootAgentName) {
context.assistant_plugin.logger.error(AgentNameNotFoundError);
return response.custom({ statusCode: 400, body: AgentNameNotFoundError });
}
const { sessionId, interactionId } = request.body;
const storageService = createStorageService(context);
const chatService = createChatService();
const chatService = createChatService(context);

let outputs: Awaited<ReturnType<ChatService['regenerate']>> | undefined;

/**
* Get final answer from Agent framework
*/
try {
outputs = await chatService.regenerate(
{ sessionId, rootAgentId: routeOptions.rootAgentId, interactionId },
context
);
outputs = await chatService.regenerate({ sessionId, interactionId });
} catch (error) {
context.assistant_plugin.logger.error(error);
}
Expand Down
12 changes: 6 additions & 6 deletions server/routes/regenerate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ import {
} from '../services/storage/agent_framework_storage_service.mock';
import { httpServerMock } from '../../../../src/core/server/http/http_server.mocks';
import { loggerMock } from '../../../../src/core/server/logging/logger.mock';
import { registerChatRoutes, RegenerateSchema, AgentIdNotFoundError } from './chat_routes';
import { registerChatRoutes, RegenerateSchema, AgentNameNotFoundError } from './chat_routes';
import { ASSISTANT_API } from '../../common/constants/llm';

const mockedLogger = loggerMock.create();

describe('regenerate route when rootAgentId is provided', () => {
describe('regenerate route when rootAgentName is provided', () => {
const router = new Router(
'',
mockedLogger,
Expand All @@ -31,7 +31,7 @@ describe('regenerate route when rootAgentId is provided', () => {
);
registerChatRoutes(router, {
messageParsers: [],
rootAgentId: 'foo',
rootAgentName: 'foo',
});
const regenerateRequest = (payload: RegenerateSchema) =>
triggerHandler(router, {
Expand Down Expand Up @@ -169,7 +169,7 @@ describe('regenerate route when rootAgentId is provided', () => {
});
});

describe('regenerate route when rootAgentId is not provided', () => {
describe('regenerate route when rootAgentName is not provided', () => {
const router = new Router(
'',
mockedLogger,
Expand Down Expand Up @@ -200,13 +200,13 @@ describe('regenerate route when rootAgentId is not provided', () => {
sessionId: 'foo',
})) as Boom;
expect(mockedLogger.error).toBeCalledTimes(1);
expect(mockedLogger.error).toBeCalledWith(AgentIdNotFoundError);
expect(mockedLogger.error).toBeCalledWith(AgentNameNotFoundError);
expect(result.output).toMatchInlineSnapshot(`
Object {
"headers": Object {},
"payload": Object {
"error": "Bad Request",
"message": "rootAgentId is required, please specify one in opensearch_dashboards.yml",
"message": "rootAgentName is required, please specify one in opensearch_dashboards.yml",
"statusCode": 400,
},
"statusCode": 400,
Expand Down
12 changes: 6 additions & 6 deletions server/routes/send_message.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ import {
} from '../services/storage/agent_framework_storage_service.mock';
import { httpServerMock } from '../../../../src/core/server/http/http_server.mocks';
import { loggerMock } from '../../../../src/core/server/logging/logger.mock';
import { registerChatRoutes, LLMRequestSchema, AgentIdNotFoundError } from './chat_routes';
import { registerChatRoutes, LLMRequestSchema, AgentNameNotFoundError } from './chat_routes';
import { ASSISTANT_API } from '../../common/constants/llm';

const mockedLogger = loggerMock.create();

describe('send_message route when rootAgentId is provided', () => {
describe('send_message route when rootAgentName is provided', () => {
const router = new Router(
'',
mockedLogger,
Expand All @@ -31,7 +31,7 @@ describe('send_message route when rootAgentId is provided', () => {
);
registerChatRoutes(router, {
messageParsers: [],
rootAgentId: 'foo',
rootAgentName: 'foo',
});
const sendMessageRequest = (payload: LLMRequestSchema) =>
triggerHandler(router, {
Expand Down Expand Up @@ -265,7 +265,7 @@ describe('send_message route when rootAgentId is provided', () => {
});
});

describe('send_message route when rootAgentId is not provided', () => {
describe('send_message route when rootAgentName is not provided', () => {
const router = new Router(
'',
mockedLogger,
Expand Down Expand Up @@ -303,13 +303,13 @@ describe('send_message route when rootAgentId is not provided', () => {
sessionId: 'foo',
})) as Boom;
expect(mockedLogger.error).toBeCalledTimes(1);
expect(mockedLogger.error).toBeCalledWith(AgentIdNotFoundError);
expect(mockedLogger.error).toBeCalledWith(AgentNameNotFoundError);
expect(result.output).toMatchInlineSnapshot(`
Object {
"headers": Object {},
"payload": Object {
"error": "Bad Request",
"message": "rootAgentId is required, please specify one in opensearch_dashboards.yml",
"message": "rootAgentName is required, please specify one in opensearch_dashboards.yml",
"statusCode": 400,
},
"statusCode": 400,
Expand Down
Loading
Loading