Skip to content

Commit

Permalink
feat: add mocks and unit test for router.ts
Browse files Browse the repository at this point in the history
Signed-off-by: SuZhou-Joe <[email protected]>
  • Loading branch information
SuZhou-Joe committed Dec 11, 2023
1 parent 36ac34b commit 55e831c
Show file tree
Hide file tree
Showing 8 changed files with 581 additions and 21 deletions.
37 changes: 29 additions & 8 deletions server/routes/chat_routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { OllyChatService } from '../services/chat/olly_chat_service';
import { IMessage, IInput } from '../../common/types/chat_saved_object_attributes';
import { AgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service';
import { RoutesOptions } from '../types';
import { ChatService } from '../services/chat/chat_service';

const llmRequestRoute = {
path: ASSISTANT_API.SEND_MESSAGE,
Expand Down Expand Up @@ -145,24 +146,44 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
const storageService = createStorageService(context);
const chatService = createChatService();

let outputs: Awaited<ReturnType<ChatService['requestLLM']>> | undefined;

/**
* Get final answer from Agent framework
*/
try {
const outputs = await chatService.requestLLM(
outputs = await chatService.requestLLM(
{ messages, input, sessionId: sessionIdInRequestBody, rootAgentId },
context
);
const sessionId = outputs.memoryId;
const finalMessage = await storageService.getSession(sessionId);
} catch (error) {
context.assistant_plugin.logger.error(error);
const sessionId = outputs?.memoryId || sessionIdInRequestBody;
if (!sessionId) {
return response.custom({ statusCode: error.statusCode || 500, body: error.message });
}
}

/**
* Retrieve latest interactions from memory
*/
const sessionId = outputs?.memoryId || (sessionIdInRequestBody as string);
try {
if (!sessionId) {
throw new Error('Not a valid conversation');
}
const conversation = await storageService.getSession(sessionId);

return response.ok({
body: {
messages: finalMessage.messages,
sessionId: outputs.memoryId,
title: finalMessage.title,
interactions: finalMessage.interactions,
messages: conversation.messages,
sessionId,
title: conversation.title,
interactions: conversation.interactions,
},
});
} catch (error) {
context.assistant_plugin.logger.warn(error);
context.assistant_plugin.logger.error(error);
return response.custom({ statusCode: error.statusCode || 500, body: error.message });
}
}
Expand Down
87 changes: 87 additions & 0 deletions server/routes/get_session.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { ResponseObject } from '@hapi/hapi';
import { Boom } from '@hapi/boom';
import { Router } from '../../../../src/core/server/http/router';
import { enhanceWithContext, triggerHandler } from './router.mock';
import { httpServerMock } from '../../../../src/core/server/http/http_server.mocks';
import { mockAgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service.mock';
import { loggerMock } from '../../../../src/core/server/logging/logger.mock';
import { GetSessionSchema, registerChatRoutes } from './chat_routes';
import { ASSISTANT_API } from '../../common/constants/llm';

const mockedLogger = loggerMock.create();

const router = new Router(
'',
mockedLogger,
enhanceWithContext({
assistant_plugin: {
logger: mockedLogger,
},
})
);
registerChatRoutes(router, {
messageParsers: [],
});

describe('getSession route', () => {
const getSessionRequest = (payload: GetSessionSchema) =>
triggerHandler(router, {
method: 'get',
path: `${ASSISTANT_API.SESSION}/{sessionId}`,
req: httpServerMock.createRawRequest({
params: payload,
}),
});
beforeEach(() => {
loggerMock.clear(mockedLogger);
});
it('return back successfully when getSession returns session back', async () => {
mockAgentFrameworkStorageService.getSession.mockImplementationOnce(async () => {
return {
messages: [],
title: 'foo',
interactions: [],
createdTimeMs: 0,
updatedTimeMs: 0,
};
});
const result = (await getSessionRequest({
sessionId: '1',
})) as ResponseObject;
expect(result.source).toMatchInlineSnapshot(`
Object {
"createdTimeMs": 0,
"interactions": Array [],
"messages": Array [],
"title": "foo",
"updatedTimeMs": 0,
}
`);
});

it('return 500 when getSession throws error', async () => {
mockAgentFrameworkStorageService.getSession.mockImplementationOnce(() => {
throw new Error('getSession error');
});
const result = (await getSessionRequest({
sessionId: '1',
})) as Boom;
expect(mockedLogger.error).toBeCalledTimes(1);
expect(result.output).toMatchInlineSnapshot(`
Object {
"headers": Object {},
"payload": Object {
"error": "Internal Server Error",
"message": "getSession error",
"statusCode": 500,
},
"statusCode": 500,
}
`);
});
});
83 changes: 83 additions & 0 deletions server/routes/get_sessions.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { ResponseObject } from '@hapi/hapi';
import { Boom } from '@hapi/boom';
import { Router } from '../../../../src/core/server/http/router';
import { enhanceWithContext, triggerHandler } from './router.mock';
import { mockAgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service.mock';
import { httpServerMock } from '../../../../src/core/server/http/http_server.mocks';
import { loggerMock } from '../../../../src/core/server/logging/logger.mock';
import { GetSessionsSchema, registerChatRoutes } from './chat_routes';
import { ASSISTANT_API } from '../../common/constants/llm';

const mockedLogger = loggerMock.create();

const router = new Router(
'',
mockedLogger,
enhanceWithContext({
assistant_plugin: {
logger: mockedLogger,
},
})
);
registerChatRoutes(router, {
messageParsers: [],
});

describe('getSessions route', () => {
const getSessionsRequest = (payload: GetSessionsSchema) =>
triggerHandler(router, {
method: 'get',
path: `${ASSISTANT_API.SESSIONS}`,
req: httpServerMock.createRawRequest({
query: payload,
}),
});
beforeEach(() => {
loggerMock.clear(mockedLogger);
});
it('return back successfully when getSessions returns sessions back', async () => {
mockAgentFrameworkStorageService.getSessions.mockImplementationOnce(async () => {
return {
objects: [],
total: 0,
};
});
const result = (await getSessionsRequest({
perPage: 10,
page: 1,
})) as ResponseObject;
expect(result.source).toMatchInlineSnapshot(`
Object {
"objects": Array [],
"total": 0,
}
`);
});

it('return 500 when getSessions throws error', async () => {
mockAgentFrameworkStorageService.getSessions.mockImplementationOnce(() => {
throw new Error('getSessions error');
});
const result = (await getSessionsRequest({
perPage: 10,
page: 1,
})) as Boom;
expect(mockedLogger.error).toBeCalledTimes(1);
expect(result.output).toMatchInlineSnapshot(`
Object {
"headers": Object {},
"payload": Object {
"error": "Internal Server Error",
"message": "getSessions error",
"statusCode": 500,
},
"statusCode": 500,
}
`);
});
});
125 changes: 125 additions & 0 deletions server/routes/router.mock.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import {
Auth,
AuthenticationData,
Request,
ResponseObject,
ResponseToolkit,
ServerRealm,
ServerStateCookieOptions,
} from '@hapi/hapi';
// @ts-ignore
import Response from '@hapi/hapi/lib/response';
import { ProxyHandlerOptions } from '@hapi/h2o2';
import { ReplyFileHandlerOptions } from '@hapi/inert';
import { httpServerMock } from '../../../../src/core/server/http/http_server.mocks';
import {
OpenSearchDashboardsRequest,
OpenSearchDashboardsResponseFactory,
Router,
} from '../../../../src/core/server/http/router';
import { CoreRouteHandlerContext } from '../../../../src/core/server/core_route_handler_context';
import { coreMock } from '../../../../src/core/server/mocks';

/**
* For hapi, ResponseToolkit is an internal implementation
* so we have to create a MockResponseToolkit to mock the behavior.
* This class should be put under OSD core,
*/
export class MockResponseToolkit implements ResponseToolkit {
abandon: symbol = Symbol('abandon');
close: symbol = Symbol('close');
context: unknown;
continue: symbol = Symbol('continue');
realm: ServerRealm = {
modifiers: {
route: {
prefix: '',
vhost: '',
},
},
parent: null,
plugin: '',
pluginOptions: {},
plugins: [],
settings: {
files: {
relativeTo: '',
},
bind: {},
},
};
request: Readonly<Request> = httpServerMock.createRawRequest();
authenticated(): Auth {
throw new Error('Method not implemented.');
}
entity(
options?:
| { etag?: string | undefined; modified?: string | undefined; vary?: boolean | undefined }
| undefined
): ResponseObject | undefined {
throw new Error('Method not implemented.');
}
redirect(uri?: string | undefined): ResponseObject {
throw new Error('Method not implemented.');
}
state(
name: string,
value: string | object,
options?: ServerStateCookieOptions | undefined
): void {
throw new Error('Method not implemented.');
}
unauthenticated(error: Error, data?: AuthenticationData | undefined): void {
throw new Error('Method not implemented.');
}
unstate(name: string, options?: ServerStateCookieOptions | undefined): void {
throw new Error('Method not implemented.');
}
file(path: string, options?: ReplyFileHandlerOptions | undefined): ResponseObject {
throw new Error('Method not implemented.');
}
proxy(options: ProxyHandlerOptions): Promise<ResponseObject> {
throw new Error('Method not implemented.');
}
response(payload: unknown) {
return new Response(payload);
}
}

const enhanceWithContext = (otherContext?: object) => (fn: (...args: unknown[]) => unknown) => (
req: OpenSearchDashboardsRequest,
res: OpenSearchDashboardsResponseFactory
) => {
const context = new CoreRouteHandlerContext(coreMock.createInternalStart(), req);
return fn.call(
null,
{
core: context,
...otherContext,
},
req,
res
);
};

const triggerHandler = async (
router: Router,
options: {
method: string;
path: string;
req: Request;
}
) => {
const allRoutes = router.getRoutes();
const findRoute = allRoutes.find(
(item) => item.method === options.method && item.path === options.path
);
return await findRoute?.handler(options.req, new MockResponseToolkit());
};

export { enhanceWithContext, triggerHandler };
Loading

0 comments on commit 55e831c

Please sign in to comment.