diff --git a/CHANGELOG.md b/CHANGELOG.md index 61c07477..27ece560 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,3 +18,4 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Hide notebook feature when MDS enabled and remove security dashboard plugin dependency ([#201](https://github.com/opensearch-project/dashboards-assistant/pull/201)) - Refactor default data source retriever ([#197](https://github.com/opensearch-project/dashboards-assistant/pull/197)) - Add patch style for fixed components ([#203](https://github.com/opensearch-project/dashboards-assistant/pull/203)) +- Reset chat and reload history after data source change ([#194](https://github.com/opensearch-project/dashboards-assistant/pull/194)) diff --git a/public/chat_header_button.tsx b/public/chat_header_button.tsx index 9723010a..773363f4 100644 --- a/public/chat_header_button.tsx +++ b/public/chat_header_button.tsx @@ -7,6 +7,7 @@ import { EuiBadge, EuiFieldText, EuiIcon } from '@elastic/eui'; import classNames from 'classnames'; import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { useEffectOnce } from 'react-use'; + import { ApplicationStart, SIDECAR_DOCKED_MODE } from '../../../src/core/public'; // TODO: Replace with getChrome().logos.Chat.url import chatIcon from './assets/chat.svg'; diff --git a/public/components/agent_framework_traces_flyout_body.test.tsx b/public/components/agent_framework_traces_flyout_body.test.tsx index c204a039..6cdbbcd6 100644 --- a/public/components/agent_framework_traces_flyout_body.test.tsx +++ b/public/components/agent_framework_traces_flyout_body.test.tsx @@ -5,10 +5,12 @@ import React from 'react'; import '@testing-library/jest-dom/extend-expect'; -import { act, waitFor, render, screen, fireEvent } from '@testing-library/react'; +import { waitFor, render, screen, fireEvent } from '@testing-library/react'; import * as chatContextExports from '../contexts/chat_context'; +import * as coreContextExports from '../contexts/core_context'; import { AgentFrameworkTracesFlyoutBody } from './agent_framework_traces_flyout_body'; import { TAB_ID } from '../utils/constants'; +import { BehaviorSubject, Subject } from 'rxjs'; jest.mock('./agent_framework_traces', () => { return { @@ -17,6 +19,20 @@ jest.mock('./agent_framework_traces', () => { }); describe(' spec', () => { + let dataSourceIdUpdates$: Subject; + beforeEach(() => { + dataSourceIdUpdates$ = new Subject(); + jest.spyOn(coreContextExports, 'useCore').mockImplementation(() => { + return { + services: { + dataSource: { + dataSourceIdUpdates$, + }, + }, + }; + }); + }); + it('show back button if interactionId exists', async () => { const onCloseMock = jest.fn(); jest.spyOn(chatContextExports, 'useChatContext').mockReturnValue({ @@ -70,4 +86,19 @@ describe(' spec', () => { expect(onCloseMock).toHaveBeenCalledWith(TAB_ID.HISTORY); }); }); + + it('should set tab to chat after data source changed', () => { + const setSelectedTabIdMock = jest.fn(); + jest.spyOn(chatContextExports, 'useChatContext').mockReturnValue({ + interactionId: 'test-interaction-id', + flyoutFullScreen: true, + setSelectedTabId: setSelectedTabIdMock, + preSelectedTabId: TAB_ID.HISTORY, + }); + render(); + + expect(setSelectedTabIdMock).not.toHaveBeenCalled(); + dataSourceIdUpdates$.next('foo'); + expect(setSelectedTabIdMock).toHaveBeenCalled(); + }); }); diff --git a/public/components/agent_framework_traces_flyout_body.tsx b/public/components/agent_framework_traces_flyout_body.tsx index 5cf12b0c..b18f8004 100644 --- a/public/components/agent_framework_traces_flyout_body.tsx +++ b/public/components/agent_framework_traces_flyout_body.tsx @@ -13,14 +13,26 @@ import { EuiButtonIcon, EuiPageHeaderSection, } from '@elastic/eui'; -import React from 'react'; +import React, { useEffect } from 'react'; import { useChatContext } from '../contexts/chat_context'; +import { useCore } from '../../public/contexts'; import { AgentFrameworkTraces } from './agent_framework_traces'; import { TAB_ID } from '../utils/constants'; export const AgentFrameworkTracesFlyoutBody: React.FC = () => { + const core = useCore(); const chatContext = useChatContext(); const interactionId = chatContext.interactionId; + + useEffect(() => { + const subscription = core.services.dataSource.dataSourceIdUpdates$.subscribe(() => { + chatContext.setSelectedTabId(TAB_ID.CHAT); + }); + return () => { + subscription.unsubscribe(); + }; + }, [core.services.dataSource, chatContext.setSelectedTabId]); + if (!interactionId) { return null; } diff --git a/public/hooks/use_chat_actions.test.tsx b/public/hooks/use_chat_actions.test.tsx index 26ca0096..eb1db675 100644 --- a/public/hooks/use_chat_actions.test.tsx +++ b/public/hooks/use_chat_actions.test.tsx @@ -26,7 +26,14 @@ jest.mock('../services/conversations_service', () => { jest.mock('../services/conversation_load_service', () => { return { ConversationLoadService: jest.fn().mockImplementation(() => { - return { load: jest.fn().mockReturnValue({ messages: [], interactions: [] }) }; + const conversationLoadMock = { + abortController: new AbortController(), + load: jest.fn().mockImplementation(async () => { + conversationLoadMock.abortController = new AbortController(); + return { messages: [], interactions: [] }; + }), + }; + return conversationLoadMock; }), }; }); @@ -126,7 +133,7 @@ describe('useChatActions hook', () => { messages: [SEND_MESSAGE_RESPONSE.messages[0]], input: INPUT_MESSAGE, }), - query: await dataSourceServiceMock.getDataSourceQuery(), + query: dataSourceServiceMock.getDataSourceQuery(), }); // it should send dispatch `receive` action to remove the message without messageId @@ -201,7 +208,7 @@ describe('useChatActions hook', () => { messages: [], input: { type: 'input', content: 'message that send as input', contentType: 'text' }, }), - query: await dataSourceServiceMock.getDataSourceQuery(), + query: dataSourceServiceMock.getDataSourceQuery(), }); }); @@ -264,7 +271,7 @@ describe('useChatActions hook', () => { expect(chatStateDispatchMock).toHaveBeenCalledWith({ type: 'abort' }); expect(httpMock.post).toHaveBeenCalledWith(ASSISTANT_API.ABORT_AGENT_EXECUTION, { body: JSON.stringify({ conversationId: 'conversation_id_to_abort' }), - query: await dataSourceServiceMock.getDataSourceQuery(), + query: dataSourceServiceMock.getDataSourceQuery(), }); }); @@ -292,7 +299,7 @@ describe('useChatActions hook', () => { conversationId: 'conversation_id_mock', interactionId: 'interaction_id_mock', }), - query: await dataSourceServiceMock.getDataSourceQuery(), + query: dataSourceServiceMock.getDataSourceQuery(), }); expect(chatStateDispatchMock).toHaveBeenCalledWith( expect.objectContaining({ type: 'receive', payload: { messages: [], interactions: [] } }) @@ -312,6 +319,7 @@ describe('useChatActions hook', () => { it('should not handle regenerate response if the regenerate operation has already aborted', async () => { const AbortControllerMock = jest.spyOn(window, 'AbortController').mockImplementation(() => ({ signal: { aborted: true }, + abort: jest.fn(), })); httpMock.put.mockResolvedValue(SEND_MESSAGE_RESPONSE); @@ -328,7 +336,7 @@ describe('useChatActions hook', () => { conversationId: 'conversation_id_mock', interactionId: 'interaction_id_mock', }), - query: await dataSourceServiceMock.getDataSourceQuery(), + query: dataSourceServiceMock.getDataSourceQuery(), }); expect(chatStateDispatchMock).not.toHaveBeenCalledWith( expect.objectContaining({ type: 'receive' }) @@ -353,6 +361,7 @@ describe('useChatActions hook', () => { it('should not handle regenerate error if the regenerate operation has already aborted', async () => { const AbortControllerMock = jest.spyOn(window, 'AbortController').mockImplementation(() => ({ signal: { aborted: true }, + abort: jest.fn(), })); httpMock.put.mockImplementationOnce(() => { throw new Error(); @@ -369,4 +378,43 @@ describe('useChatActions hook', () => { ); AbortControllerMock.mockRestore(); }); + + it('should clear chat title, conversation id, flyoutComponent and call reset action', async () => { + const { result } = renderHook(() => useChatActions()); + result.current.resetChat(); + + expect(chatContextMock.setConversationId).toHaveBeenLastCalledWith(undefined); + expect(chatContextMock.setTitle).toHaveBeenLastCalledWith(undefined); + expect(chatContextMock.setFlyoutComponent).toHaveBeenLastCalledWith(null); + + expect(chatStateDispatchMock).toHaveBeenLastCalledWith({ type: 'reset' }); + }); + + it('should abort send action after reset chat', async () => { + const abortFn = jest.fn(); + const AbortControllerMock = jest.spyOn(window, 'AbortController').mockImplementation(() => ({ + signal: { aborted: true }, + abort: abortFn, + })); + const { result } = renderHook(() => useChatActions()); + await result.current.send(INPUT_MESSAGE); + result.current.resetChat(); + + expect(abortFn).toHaveBeenCalled(); + AbortControllerMock.mockRestore(); + }); + + it('should abort load action after reset chat', async () => { + const abortFn = jest.fn(); + const AbortControllerMock = jest.spyOn(window, 'AbortController').mockImplementation(() => ({ + signal: { aborted: true }, + abort: abortFn, + })); + const { result } = renderHook(() => useChatActions()); + await result.current.loadChat('conversation_id_mock'); + result.current.resetChat(); + + expect(abortFn).toHaveBeenCalled(); + AbortControllerMock.mockRestore(); + }); }); diff --git a/public/hooks/use_chat_actions.tsx b/public/hooks/use_chat_actions.tsx index 962420f2..5c3c239a 100644 --- a/public/hooks/use_chat_actions.tsx +++ b/public/hooks/use_chat_actions.tsx @@ -35,7 +35,7 @@ export const useChatActions = (): AssistantActions => { ...(!chatContext.conversationId && { messages: chatState.messages }), // include all previous messages for new chats input, }), - query: await core.services.dataSource.getDataSourceQuery(), + query: core.services.dataSource.getDataSourceQuery(), }); if (abortController.signal.aborted) return; // Refresh history list after new conversation created if new conversation saved and history list page visible @@ -106,6 +106,15 @@ export const useChatActions = (): AssistantActions => { } }; + const resetChat = () => { + abortControllerRef?.abort(); + core.services.conversationLoad.abortController?.abort(); + chatContext.setConversationId(undefined); + chatContext.setTitle(undefined); + chatContext.setFlyoutComponent(null); + chatStateDispatch({ type: 'reset' }); + }; + const openChatUI = () => { chatContext.setFlyoutVisible(true); chatContext.setSelectedTabId(TAB_ID.CHAT); @@ -163,7 +172,7 @@ export const useChatActions = (): AssistantActions => { // abort agent execution await core.services.http.post(`${ASSISTANT_API.ABORT_AGENT_EXECUTION}`, { body: JSON.stringify({ conversationId }), - query: await core.services.dataSource.getDataSourceQuery(), + query: core.services.dataSource.getDataSourceQuery(), }); } }; @@ -180,7 +189,7 @@ export const useChatActions = (): AssistantActions => { conversationId: chatContext.conversationId, interactionId, }), - query: await core.services.dataSource.getDataSourceQuery(), + query: core.services.dataSource.getDataSourceQuery(), }); if (abortController.signal.aborted) { @@ -225,5 +234,5 @@ export const useChatActions = (): AssistantActions => { } }; - return { send, loadChat, executeAction, openChatUI, abortAction, regenerate }; + return { send, loadChat, resetChat, executeAction, openChatUI, abortAction, regenerate }; }; diff --git a/public/hooks/use_conversations.ts b/public/hooks/use_conversations.ts index 2828a4af..516a982e 100644 --- a/public/hooks/use_conversations.ts +++ b/public/hooks/use_conversations.ts @@ -14,13 +14,13 @@ export const useDeleteConversation = () => { const abortControllerRef = useRef(); const deleteConversation = useCallback( - async (conversationId: string) => { + (conversationId: string) => { abortControllerRef.current = new AbortController(); dispatch({ type: 'request' }); return core.services.http .delete(`${ASSISTANT_API.CONVERSATION}/${conversationId}`, { signal: abortControllerRef.current.signal, - query: await core.services.dataSource.getDataSourceQuery(), + query: core.services.dataSource.getDataSourceQuery(), }) .then((payload) => { dispatch({ type: 'success', payload }); @@ -53,7 +53,7 @@ export const usePatchConversation = () => { const abortControllerRef = useRef(); const patchConversation = useCallback( - async (conversationId: string, title: string) => { + (conversationId: string, title: string) => { abortControllerRef.current = new AbortController(); dispatch({ type: 'request' }); return core.services.http @@ -61,7 +61,7 @@ export const usePatchConversation = () => { body: JSON.stringify({ title, }), - query: await core.services.dataSource.getDataSourceQuery(), + query: core.services.dataSource.getDataSourceQuery(), signal: abortControllerRef.current.signal, }) .then((payload) => dispatch({ type: 'success', payload })) diff --git a/public/hooks/use_feed_back.test.tsx b/public/hooks/use_feed_back.test.tsx index 74de2031..943aab2f 100644 --- a/public/hooks/use_feed_back.test.tsx +++ b/public/hooks/use_feed_back.test.tsx @@ -84,7 +84,7 @@ describe('useFeedback hook', () => { body: JSON.stringify({ satisfaction: true, }), - query: await dataSourceMock.getDataSourceQuery(), + query: dataSourceMock.getDataSourceQuery(), } ); expect(result.current.feedbackResult).toBe(true); @@ -119,7 +119,7 @@ describe('useFeedback hook', () => { body: JSON.stringify({ satisfaction: true, }), - query: await dataSourceMock.getDataSourceQuery(), + query: dataSourceMock.getDataSourceQuery(), } ); expect(result.current.feedbackResult).toBe(undefined); diff --git a/public/hooks/use_feed_back.tsx b/public/hooks/use_feed_back.tsx index 2222b301..7a37bd37 100644 --- a/public/hooks/use_feed_back.tsx +++ b/public/hooks/use_feed_back.tsx @@ -38,7 +38,7 @@ export const useFeedback = (interaction?: Interaction | null) => { try { await core.services.http.put(`${ASSISTANT_API.FEEDBACK}/${message.interactionId}`, { body: JSON.stringify(body), - query: await core.services.dataSource.getDataSourceQuery(), + query: core.services.dataSource.getDataSourceQuery(), }); setFeedbackResult(correct); } catch (error) { diff --git a/public/hooks/use_fetch_agentframework_traces.ts b/public/hooks/use_fetch_agentframework_traces.ts index 43925855..7ad43668 100644 --- a/public/hooks/use_fetch_agentframework_traces.ts +++ b/public/hooks/use_fetch_agentframework_traces.ts @@ -22,26 +22,24 @@ export const useFetchAgentFrameworkTraces = (interactionId: string) => { return; } - core.services.dataSource.getDataSourceQuery().then((query) => { - core.services.http - .get(`${ASSISTANT_API.TRACE}/${interactionId}`, { - signal: abortController.signal, - query, + core.services.http + .get(`${ASSISTANT_API.TRACE}/${interactionId}`, { + signal: abortController.signal, + query: core.services.dataSource.getDataSourceQuery(), + }) + .then((payload) => + dispatch({ + type: 'success', + payload, }) - .then((payload) => - dispatch({ - type: 'success', - payload, - }) - ) - .catch((error) => { - if (error.name === 'AbortError') return; - dispatch({ type: 'failure', error }); - }); - }); + ) + .catch((error) => { + if (error.name === 'AbortError') return; + dispatch({ type: 'failure', error }); + }); return () => abortController.abort(); - }, [core.services.http, interactionId]); + }, [core.services.http, interactionId, core.services.dataSource]); return { ...state }; }; diff --git a/public/plugin.tsx b/public/plugin.tsx index 8e5c1151..89c4f545 100644 --- a/public/plugin.tsx +++ b/public/plugin.tsx @@ -5,6 +5,7 @@ import { EuiLoadingSpinner } from '@elastic/eui'; import React, { lazy, Suspense } from 'react'; +import { Subscription } from 'rxjs'; import { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '../../../src/core/public'; import { createOpenSearchDashboardsReactContext, @@ -61,6 +62,7 @@ export class AssistantPlugin private config: ConfigSchema; incontextInsightRegistry: IncontextInsightRegistry | undefined; private dataSourceService: DataSourceService; + private resetChatSubscription: Subscription | undefined; constructor(initializerContext: PluginInitializerContext) { this.config = initializerContext.config.get(); @@ -108,6 +110,12 @@ export class AssistantPlugin const username = account.user_name; this.incontextInsightRegistry?.setIsEnabled(this.config.incontextInsight.enabled); + if (this.dataSourceService.isMDSEnabled()) { + this.resetChatSubscription = this.dataSourceService.dataSourceIdUpdates$.subscribe(() => { + assistantActions.resetChat?.(); + }); + } + coreStart.chrome.navControls.registerRight({ order: 10000, mount: toMountPoint( @@ -163,5 +171,6 @@ export class AssistantPlugin public stop() { this.dataSourceService.stop(); + this.resetChatSubscription?.unsubscribe(); } } diff --git a/public/services/__tests__/data_source_service.test.ts b/public/services/__tests__/data_source_service.test.ts index d1db20dc..a14e5faf 100644 --- a/public/services/__tests__/data_source_service.test.ts +++ b/public/services/__tests__/data_source_service.test.ts @@ -111,6 +111,22 @@ describe('DataSourceService', () => { }); expect(await dataSource.getDataSourceId$().pipe(first()).toPromise()).toBe('foo'); }); + + it('should not fire change for same data source id', async () => { + const { dataSource, defaultDataSourceSelection$ } = setup({ + dataSourceSelection: new Map(), + defaultDataSourceId: 'foo', + }); + const observerFn = jest.fn(); + dataSource.getDataSourceId$().subscribe(observerFn); + + expect(observerFn).toHaveBeenCalledTimes(1); + dataSource.setDataSourceId('foo'); + expect(observerFn).toHaveBeenCalledTimes(1); + + defaultDataSourceSelection$.next('foo'); + expect(observerFn).toHaveBeenCalledTimes(1); + }); }); describe('isMDSEnabled', () => { @@ -126,23 +142,23 @@ describe('DataSourceService', () => { describe('getDataSourceQuery', () => { it('should return empty object if MDS not enabled', async () => { const { dataSource } = setup({ dataSourceManagement: undefined }); - expect(await dataSource.getDataSourceQuery()).toEqual({}); + expect(dataSource.getDataSourceQuery()).toEqual({}); }); it('should return empty object if data source id is empty', async () => { const { dataSource } = setup({ dataSourceSelection: new Map([['test', [{ label: '', id: '' }]]]), }); - expect(await dataSource.getDataSourceQuery()).toEqual({}); + expect(dataSource.getDataSourceQuery()).toEqual({}); }); it('should return query object with provided data source id', async () => { const { dataSource } = setup({ defaultDataSourceId: 'foo' }); - expect(await dataSource.getDataSourceQuery()).toEqual({ dataSourceId: 'foo' }); + expect(dataSource.getDataSourceQuery()).toEqual({ dataSourceId: 'foo' }); }); it('should throw error if data source id not exists', async () => { const { dataSource } = setup(); let error; try { - await dataSource.getDataSourceQuery(); + dataSource.getDataSourceQuery(); } catch (e) { error = e; } @@ -209,4 +225,15 @@ describe('DataSourceService', () => { dataSource.setDataSourceId('bar'); expect(observerFn).toHaveBeenCalledTimes(3); }); + + it('should emit new data source id updates after data source id change', () => { + const { dataSource } = setup(); + const observerFn = jest.fn(); + dataSource.dataSourceIdUpdates$.subscribe(observerFn); + dataSource.setDataSourceId('foo'); + expect(observerFn).toHaveBeenCalledTimes(1); + + dataSource.setDataSourceId('bar'); + expect(observerFn).toHaveBeenCalledTimes(2); + }); }); diff --git a/public/services/conversation_load_service.ts b/public/services/conversation_load_service.ts index 866ad6d6..b2875a20 100644 --- a/public/services/conversation_load_service.ts +++ b/public/services/conversation_load_service.ts @@ -26,7 +26,7 @@ export class ConversationLoadService { `${ASSISTANT_API.CONVERSATION}/${conversationId}`, { signal: this.abortController.signal, - query: await this._dataSource.getDataSourceQuery(), + query: this._dataSource.getDataSourceQuery(), } ); this.status$.next('idle'); diff --git a/public/services/conversations_service.ts b/public/services/conversations_service.ts index ab188070..52973fea 100644 --- a/public/services/conversations_service.ts +++ b/public/services/conversations_service.ts @@ -29,7 +29,10 @@ export class ConversationsService { } load = async ( - query?: Pick + query?: Pick< + SavedObjectsFindOptions, + 'page' | 'perPage' | 'fields' | 'sortField' | 'sortOrder' | 'search' | 'searchFields' + > ) => { this.abortController?.abort(); this.abortController = new AbortController(); @@ -40,7 +43,7 @@ export class ConversationsService { await this._http.get(ASSISTANT_API.CONVERSATIONS, { query: { ...this._options, - ...(await this._dataSource.getDataSourceQuery()), + ...this._dataSource.getDataSourceQuery(), } as HttpFetchQuery, signal: this.abortController.signal, }) diff --git a/public/services/data_source_service.mock.ts b/public/services/data_source_service.mock.ts index 0ca0bea1..d8496607 100644 --- a/public/services/data_source_service.mock.ts +++ b/public/services/data_source_service.mock.ts @@ -10,14 +10,11 @@ export class DataSourceServiceMock { } getDataSourceQuery() { - const result = this._isMDSEnabled + return this._isMDSEnabled ? { dataSourceId: '', } : {}; - return new Promise((resolve) => { - resolve(result); - }); } isMDSEnabled() { diff --git a/public/services/data_source_service.ts b/public/services/data_source_service.ts index f89565e4..02f4bc21 100644 --- a/public/services/data_source_service.ts +++ b/public/services/data_source_service.ts @@ -3,8 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { BehaviorSubject, Subscription, combineLatest, of } from 'rxjs'; -import { first, map } from 'rxjs/operators'; +import { BehaviorSubject, Observable, Subject, Subscription, combineLatest, of } from 'rxjs'; +import { distinctUntilChanged, map } from 'rxjs/operators'; import type { IUiSettingsClient } from '../../../../src/core/public'; import type { DataSourceOption } from '../../../../src/plugins/data_source_management/public/components/data_source_menu/types'; @@ -36,6 +36,9 @@ export class DataSourceService { private uiSettings: IUiSettingsClient | undefined; private dataSourceManagement: DataSourceManagementPluginSetup | undefined; private dataSourceSelectionSubscription: Subscription | undefined; + private finalDataSourceId: string | null = null; + dataSourceIdUpdates$ = new Subject(); + private getDataSourceIdSubscription: Subscription | undefined; constructor() {} @@ -54,11 +57,11 @@ export class DataSourceService { }); } - async getDataSourceQuery() { + getDataSourceQuery() { if (!this.isMDSEnabled()) { return {}; } - const dataSourceId = await this.getDataSourceId$().pipe(first()).toPromise(); + const dataSourceId = this.finalDataSourceId; if (dataSourceId === null) { throw new Error('No data source id'); } @@ -74,24 +77,24 @@ export class DataSourceService { } setDataSourceId(newDataSourceId: string | null) { - if (this.dataSourceId$.getValue() === newDataSourceId) { - return; - } this.dataSourceId$.next(newDataSourceId); } getDataSourceId$() { return combineLatest([ this.dataSourceId$, - this.dataSourceManagement?.getDefaultDataSourceId$?.(this.uiSettings) ?? of(null), - ]).pipe( - map(([selectedDataSourceId, defaultDataSourceId]) => { - if (selectedDataSourceId !== null) { - return selectedDataSourceId; - } - return defaultDataSourceId; - }) - ); + (this.dataSourceManagement?.getDefaultDataSourceId$?.(this.uiSettings) ?? + of(null)) as Observable, + ]) + .pipe( + map(([selectedDataSourceId, defaultDataSourceId]) => { + if (selectedDataSourceId !== null) { + return selectedDataSourceId; + } + return defaultDataSourceId; + }) + ) + .pipe(distinctUntilChanged()); } setup({ @@ -104,6 +107,10 @@ export class DataSourceService { this.uiSettings = uiSettings; this.dataSourceManagement = dataSourceManagement; this.init(); + this.getDataSourceIdSubscription = this.getDataSourceId$().subscribe((finalDataSourceId) => { + this.finalDataSourceId = finalDataSourceId; + this.dataSourceIdUpdates$.next(finalDataSourceId); + }); return { setDataSourceId: (newDataSourceId: string | null) => { @@ -122,6 +129,8 @@ export class DataSourceService { public stop() { this.dataSourceSelectionSubscription?.unsubscribe(); + this.getDataSourceIdSubscription?.unsubscribe(); + this.dataSourceIdUpdates$.complete(); this.dataSourceId$.complete(); } } diff --git a/public/tabs/history/__tests__/chat_history_page.test.tsx b/public/tabs/history/__tests__/chat_history_page.test.tsx index 944fedf6..2f87af01 100644 --- a/public/tabs/history/__tests__/chat_history_page.test.tsx +++ b/public/tabs/history/__tests__/chat_history_page.test.tsx @@ -6,6 +6,7 @@ import React from 'react'; import { act, fireEvent, render, waitFor } from '@testing-library/react'; import { I18nProvider } from '@osd/i18n/react'; +import { BehaviorSubject, Subject } from 'rxjs'; import { coreMock } from '../../../../../../src/core/public/mocks'; import { HttpStart } from '../../../../../../src/core/public'; @@ -14,7 +15,6 @@ import * as useChatStateExports from '../../../hooks/use_chat_state'; import * as chatContextExports from '../../../contexts/chat_context'; import * as coreContextExports from '../../../contexts/core_context'; import { ConversationsService } from '../../../services/conversations_service'; -import { DataSourceServiceMock } from '../../../services/data_source_service.mock'; import { ChatHistoryPage } from '../chat_history_page'; @@ -27,7 +27,7 @@ const mockGetConversationsHttp = () => { title: 'foo', }, ], - total: 1, + total: 100, })); return http; }; @@ -35,11 +35,16 @@ const mockGetConversationsHttp = () => { const setup = ({ http = mockGetConversationsHttp(), chatContext = {}, + shouldRefresh = false, }: { http?: HttpStart; chatContext?: { flyoutFullScreen?: boolean }; + shouldRefresh?: boolean; } = {}) => { - const dataSourceMock = new DataSourceServiceMock(); + const dataSourceMock = { + dataSourceIdUpdates$: new Subject(), + getDataSourceQuery: jest.fn(() => ({ dataSourceId: 'foo' })), + }; const useCoreMock = { services: { ...coreMock.createStart(), @@ -65,7 +70,7 @@ const setup = ({ const renderResult = render( - + ); @@ -73,6 +78,7 @@ const setup = ({ useCoreMock, useChatStateMock, useChatContextMock, + dataSourceMock, renderResult, }; }; @@ -240,4 +246,100 @@ describe('', () => { expect(abortMock).toHaveBeenCalled(); }); }); + + it('should call conversations.reload after data source changed', async () => { + const { useCoreMock, dataSourceMock } = setup({ shouldRefresh: true }); + + jest.spyOn(useCoreMock.services.conversations, 'load'); + + expect(useCoreMock.services.conversations.load).not.toHaveBeenCalled(); + + act(() => { + dataSourceMock.dataSourceIdUpdates$.next('bar'); + }); + + await waitFor(() => { + expect(useCoreMock.services.conversations.load).toHaveBeenCalledTimes(1); + }); + }); + + it('should not call conversations.load after unmount', async () => { + const { useCoreMock, dataSourceMock, renderResult } = setup({ shouldRefresh: true }); + + jest.spyOn(useCoreMock.services.conversations, 'reload'); + + expect(useCoreMock.services.conversations.reload).not.toHaveBeenCalled(); + renderResult.unmount(); + + dataSourceMock.dataSourceIdUpdates$.next('bar'); + expect(useCoreMock.services.conversations.reload).not.toHaveBeenCalled(); + }); + + it('should load conversations with empty search after data source changed', async () => { + const { useCoreMock, dataSourceMock, renderResult } = setup({ shouldRefresh: true }); + + jest.spyOn(useCoreMock.services.conversations, 'load'); + + fireEvent.change(renderResult.getByPlaceholderText('Search by conversation name'), { + target: { + value: 'bar', + }, + }); + + await waitFor(() => { + expect(useCoreMock.services.conversations.load).toHaveBeenLastCalledWith( + expect.objectContaining({ + search: 'bar', + }) + ); + }); + + act(() => { + dataSourceMock.dataSourceIdUpdates$.next('baz'); + }); + + await waitFor(() => { + expect(useCoreMock.services.conversations.load).toHaveBeenLastCalledWith({ + fields: expect.any(Array), + page: 1, + perPage: 10, + sortField: 'updatedTimeMs', + sortOrder: 'DESC', + searchFields: ['title'], + }); + expect(useCoreMock.services.conversations.load).toHaveBeenCalledTimes(2); + }); + }); + + it('should load conversations with first page after data source changed', async () => { + const { useCoreMock, dataSourceMock, renderResult } = setup({ shouldRefresh: true }); + + jest.spyOn(useCoreMock.services.conversations, 'load'); + + await waitFor(() => { + expect(renderResult.getByTestId('pagination-button-1')).toBeInTheDocument(); + }); + + fireEvent.click(renderResult.getByTestId('pagination-button-1')); + + await waitFor(() => { + expect(useCoreMock.services.conversations.load).toHaveBeenLastCalledWith( + expect.objectContaining({ + page: 2, + }) + ); + }); + + act(() => { + dataSourceMock.dataSourceIdUpdates$.next('baz'); + }); + + await waitFor(() => { + expect(useCoreMock.services.conversations.load).toHaveBeenLastCalledWith( + expect.objectContaining({ + page: 1, + }) + ); + }); + }); }); diff --git a/public/tabs/history/chat_history_page.tsx b/public/tabs/history/chat_history_page.tsx index 98662891..92b175c8 100644 --- a/public/tabs/history/chat_history_page.tsx +++ b/public/tabs/history/chat_history_page.tsx @@ -18,8 +18,9 @@ import { } from '@elastic/eui'; import React, { useCallback, useEffect, useMemo, useState } from 'react'; import { FormattedMessage } from '@osd/i18n/react'; -import { useDebounce, useObservable } from 'react-use'; +import { useDebounce, useObservable, useUpdateEffect } from 'react-use'; import cs from 'classnames'; + import { useChatActions, useChatState } from '../../hooks'; import { useChatContext, useCore } from '../../contexts'; import { TAB_ID } from '../../utils/constants'; @@ -41,34 +42,36 @@ export const ChatHistoryPage: React.FC = React.memo((props setConversationId, setTitle, } = useChatContext(); - const [pageIndex, setPageIndex] = useState(0); - const [pageSize, setPageSize] = useState(10); const [searchName, setSearchName] = useState(''); - const [debouncedSearchName, setDebouncedSearchName] = useState(''); - const bulkGetOptions = useMemo( - () => ({ - page: pageIndex + 1, - perPage: pageSize, - fields: ['createdTimeMs', 'updatedTimeMs', 'title'], - sortField: 'updatedTimeMs', - sortOrder: 'DESC', - ...(debouncedSearchName ? { search: debouncedSearchName, searchFields: ['title'] } : {}), - }), - [pageIndex, pageSize, debouncedSearchName] - ); + const [bulkGetOptions, setBulkGetOptions] = useState<{ + page: number; + perPage: number; + fields: string[]; + sortField: string; + sortOrder: string; + searchFields: string[]; + search?: string; + }>({ + page: 1, + perPage: 10, + fields: ['createdTimeMs', 'updatedTimeMs', 'title'], + sortField: 'updatedTimeMs', + sortOrder: 'DESC', + searchFields: ['title'], + }); const conversations = useObservable(services.conversations.conversations$); const loading = useObservable(services.conversations.status$) === 'loading'; const chatHistories = useMemo(() => conversations?.objects || [], [conversations]); const hasNoConversations = - !debouncedSearchName && !!conversations && conversations.total === 0 && !loading; + !bulkGetOptions.search && !!conversations && conversations.total === 0 && !loading; + const dataSourceUpdate = useObservable(services.dataSource.dataSourceIdUpdates$); const handleSearchChange = useCallback((e) => { setSearchName(e.target.value); }, []); const handleItemsPerPageChange = useCallback((itemsPerPage: number) => { - setPageIndex(0); - setPageSize(itemsPerPage); + setBulkGetOptions((prevOptions) => ({ ...prevOptions, page: 1, perPage: itemsPerPage })); }, []); const handleBack = useCallback(() => { @@ -87,19 +90,49 @@ export const ChatHistoryPage: React.FC = React.memo((props [conversationId, setConversationId, setTitle, chatStateDispatch] ); + const handlePageChange = useCallback((newPage) => { + setBulkGetOptions((prevOptions) => ({ + ...prevOptions, + page: newPage + 1, + })); + }, []); + useDebounce( () => { - setPageIndex(0); - setDebouncedSearchName(searchName); + setBulkGetOptions((prevOptions) => { + if (prevOptions.search === searchName || (!prevOptions.search && searchName === '')) { + return prevOptions; + } + const { search, ...rest } = prevOptions; + return { + ...rest, + page: 1, + ...(searchName ? { search: searchName } : {}), + }; + }); }, 150, [searchName] ); - useEffect(() => { - if (props.shouldRefresh) services.conversations.reload(); + useUpdateEffect(() => { + if (!props.shouldRefresh) { + return; + } + services.conversations.reload(); + return () => { + services.conversations.abortController?.abort(); + }; }, [props.shouldRefresh, services.conversations]); + useUpdateEffect(() => { + setSearchName(''); + setBulkGetOptions(({ search, page, ...rest }) => ({ + ...rest, + page: 1, + })); + }, [dataSourceUpdate]); + useEffect(() => { services.conversations.load(bulkGetOptions); return () => { @@ -150,11 +183,13 @@ export const ChatHistoryPage: React.FC = React.memo((props onLoadChat={loadChat} onRefresh={services.conversations.reload} histories={chatHistories} - activePage={pageIndex} - itemsPerPage={pageSize} + activePage={bulkGetOptions.page - 1} + itemsPerPage={bulkGetOptions.perPage} onChangeItemsPerPage={handleItemsPerPageChange} - onChangePage={setPageIndex} - {...(conversations ? { pageCount: Math.ceil(conversations.total / pageSize) } : {})} + onChangePage={handlePageChange} + {...(conversations + ? { pageCount: Math.ceil(conversations.total / bulkGetOptions.perPage) } + : {})} onHistoryDeleted={handleHistoryDeleted} /> )} diff --git a/public/types.ts b/public/types.ts index ca6505b2..4b05b4ad 100644 --- a/public/types.ts +++ b/public/types.ts @@ -21,6 +21,7 @@ export type ActionExecutor = (params: Record) => void; export interface AssistantActions { send: (input: IMessage) => Promise; loadChat: (conversationId?: string, title?: string) => Promise; + resetChat: () => void; openChatUI: (conversationId?: string) => void; executeAction: (suggestedAction: ISuggestedAction, message: IMessage) => Promise; abortAction: (conversationId?: string) => Promise;