From f836af35f0d7e86a52e4bf82b7547d836bd550b9 Mon Sep 17 00:00:00 2001 From: tygao Date: Tue, 28 May 2024 23:09:41 +0800 Subject: [PATCH] test: add tests for chat route Signed-off-by: tygao --- server/routes/chat_routes.test.ts | 185 ++++++++++++++++++++++++++++-- 1 file changed, 174 insertions(+), 11 deletions(-) diff --git a/server/routes/chat_routes.test.ts b/server/routes/chat_routes.test.ts index 41f9c195..cbe03ad9 100644 --- a/server/routes/chat_routes.test.ts +++ b/server/routes/chat_routes.test.ts @@ -14,9 +14,24 @@ import { mockOllyChatService } from '../services/chat/olly_chat_service.mock'; import { loggerMock } from '../../../../src/core/server/logging/logger.mock'; import { registerChatRoutes } from './chat_routes'; import { ASSISTANT_API } from '../../common/constants/llm'; +import { getOpenSearchClientTransport } from '../utils/get_opensearch_client_transport'; -const mockedLogger = loggerMock.create(); +jest.mock('../utils/get_opensearch_client_transport'); + +beforeEach(() => { + (getOpenSearchClientTransport as jest.Mock).mockImplementation(({ dataSourceId }) => { + if (dataSourceId) { + return 'dataSource-client'; + } else { + return 'client'; + } + }); +}); +afterEach(() => { + (getOpenSearchClientTransport as jest.Mock).mockClear(); +}); +const mockedLogger = loggerMock.create(); const router = new Router( '', mockedLogger, @@ -30,38 +45,90 @@ registerChatRoutes(router, { messageParsers: [], }); -const triggerDeleteConversation = (conversationId: string) => +const triggerDeleteConversation = (conversationId: string, dataSourceId?: string) => triggerHandler(router, { method: 'delete', path: `${ASSISTANT_API.CONVERSATION}/{conversationId}`, - req: httpServerMock.createRawRequest({ params: { conversationId } }), + req: httpServerMock.createRawRequest({ + params: { conversationId }, + ...(dataSourceId + ? { + query: { + dataSourceId, + }, + } + : {}), + }), }); const triggerUpdateConversation = ( params: { conversationId: string }, - payload: { title: string } + payload: { title: string }, + dataSourceId?: string ) => triggerHandler(router, { method: 'put', path: `${ASSISTANT_API.CONVERSATION}/{conversationId}`, - req: httpServerMock.createRawRequest({ params, payload }), + req: httpServerMock.createRawRequest({ + params, + payload, + ...(dataSourceId + ? { + query: { + dataSourceId, + }, + } + : {}), + }), }); -const triggerGetTrace = (interactionId: string) => +const triggerGetTrace = (interactionId: string, dataSourceId?: string) => triggerHandler(router, { method: 'get', path: `${ASSISTANT_API.TRACE}/{interactionId}`, - req: httpServerMock.createRawRequest({ params: { interactionId } }), + req: httpServerMock.createRawRequest({ + params: { interactionId }, + ...(dataSourceId + ? { + query: { + dataSourceId, + }, + } + : {}), + }), }); -const triggerAbortAgentExecution = (conversationId: string) => +const triggerAbortAgentExecution = (conversationId: string, dataSourceId?: string) => triggerHandler(router, { method: 'post', path: ASSISTANT_API.ABORT_AGENT_EXECUTION, - req: httpServerMock.createRawRequest({ payload: { conversationId } }), + req: httpServerMock.createRawRequest({ + payload: { conversationId }, + ...(dataSourceId + ? { + query: { + dataSourceId, + }, + } + : {}), + }), }); -const triggerFeedback = (params: { interactionId: string }, payload: { satisfaction: boolean }) => +const triggerFeedback = ( + params: { interactionId: string }, + payload: { satisfaction: boolean }, + dataSourceId?: string +) => triggerHandler(router, { method: 'put', path: `${ASSISTANT_API.FEEDBACK}/{interactionId}`, - req: httpServerMock.createRawRequest({ params, payload }), + req: httpServerMock.createRawRequest({ + params, + payload, + ...(dataSourceId + ? { + query: { + dataSourceId, + }, + } + : {}), + }), }); describe('chat routes', () => { @@ -80,6 +147,22 @@ describe('chat routes', () => { expect(mockAgentFrameworkStorageService.deleteConversation).not.toHaveBeenCalled(); const result = (await triggerDeleteConversation('foo')) as ResponseObject; + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('client'); + expect(mockAgentFrameworkStorageService.deleteConversation).toHaveBeenCalledWith('foo'); + expect(result.source).toMatchInlineSnapshot(` + Object { + "success": true, + } + `); + }); + + it('should call delete conversation with passed data source id and get data source transport', async () => { + mockAgentFrameworkStorageService.deleteConversation.mockResolvedValueOnce({ + success: true, + }); + expect(mockAgentFrameworkStorageService.deleteConversation).not.toHaveBeenCalled(); + const result = (await triggerDeleteConversation('foo', 'data_source_id')) as ResponseObject; + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('dataSource-client'); expect(mockAgentFrameworkStorageService.deleteConversation).toHaveBeenCalledWith('foo'); expect(result.source).toMatchInlineSnapshot(` Object { @@ -109,6 +192,7 @@ describe('chat routes', () => { { conversationId: 'foo' }, { title: 'new-title' } )) as ResponseObject; + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('client'); expect(mockAgentFrameworkStorageService.updateConversation).toHaveBeenCalledWith( 'foo', 'new-title' @@ -120,6 +204,29 @@ describe('chat routes', () => { `); }); + it('should call update conversation with passed data source id and title then get data source transport', async () => { + mockAgentFrameworkStorageService.updateConversation.mockResolvedValueOnce({ + success: true, + }); + + expect(mockAgentFrameworkStorageService.updateConversation).not.toHaveBeenCalled(); + const result = (await triggerUpdateConversation( + { conversationId: 'foo' }, + { title: 'new-title' }, + 'data_source_id' + )) as ResponseObject; + expect(mockAgentFrameworkStorageService.updateConversation).toHaveBeenCalledWith( + 'foo', + 'new-title' + ); + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('dataSource-client'); + expect(result.source).toMatchInlineSnapshot(` + Object { + "success": true, + } + `); + }); + it('should log error and return 500 error when failed to update conversation', async () => { mockAgentFrameworkStorageService.updateConversation.mockRejectedValueOnce(new Error()); @@ -149,6 +256,27 @@ describe('chat routes', () => { expect(mockAgentFrameworkStorageService.getTraces).not.toHaveBeenCalled(); const result = (await triggerGetTrace('interaction-1')) as ResponseObject; + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('client'); + expect(mockAgentFrameworkStorageService.getTraces).toHaveBeenCalledWith('interaction-1'); + expect(result.source).toEqual(getTraceResultMock); + }); + + it('should call get traces with passed data source id and get data source transport', async () => { + const getTraceResultMock = [ + { + interactionId: 'interaction-1', + createTime: '', + input: 'foo', + output: 'bar', + origin: '', + traceNumber: 0, + }, + ]; + mockAgentFrameworkStorageService.getTraces.mockResolvedValueOnce(getTraceResultMock); + + expect(mockAgentFrameworkStorageService.getTraces).not.toHaveBeenCalled(); + const result = (await triggerGetTrace('interaction-1', 'data_source_id')) as ResponseObject; + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('dataSource-client'); expect(mockAgentFrameworkStorageService.getTraces).toHaveBeenCalledWith('interaction-1'); expect(result.source).toEqual(getTraceResultMock); }); @@ -169,6 +297,16 @@ describe('chat routes', () => { await triggerAbortAgentExecution('foo'); expect(mockOllyChatService.abortAgentExecution).toHaveBeenCalledWith('foo'); + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('client'); + expect(mockedLogger.info).toHaveBeenCalledWith('Abort agent execution: foo'); + }); + + it('should call get abort agent with passed data source id and get data source transport ', async () => { + expect(mockOllyChatService.abortAgentExecution).not.toHaveBeenCalled(); + + await triggerAbortAgentExecution('foo', 'data_source_id'); + expect(mockOllyChatService.abortAgentExecution).toHaveBeenCalledWith('foo'); + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('dataSource-client'); expect(mockedLogger.info).toHaveBeenCalledWith('Abort agent execution: foo'); }); @@ -200,6 +338,31 @@ describe('chat routes', () => { { interactionId: 'foo' }, { satisfaction: true } )) as ResponseObject; + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('client'); + expect(mockAgentFrameworkStorageService.updateInteraction).toHaveBeenCalledWith('foo', { + feedback: { + satisfaction: true, + }, + }); + expect(result.source).toMatchInlineSnapshot(` + Object { + "success": true, + } + `); + }); + + it('should call update interaction with passed data source id and get data source transport', async () => { + mockAgentFrameworkStorageService.updateConversation.mockResolvedValueOnce({ + success: true, + }); + + expect(mockAgentFrameworkStorageService.updateConversation).not.toHaveBeenCalled(); + const result = (await triggerFeedback( + { interactionId: 'foo' }, + { satisfaction: true }, + 'data_source_id' + )) as ResponseObject; + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('dataSource-client'); expect(mockAgentFrameworkStorageService.updateInteraction).toHaveBeenCalledWith('foo', { feedback: { satisfaction: true,