Skip to content

Commit

Permalink
feat: add unit test for regenerate route
Browse files Browse the repository at this point in the history
Signed-off-by: SuZhou-Joe <[email protected]>
  • Loading branch information
SuZhou-Joe committed Dec 18, 2023
1 parent b51d668 commit 124aedc
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 1 deletion.
180 changes: 180 additions & 0 deletions server/routes/regenerate.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { ResponseObject } from '@hapi/hapi';
import { Boom } from '@hapi/boom';
import { Router } from '../../../../src/core/server/http/router';
import { enhanceWithContext, triggerHandler } from './router.mock';
import { mockOllyChatService } from '../services/chat/olly_chat_service.mock';
import { mockAgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service.mock';
import { httpServerMock } from '../../../../src/core/server/http/http_server.mocks';
import { loggerMock } from '../../../../src/core/server/logging/logger.mock';
import { registerChatRoutes, RegenerateSchema, AgentIdNotFoundError } from './chat_routes';
import { ASSISTANT_API } from '../../common/constants/llm';

const mockedLogger = loggerMock.create();

describe('regenerate route when rootAgentId is provided', () => {
const router = new Router(
'',
mockedLogger,
enhanceWithContext({
assistant_plugin: {
logger: mockedLogger,
},
})
);
registerChatRoutes(router, {
messageParsers: [],
rootAgentId: 'foo',
});
const regenerateRequest = (payload: RegenerateSchema) =>
triggerHandler(router, {
method: 'put',
path: ASSISTANT_API.REGENERATE,
req: httpServerMock.createRawRequest({
payload: JSON.stringify(payload),
}),
});
beforeEach(() => {
loggerMock.clear(mockedLogger);
});
it('return back successfully when requestLLM returns momery back', async () => {
mockOllyChatService.regenerate.mockImplementationOnce(async () => {
return {
messages: [],
memoryId: 'foo',
};
});
mockAgentFrameworkStorageService.getSession.mockImplementationOnce(async () => {
return {
messages: [],
title: 'foo',
interactions: [],
createdTimeMs: 0,
updatedTimeMs: 0,
};
});
const result = (await regenerateRequest({
sessionId: 'foo',
interactionId: 'bar',
})) as ResponseObject;
expect(result.source).toMatchInlineSnapshot(`
Object {
"createdTimeMs": 0,
"interactions": Array [],
"messages": Array [],
"sessionId": "foo",
"title": "foo",
"updatedTimeMs": 0,
}
`);
});

it('log error when requestLLM throws an error', async () => {
mockOllyChatService.regenerate.mockImplementationOnce(() => {
throw new Error('something went wrong');
});
mockAgentFrameworkStorageService.getSession.mockImplementationOnce(async () => {
return {
messages: [],
title: 'foo',
interactions: [],
createdTimeMs: 0,
updatedTimeMs: 0,
};
});
const result = (await regenerateRequest({
sessionId: 'foo',
interactionId: 'bar',
})) as ResponseObject;
expect(mockedLogger.error).toBeCalledTimes(1);
expect(result.source).toMatchInlineSnapshot(`
Object {
"createdTimeMs": 0,
"interactions": Array [],
"messages": Array [],
"sessionId": "foo",
"title": "foo",
"updatedTimeMs": 0,
}
`);
});

it('return 500 when get session throws an error', async () => {
mockOllyChatService.regenerate.mockImplementationOnce(async () => {
return {
messages: [],
memoryId: 'foo',
};
});
mockAgentFrameworkStorageService.getSession.mockImplementationOnce(() => {
throw new Error('foo');
});
const result = (await regenerateRequest({
sessionId: 'foo',
interactionId: 'bar',
})) as Boom;
expect(mockedLogger.error).toBeCalledTimes(1);
expect(mockedLogger.error).toBeCalledWith(new Error('foo'));
expect(result.output).toMatchInlineSnapshot(`
Object {
"headers": Object {},
"payload": Object {
"error": "Internal Server Error",
"message": "foo",
"statusCode": 500,
},
"statusCode": 500,
}
`);
});
});

describe('regenerate route when rootAgentId is not provided', () => {
const router = new Router(
'',
mockedLogger,
enhanceWithContext({
assistant_plugin: {
logger: mockedLogger,
},
})
);
registerChatRoutes(router, {
messageParsers: [],
});
const regenerateRequest = (payload: RegenerateSchema) =>
triggerHandler(router, {
method: 'put',
path: ASSISTANT_API.REGENERATE,
req: httpServerMock.createRawRequest({
payload: JSON.stringify(payload),
}),
});
beforeEach(() => {
loggerMock.clear(mockedLogger);
});

it('return 400', async () => {
const result = (await regenerateRequest({
interactionId: 'bar',
sessionId: 'foo',
})) as Boom;
expect(mockedLogger.error).toBeCalledTimes(1);
expect(mockedLogger.error).toBeCalledWith(AgentIdNotFoundError);
expect(result.output).toMatchInlineSnapshot(`
Object {
"headers": Object {},
"payload": Object {
"error": "Bad Request",
"message": "rootAgentId is required, please specify one in opensearch_dashboards.yml",
"statusCode": 400,
},
"statusCode": 400,
}
`);
});
});
4 changes: 3 additions & 1 deletion server/services/chat/olly_chat_service.mock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
* SPDX-License-Identifier: Apache-2.0
*/

import { PublicContract } from '@osd/utility-types';
import { OllyChatService } from './olly_chat_service';

const mockOllyChatService: jest.Mocked<OllyChatService> = {
const mockOllyChatService: jest.Mocked<PublicContract<OllyChatService>> = {
requestLLM: jest.fn(),
regenerate: jest.fn(),
abortAgentExecution: jest.fn(),
};

Expand Down

0 comments on commit 124aedc

Please sign in to comment.