diff --git a/opensearch_dashboards.json b/opensearch_dashboards.json index d3fc739c..e96de077 100644 --- a/opensearch_dashboards.json +++ b/opensearch_dashboards.json @@ -12,5 +12,5 @@ "dashboard", "opensearchUiShared" ], - "optionalPlugins": [] -} \ No newline at end of file + "optionalPlugins": ["dataSource", "dataSourceManagement"] +} diff --git a/public/apis/connector.ts b/public/apis/connector.ts index dc742dd1..cce182a4 100644 --- a/public/apis/connector.ts +++ b/public/apis/connector.ts @@ -23,13 +23,22 @@ interface GetAllInternalConnectorResponse { } export class Connector { - public getAll() { - return InnerHttpProvider.getHttp().get(CONNECTOR_API_ENDPOINT); + public getAll({ dataSourceId }: { dataSourceId?: string }) { + return InnerHttpProvider.getHttp().get(CONNECTOR_API_ENDPOINT, { + query: { + data_source_id: dataSourceId, + }, + }); } - public getAllInternal() { + public getAllInternal({ dataSourceId }: { dataSourceId?: string }) { return InnerHttpProvider.getHttp().get( - INTERNAL_CONNECTOR_API_ENDPOINT + INTERNAL_CONNECTOR_API_ENDPOINT, + { + query: { + data_source_id: dataSourceId, + }, + } ); } } diff --git a/public/apis/model.ts b/public/apis/model.ts index 459ce4bb..75d78191 100644 --- a/public/apis/model.ts +++ b/public/apis/model.ts @@ -36,10 +36,13 @@ export class Model { states?: MODEL_STATE[]; nameOrId?: string; extraQuery?: Record; + dataSourceId?: string; }) { - const { extraQuery, ...restQuery } = query; + const { extraQuery, dataSourceId, ...restQuery } = query; return InnerHttpProvider.getHttp().get(MODEL_API_ENDPOINT, { - query: extraQuery ? { ...restQuery, extra_query: JSON.stringify(extraQuery) } : restQuery, + query: extraQuery + ? { ...restQuery, extra_query: JSON.stringify(extraQuery), data_source_id: dataSourceId } + : { ...restQuery, data_source_id: dataSourceId }, }); } } diff --git a/public/apis/profile.ts b/public/apis/profile.ts index 253bbba3..f24e0c93 100644 --- a/public/apis/profile.ts +++ b/public/apis/profile.ts @@ -14,9 +14,14 @@ export interface ModelDeploymentProfile { } export class Profile { - public getModel(modelId: string) { + public getModel(modelId: string, { dataSourceId }: { dataSourceId?: string }) { return InnerHttpProvider.getHttp().get( - `${DEPLOYED_MODEL_PROFILE_API_ENDPOINT}/${modelId}` + `${DEPLOYED_MODEL_PROFILE_API_ENDPOINT}/${modelId}`, + { + query: { + data_source_id: dataSourceId, + }, + } ); } } diff --git a/public/application.tsx b/public/application.tsx index 28b69ca6..149b83b6 100644 --- a/public/application.tsx +++ b/public/application.tsx @@ -14,7 +14,7 @@ import { APIProvider } from './apis/api_provider'; import { OpenSearchDashboardsContextProvider } from '../../../src/plugins/opensearch_dashboards_react/public'; export const renderApp = ( - { element, history, appBasePath }: AppMountParameters, + { element, history, appBasePath, setHeaderActionMenu }: AppMountParameters, services: MLServices ) => { InnerHttpProvider.setHttp(services.http); @@ -31,6 +31,10 @@ export const renderApp = ( chrome={services.chrome} data={services.data} uiSettingsClient={services.uiSettings} + savedObjects={services.savedObjects} + setActionMenu={setHeaderActionMenu} + dataSource={services.dataSource} + dataSourceManagement={services.dataSourceManagement} /> diff --git a/public/components/__tests__/data_source_top_nav_menu.test.tsx b/public/components/__tests__/data_source_top_nav_menu.test.tsx new file mode 100644 index 00000000..dab568b8 --- /dev/null +++ b/public/components/__tests__/data_source_top_nav_menu.test.tsx @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useContext } from 'react'; +import userEvent from '@testing-library/user-event'; +import { render, screen, waitFor } from '../../../test/test_utils'; +import { DataSourceTopNavMenu, DataSourceTopNavMenuProps } from '../data_source_top_nav_menu'; +import { coreMock } from '../../../../../src/core/public/mocks'; +import { DataSourceContext } from '../../contexts'; + +function setup(options: Partial = {}) { + const user = userEvent.setup({}); + const coreStart = coreMock.createStart(); + const DataSourceMenu = ({ componentConfig: { onSelectedDataSources } }) => ( +
+
Data Source Menu
+
+ + +
+
+ ); + + const DataSourceConsumer = () => { + const { selectedDataSourceOption } = useContext(DataSourceContext); + + return ( +
+ {}} + /> +
+ ); + }; + + const renderResult = render( + <> + null, + getDataSourceMenu: () => DataSourceMenu, + }, + }} + setActionMenu={jest.fn()} + {...options} + /> + + + ); + return { user, renderResult }; +} + +describe('', () => { + it('should not render data source menu when data source management not defined', () => { + setup({ + dataSourceManagement: undefined, + }); + expect(screen.queryByText('Data Source Menu')).not.toBeInTheDocument(); + }); + + it('should render data source menu and data source context', () => { + setup(); + expect(screen.getByText('Data Source Menu')).toBeInTheDocument(); + expect(screen.getByLabelText('selectedDataSourceOption')).toHaveValue('null'); + }); + + it('should set selected data source option to undefined', async () => { + const { user } = setup(); + expect(screen.getByText('Data Source Menu')).toBeInTheDocument(); + await user.click(screen.getByLabelText('invalidDataSource')); + await waitFor(() => { + expect(screen.getByLabelText('selectedDataSourceOption')).toHaveValue('undefined'); + }); + }); + + it('should set selected data source option to valid data source', async () => { + const { user } = setup(); + expect(screen.getByText('Data Source Menu')).toBeInTheDocument(); + await user.click(screen.getByLabelText('validDataSource')); + await waitFor(() => { + expect(screen.getByLabelText('selectedDataSourceOption')).toHaveValue( + JSON.stringify({ id: 'ds1', label: 'Data Source 1' }) + ); + }); + }); +}); diff --git a/public/components/app.tsx b/public/components/app.tsx index 715e0174..e8d38fb6 100644 --- a/public/components/app.tsx +++ b/public/components/app.tsx @@ -10,11 +10,20 @@ import { EuiPage, EuiPageBody } from '@elastic/eui'; import { ROUTES } from '../../common/router'; import { routerPaths } from '../../common/router_paths'; -import { CoreStart, IUiSettingsClient } from '../../../../src/core/public'; +import { + CoreStart, + IUiSettingsClient, + MountPoint, + SavedObjectsStart, +} from '../../../../src/core/public'; import { NavigationPublicPluginStart } from '../../../../src/plugins/navigation/public'; import { DataPublicPluginStart } from '../../../../src/plugins/data/public'; +import type { DataSourceManagementPluginSetup } from '../../../../src/plugins/data_source_management/public'; +import type { DataSourcePluginSetup } from '../../../../src/plugins/data_source/public'; +import { DataSourceContextProvider } from '../contexts/data_source_context'; import { GlobalBreadcrumbs } from './global_breadcrumbs'; +import { DataSourceTopNavMenu } from './data_source_top_nav_menu'; interface MlCommonsPluginAppDeps { basename: string; @@ -24,6 +33,10 @@ interface MlCommonsPluginAppDeps { chrome: CoreStart['chrome']; data: DataPublicPluginStart; uiSettingsClient: IUiSettingsClient; + savedObjects: SavedObjectsStart; + dataSource?: DataSourcePluginSetup; + dataSourceManagement?: DataSourceManagementPluginSetup; + setActionMenu: (menuMount: MountPoint | undefined) => void; } export interface ComponentsCommonProps { @@ -38,27 +51,48 @@ export const MlCommonsPluginApp = ({ http, chrome, data, + dataSource, + dataSourceManagement, + savedObjects, + setActionMenu, }: MlCommonsPluginAppDeps) => { + const dataSourceEnabled = !!dataSource; return ( - <> - - - - {ROUTES.map(({ path, Component, exact }) => ( - } - exact={exact ?? false} - /> - ))} - - - - - - + + <> + + + + {ROUTES.map(({ path, Component, exact }) => ( + ( + + )} + exact={exact ?? false} + /> + ))} + + + + + + {dataSourceEnabled && ( + + )} + + ); }; diff --git a/public/components/data_source_top_nav_menu.tsx b/public/components/data_source_top_nav_menu.tsx new file mode 100644 index 00000000..7194696a --- /dev/null +++ b/public/components/data_source_top_nav_menu.tsx @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useMemo, useContext, useCallback } from 'react'; + +import type { CoreStart, MountPoint, SavedObjectsStart } from '../../../../src/core/public'; +import type { + DataSourceManagementPluginSetup, + DataSourceSelectableConfig, +} from '../../../../src/plugins/data_source_management/public'; +import { DataSourceContext } from '../contexts/data_source_context'; + +export interface DataSourceTopNavMenuProps { + notifications: CoreStart['notifications']; + savedObjects: SavedObjectsStart; + dataSourceManagement?: DataSourceManagementPluginSetup; + setActionMenu: (menuMount: MountPoint | undefined) => void; +} + +export const DataSourceTopNavMenu = ({ + savedObjects, + notifications, + setActionMenu, + dataSourceManagement, +}: DataSourceTopNavMenuProps) => { + const DataSourceMenu = useMemo(() => dataSourceManagement?.ui.getDataSourceMenu(), [ + dataSourceManagement, + ]); + const { selectedDataSourceOption, setSelectedDataSourceOption } = useContext(DataSourceContext); + const activeOption = useMemo(() => (selectedDataSourceOption ? [selectedDataSourceOption] : []), [ + selectedDataSourceOption, + ]); + + const handleDataSourcesSelected = useCallback< + DataSourceSelectableConfig['onSelectedDataSources'] + >( + (dataSourceOptions) => { + setSelectedDataSourceOption(dataSourceOptions[0]); + }, + [setSelectedDataSourceOption] + ); + + if (!DataSourceMenu) { + return null; + } + return ( + + ); +}; diff --git a/public/components/monitoring/tests/index.test.tsx b/public/components/monitoring/__tests__/index.test.tsx similarity index 100% rename from public/components/monitoring/tests/index.test.tsx rename to public/components/monitoring/__tests__/index.test.tsx diff --git a/public/components/monitoring/tests/model_connector_filter.test.tsx b/public/components/monitoring/__tests__/model_connector_filter.test.tsx similarity index 75% rename from public/components/monitoring/tests/model_connector_filter.test.tsx rename to public/components/monitoring/__tests__/model_connector_filter.test.tsx index 270c504a..7e47387e 100644 --- a/public/components/monitoring/tests/model_connector_filter.test.tsx +++ b/public/components/monitoring/__tests__/model_connector_filter.test.tsx @@ -7,10 +7,16 @@ import React from 'react'; import userEvent from '@testing-library/user-event'; import { render, screen, waitFor, within } from '../../../../test/test_utils'; import { ModelConnectorFilter } from '../model_connector_filter'; +import { + DATA_SOURCE_FETCHING_ID, + DATA_SOURCE_INVALID_ID, + DataSourceId, +} from '../../../utils/data_source'; +import { Connector } from '../../../apis/connector'; jest.mock('../../../apis/connector'); -async function setup(value: string[]) { +async function setup(value: string[], dataSourceId?: DataSourceId) { const onChangeMock = jest.fn(); const user = userEvent.setup({}); render( @@ -21,6 +27,7 @@ async function setup(value: string[]) { ]} value={value} onChange={onChangeMock} + dataSourceId={dataSourceId} /> ); await user.click(screen.getByText('Connector name')); @@ -65,7 +72,7 @@ describe('', () => { }); it('should render all connectors in the option list', async () => { - await setup(['External Connector 1']); + await setup(['External Connector 1'], 'foo'); await waitFor(() => { expect( within(screen.getByRole('dialog')).getByText('Internal Connector 1') @@ -84,4 +91,16 @@ describe('', () => { expect(onChangeMock).toHaveBeenLastCalledWith(['External Connector 1', 'Common Connector']); }); + + it('should not call getAllInternal when data source id is fetching', async () => { + jest.spyOn(Connector.prototype, 'getAllInternal'); + await setup(['External Connector 1'], DATA_SOURCE_FETCHING_ID); + expect(Connector.prototype.getAllInternal).not.toHaveBeenCalled(); + }); + + it('should not call getAllInternal when data source id is invalid', async () => { + jest.spyOn(Connector.prototype, 'getAllInternal'); + await setup(['External Connector 1'], DATA_SOURCE_INVALID_ID); + expect(Connector.prototype.getAllInternal).not.toHaveBeenCalled(); + }); }); diff --git a/public/components/monitoring/tests/model_deployment_table.test.tsx b/public/components/monitoring/__tests__/model_deployment_table.test.tsx similarity index 100% rename from public/components/monitoring/tests/model_deployment_table.test.tsx rename to public/components/monitoring/__tests__/model_deployment_table.test.tsx diff --git a/public/components/monitoring/tests/model_source_filter.test.tsx b/public/components/monitoring/__tests__/model_source_filter.test.tsx similarity index 100% rename from public/components/monitoring/tests/model_source_filter.test.tsx rename to public/components/monitoring/__tests__/model_source_filter.test.tsx diff --git a/public/components/monitoring/tests/search_bar.test.tsx b/public/components/monitoring/__tests__/search_bar.test.tsx similarity index 100% rename from public/components/monitoring/tests/search_bar.test.tsx rename to public/components/monitoring/__tests__/search_bar.test.tsx diff --git a/public/components/monitoring/tests/use_monitoring.test.ts b/public/components/monitoring/__tests__/use_monitoring.test.tsx similarity index 75% rename from public/components/monitoring/tests/use_monitoring.test.ts rename to public/components/monitoring/__tests__/use_monitoring.test.tsx index 3c97e4fa..c338769c 100644 --- a/public/components/monitoring/tests/use_monitoring.test.ts +++ b/public/components/monitoring/__tests__/use_monitoring.test.tsx @@ -2,12 +2,52 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ - +import React, { useContext } from 'react'; import { act, renderHook } from '@testing-library/react-hooks'; import { Model, ModelSearchResponse } from '../../../apis/model'; import { Connector } from '../../../apis/connector'; import { useMonitoring } from '../use_monitoring'; +import { + DataSourceContext, + DataSourceContextProvider, + DataSourceContextProviderProps, + DataSourceOption, +} from '../../../contexts'; + +const setup = ({ + initDataSourceContextValue, +}: { + initDataSourceContextValue?: Partial; +} = {}) => { + let setSelectedDataSourceOption: ( + dataSourceOption: DataSourceOption | null | undefined + ) => void = () => {}; + const DataSourceConsumer = () => { + const context = useContext(DataSourceContext); + setSelectedDataSourceOption = context.setSelectedDataSourceOption; + return null; + }; + const renderHookResult = renderHook(() => useMonitoring(), { + wrapper: ({ children }) => ( + + {children} + + + ), + }); + + return { + renderHookResult, + setSelectedDataSourceOption, + }; +}; jest.mock('../../../apis/connector'); @@ -45,7 +85,9 @@ describe('useMonitoring', () => { await waitFor(() => result.current.pageStatus === 'normal'); - result.current.searchByNameOrId('foo'); + act(() => { + result.current.searchByNameOrId('foo'); + }); await waitFor(() => expect(Model.prototype.search).toHaveBeenCalledWith( expect.objectContaining({ @@ -55,7 +97,9 @@ describe('useMonitoring', () => { ) ); - result.current.searchByStatus(['responding']); + act(() => { + result.current.searchByStatus(['responding']); + }); await waitFor(() => expect(Model.prototype.search).toHaveBeenCalledWith( expect.objectContaining({ @@ -65,10 +109,14 @@ describe('useMonitoring', () => { ) ); - result.current.resetSearch(); + act(() => { + result.current.resetSearch(); + }); await waitFor(() => result.current.pageStatus === 'normal'); - result.current.searchByStatus(['partial-responding']); + act(() => { + result.current.searchByStatus(['partial-responding']); + }); await waitFor(() => expect(Model.prototype.search).toHaveBeenCalledWith( expect.objectContaining({ @@ -91,9 +139,11 @@ describe('useMonitoring', () => { ) ); - result.current.handleTableChange({ - sort: { field: 'name', direction: 'desc' }, - pagination: { currentPage: 2, pageSize: 10 }, + act(() => { + result.current.handleTableChange({ + sort: { field: 'name', direction: 'desc' }, + pagination: { currentPage: 2, pageSize: 10 }, + }); }); await waitFor(() => expect(Model.prototype.search).toHaveBeenCalledWith( @@ -111,7 +161,9 @@ describe('useMonitoring', () => { await waitFor(() => expect(Model.prototype.search).toHaveBeenCalledTimes(1)); - result.current.reload(); + act(() => { + result.current.reload(); + }); await waitFor(() => expect(Model.prototype.search).toHaveBeenCalledTimes(2)); }); @@ -261,18 +313,22 @@ describe('useMonitoring', () => { it('should call searchByNameOrId with from 0 after page changed', async () => { const { result, waitFor } = renderHook(() => useMonitoring()); - result.current.handleTableChange({ - pagination: { - currentPage: 2, - pageSize: 15, - }, + act(() => { + result.current.handleTableChange({ + pagination: { + currentPage: 2, + pageSize: 15, + }, + }); }); await waitFor(() => { expect(result.current.pagination?.currentPage).toBe(2); }); - result.current.searchByNameOrId('foo'); + act(() => { + result.current.searchByNameOrId('foo'); + }); await waitFor(() => { expect(Model.prototype.search).toHaveBeenCalledTimes(3); @@ -287,18 +343,22 @@ describe('useMonitoring', () => { it('should call searchByStatus with from 0 after page changed', async () => { const { result, waitFor } = renderHook(() => useMonitoring()); - result.current.handleTableChange({ - pagination: { - currentPage: 2, - pageSize: 15, - }, + act(() => { + result.current.handleTableChange({ + pagination: { + currentPage: 2, + pageSize: 15, + }, + }); }); await waitFor(() => { expect(result.current.pagination?.currentPage).toBe(2); }); - result.current.searchByStatus(['responding']); + act(() => { + result.current.searchByStatus(['responding']); + }); await waitFor(() => { expect(Model.prototype.search).toHaveBeenCalledTimes(3); @@ -315,7 +375,9 @@ describe('useMonitoring', () => { await waitFor(() => result.current.pageStatus === 'normal'); - result.current.searchBySource(['local']); + act(() => { + result.current.searchBySource(['local']); + }); await waitFor(() => expect(Model.prototype.search).toHaveBeenLastCalledWith( expect.objectContaining({ @@ -336,7 +398,9 @@ describe('useMonitoring', () => { ) ); - result.current.searchBySource(['external']); + act(() => { + result.current.searchBySource(['external']); + }); await waitFor(() => expect(Model.prototype.search).toHaveBeenLastCalledWith( expect.objectContaining({ @@ -357,7 +421,9 @@ describe('useMonitoring', () => { ) ); - result.current.searchBySource(['external', 'local']); + act(() => { + result.current.searchBySource(['external', 'local']); + }); await waitFor(() => expect(Model.prototype.search).toHaveBeenLastCalledWith( expect.objectContaining({ @@ -371,8 +437,10 @@ describe('useMonitoring', () => { const { result, waitFor } = renderHook(() => useMonitoring()); await waitFor(() => result.current.pageStatus === 'normal'); + act(() => { + result.current.searchByConnector(['External Connector 1']); + }); - result.current.searchByConnector(['External Connector 1']); await waitFor(() => expect(Model.prototype.search).toHaveBeenLastCalledWith( expect.objectContaining({ @@ -407,6 +475,61 @@ describe('useMonitoring', () => { await waitFor(() => result.current.pageStatus === 'normal'); }); + + it('should not call model search if selected data source is null', async () => { + const { + renderHookResult: { waitFor }, + } = setup({ + initDataSourceContextValue: { + selectedDataSourceOption: null, + }, + }); + await waitFor(() => { + expect(Model.prototype.search).not.toHaveBeenCalled(); + }); + }); + + it('should call model search and connector get all with data source id', async () => { + const getAllConnectorMock = jest.spyOn(Connector.prototype, 'getAll'); + const { + renderHookResult: { waitFor }, + } = setup({ + initDataSourceContextValue: { + selectedDataSourceOption: { id: 'foo' }, + }, + }); + const dataSourceIdExpect = expect.objectContaining({ + dataSourceId: 'foo', + }); + await waitFor(() => { + expect(Model.prototype.search).toHaveBeenCalledWith(dataSourceIdExpect); + expect(getAllConnectorMock).toHaveBeenCalledWith(dataSourceIdExpect); + }); + }); + + it('should reset connector filter after selected data source option change', async () => { + const { + renderHookResult: { result, waitFor }, + setSelectedDataSourceOption, + } = setup({ + initDataSourceContextValue: { + selectedDataSourceOption: { id: 'foo' }, + }, + }); + act(() => { + result.current.searchByConnector(['connector-1']); + }); + await waitFor(() => { + expect(Model.prototype.search).toHaveBeenCalledTimes(2); + }); + act(() => { + setSelectedDataSourceOption({ id: 'bar' }); + }); + await waitFor(() => { + expect(Model.prototype.search).toHaveBeenCalledTimes(3); + expect(result.current.params.connector).toEqual([]); + }); + }); }); describe('useMonitoring.pageStatus', () => { @@ -452,7 +575,9 @@ describe('useMonitoring.pageStatus', () => { // Mock search function to return empty result mockEmptyRecords(); - result.current.searchByNameOrId('foo'); + act(() => { + result.current.searchByNameOrId('foo'); + }); await waitFor(() => expect(result.current.pageStatus).toBe('reset-filter')); }); @@ -464,7 +589,9 @@ describe('useMonitoring.pageStatus', () => { // assume result is empty mockEmptyRecords(); - result.current.searchByStatus(['responding']); + act(() => { + result.current.searchByStatus(['responding']); + }); await waitFor(() => expect(result.current.pageStatus).toBe('reset-filter')); }); @@ -476,7 +603,9 @@ describe('useMonitoring.pageStatus', () => { // assume result is empty mockEmptyRecords(); - result.current.searchBySource(['local']); + act(() => { + result.current.searchBySource(['local']); + }); await waitFor(() => expect(result.current.pageStatus).toBe('reset-filter')); }); @@ -488,7 +617,9 @@ describe('useMonitoring.pageStatus', () => { // assume result is empty mockEmptyRecords(); - result.current.searchByConnector([{ name: 'Sagemaker', ids: [] }]); + act(() => { + result.current.searchByConnector(['Sagemaker']); + }); await waitFor(() => expect(result.current.pageStatus).toBe('reset-filter')); }); @@ -498,4 +629,14 @@ describe('useMonitoring.pageStatus', () => { await waitFor(() => expect(result.current.pageStatus).toBe('empty')); }); + + it('should return "loading" and not call model search when data source id is fetching', async () => { + const { + renderHookResult: { result, waitFor }, + } = setup(); + + await waitFor(() => { + expect(result.current.pageStatus).toBe('loading'); + }); + }); }); diff --git a/public/components/monitoring/index.tsx b/public/components/monitoring/index.tsx index e2cb81da..83e755e7 100644 --- a/public/components/monitoring/index.tsx +++ b/public/components/monitoring/index.tsx @@ -41,7 +41,10 @@ export const Monitoring = () => { searchByConnector, allExternalConnectors, } = useMonitoring(); - const [previewModel, setPreviewModel] = useState(null); + const [preview, setPreview] = useState<{ + model: ModelDeploymentItem; + dataSourceId: string | undefined; + } | null>(null); const searchInputRef = useRef(); const setInputRef = useCallback((node: HTMLInputElement | null) => { @@ -55,22 +58,28 @@ export const Monitoring = () => { resetSearch(); }, [resetSearch]); - const handleViewDetail = useCallback((modelPreviewItem: ModelDeploymentItem) => { - setPreviewModel(modelPreviewItem); - }, []); + const handleViewDetail = useCallback( + (modelPreviewItem: ModelDeploymentItem) => { + // This check is for type safe, the data source id won't be invalid or fetching if model can be previewed. + if (typeof params.dataSourceId !== 'symbol') { + setPreview({ model: modelPreviewItem, dataSourceId: params.dataSourceId }); + } + }, + [params.dataSourceId] + ); const onCloseModelPreview = useCallback( (modelProfile: ModelDeploymentProfile | null) => { if ( modelProfile !== null && - (previewModel?.planningNodesCount !== modelProfile.target_worker_nodes?.length || - previewModel?.respondingNodesCount !== modelProfile.worker_nodes?.length) + (preview?.model?.planningNodesCount !== modelProfile.target_worker_nodes?.length || + preview?.model?.respondingNodesCount !== modelProfile.worker_nodes?.length) ) { reload(); } - setPreviewModel(null); + setPreview(null); }, - [previewModel, reload] + [preview, reload] ); return ( @@ -108,6 +117,7 @@ export const Monitoring = () => { value={params.connector} onChange={searchByConnector} allExternalConnectors={allExternalConnectors} + dataSourceId={params.dataSourceId} /> @@ -127,7 +137,13 @@ export const Monitoring = () => { onViewDetail={handleViewDetail} onResetSearchClick={onResetSearch} /> - {previewModel && } + {preview && ( + + )} ); diff --git a/public/components/monitoring/model_connector_filter.tsx b/public/components/monitoring/model_connector_filter.tsx index 42d5434e..61043bd8 100644 --- a/public/components/monitoring/model_connector_filter.tsx +++ b/public/components/monitoring/model_connector_filter.tsx @@ -5,8 +5,9 @@ import React, { useMemo } from 'react'; import { OptionsFilter, OptionsFilterProps } from '../common/options_filter'; -import { useFetcher } from '../../hooks'; +import { DO_NOT_FETCH, useFetcher } from '../../hooks'; import { APIProvider } from '../../apis/api_provider'; +import { DataSourceId } from '../../utils/data_source'; interface ModelConnectorFilterProps extends Omit< @@ -16,14 +17,17 @@ interface ModelConnectorFilterProps allExternalConnectors?: Array<{ id: string; name: string }>; value: string[]; onChange: (value: string[]) => void; + dataSourceId: DataSourceId; } export const ModelConnectorFilter = ({ + dataSourceId, allExternalConnectors, ...restProps }: ModelConnectorFilterProps) => { const { data: internalConnectorsResult } = useFetcher( - APIProvider.getAPI('connector').getAllInternal + APIProvider.getAPI('connector').getAllInternal, + typeof dataSourceId === 'symbol' ? DO_NOT_FETCH : { dataSourceId } ); const options = useMemo( () => diff --git a/public/components/monitoring/use_monitoring.ts b/public/components/monitoring/use_monitoring.ts index b792255a..c95f5837 100644 --- a/public/components/monitoring/use_monitoring.ts +++ b/public/components/monitoring/use_monitoring.ts @@ -3,14 +3,15 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useMemo, useCallback, useState } from 'react'; +import { useMemo, useCallback, useState, useContext, useEffect } from 'react'; import { APIProvider } from '../../apis/api_provider'; import { GetAllConnectorResponse } from '../../apis/connector'; -import { useFetcher } from '../../hooks/use_fetcher'; +import { DO_NOT_FETCH, useFetcher } from '../../hooks/use_fetcher'; import { MODEL_STATE } from '../../../common'; - +import { DataSourceContext } from '../../contexts'; import { ModelDeployStatus } from './types'; +import { DATA_SOURCE_FETCHING_ID, DataSourceId, getDataSourceId } from '../../utils/data_source'; interface Params { nameOrId?: string; @@ -20,6 +21,7 @@ interface Params { currentPage: number; pageSize: number; sort: { field: 'name' | 'model_state' | 'id'; direction: 'asc' | 'desc' }; + dataSourceId?: DataSourceId; } const generateExtraQuery = ({ @@ -82,7 +84,9 @@ const checkFilterExists = (params: Params) => params.connector.length > 0 || params.source.length > 0; -const fetchDeployedModels = async (params: Params) => { +const fetchDeployedModels = async ( + params: Omit & { dataSourceId?: string } +) => { const states = params.status?.map((status) => { switch (status) { case 'not-responding': @@ -95,7 +99,9 @@ const fetchDeployedModels = async (params: Params) => { }); let externalConnectorsData: GetAllConnectorResponse; try { - externalConnectorsData = await APIProvider.getAPI('connector').getAll(); + externalConnectorsData = await APIProvider.getAPI('connector').getAll({ + dataSourceId: params.dataSourceId, + }); } catch (_e) { externalConnectorsData = { data: [], total_connectors: 0 }; } @@ -120,6 +126,7 @@ const fetchDeployedModels = async (params: Params) => { })) : [], }), + dataSourceId: params.dataSourceId, }); const externalConnectorMap = externalConnectorsData.data.reduce<{ [key: string]: { @@ -168,14 +175,21 @@ const fetchDeployedModels = async (params: Params) => { }; export const useMonitoring = () => { + const { dataSourceEnabled, selectedDataSourceOption } = useContext(DataSourceContext); const [params, setParams] = useState({ currentPage: 1, pageSize: 10, sort: { field: 'model_state', direction: 'asc' }, source: [], connector: [], + dataSourceId: getDataSourceId(dataSourceEnabled, selectedDataSourceOption), }); - const { data, loading, reload } = useFetcher(fetchDeployedModels, params); + const { data, loading, reload } = useFetcher( + fetchDeployedModels, + typeof params.dataSourceId === 'symbol' + ? DO_NOT_FETCH + : { ...params, dataSourceId: params.dataSourceId } + ); const filterExists = checkFilterExists(params); const totalRecords = data?.pagination.totalRecords; const deployedModels = useMemo(() => data?.data ?? [], [data]); @@ -188,22 +202,24 @@ export const useMonitoring = () => { * "empty" is for no deployed models in current system */ const pageStatus = useMemo(() => { - if (loading) { + if (loading || params.dataSourceId === DATA_SOURCE_FETCHING_ID) { return 'loading' as const; } if (totalRecords) { return 'normal' as const; } return filterExists ? ('reset-filter' as const) : ('empty' as const); - }, [loading, totalRecords, filterExists]); + }, [loading, totalRecords, filterExists, params.dataSourceId]); const resetSearch = useCallback(() => { setParams((previousValue) => ({ + ...previousValue, currentPage: previousValue.currentPage, pageSize: previousValue.pageSize, sort: previousValue.sort, source: [], connector: [], + status: undefined, })); }, []); @@ -263,6 +279,20 @@ export const useMonitoring = () => { [] ); + useEffect(() => { + setParams((previousParams) => { + const dataSourceId = getDataSourceId(dataSourceEnabled, selectedDataSourceOption); + if (previousParams.dataSourceId === dataSourceId) { + return previousParams; + } + return { + ...previousParams, + dataSourceId, + connector: [], + }; + }); + }, [dataSourceEnabled, selectedDataSourceOption]); + return { params, pageStatus, diff --git a/public/components/preview_panel/__tests__/preview_panel.test.tsx b/public/components/preview_panel/__tests__/preview_panel.test.tsx index 7da93f4b..55e9309a 100644 --- a/public/components/preview_panel/__tests__/preview_panel.test.tsx +++ b/public/components/preview_panel/__tests__/preview_panel.test.tsx @@ -17,7 +17,7 @@ const MODEL = { function setup({ model = MODEL, onClose = jest.fn() }) { const user = userEvent.setup({}); - render(); + render(); return { user }; } @@ -39,14 +39,14 @@ describe('', () => { }); it('source should be external and should not render nodes details when connector params passed', async () => { - const modelWithConntector = { + const modelWithConnector = { ...MODEL, connector: { name: 'connector', }, }; setup({ - model: modelWithConntector, + model: modelWithConnector, }); expect(screen.getByText('External')).toBeInTheDocument(); expect(screen.queryByText('Status by node')).not.toBeInTheDocument(); @@ -137,4 +137,19 @@ describe('', () => { expect(screen.getByText('node-3')).toBeInTheDocument(); }); }); + + it('should call get model with passed data source id', async () => { + const getModelMock = jest + .spyOn(APIProvider.getAPI('profile'), 'getModel') + .mockResolvedValueOnce({}); + setup({}); + await waitFor(() => + expect(getModelMock).toHaveBeenCalledWith( + 'id1', + expect.objectContaining({ + dataSourceId: 'foo', + }) + ) + ); + }); }); diff --git a/public/components/preview_panel/index.tsx b/public/components/preview_panel/index.tsx index e41e56c1..6c27206a 100644 --- a/public/components/preview_panel/index.tsx +++ b/public/components/preview_panel/index.tsx @@ -20,10 +20,10 @@ import { } from '@elastic/eui'; import { APIProvider } from '../../apis/api_provider'; import { useFetcher } from '../../hooks/use_fetcher'; -import { NodesTable } from './nodes_table'; import { CopyableText } from '../common'; import { ModelDeploymentProfile } from '../../apis/profile'; import { ConnectorDetails } from './connector_details'; +import { NodesTable } from './nodes_table'; export interface INode { id: string; @@ -44,11 +44,14 @@ export interface PreviewModel { interface Props { onClose: (data: ModelDeploymentProfile | null) => void; model: PreviewModel; + dataSourceId: string | undefined; } -export const PreviewPanel = ({ onClose, model }: Props) => { +export const PreviewPanel = ({ onClose, model, dataSourceId }: Props) => { const { id, name, connector } = model; - const { data, loading } = useFetcher(APIProvider.getAPI('profile').getModel, id); + const { data, loading } = useFetcher(APIProvider.getAPI('profile').getModel, id, { + dataSourceId, + }); const nodes = useMemo(() => { if (loading) { return []; diff --git a/public/contexts/data_source_context.tsx b/public/contexts/data_source_context.tsx new file mode 100644 index 00000000..2abaa752 --- /dev/null +++ b/public/contexts/data_source_context.tsx @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { Dispatch, SetStateAction, createContext, useMemo, useState } from 'react'; +import type { DataSourceSelectableConfig } from '../../../../src/plugins/data_source_management/public'; + +export type DataSourceOption = Parameters< + DataSourceSelectableConfig['onSelectedDataSources'] +>[0][0]; + +export const DataSourceContext = createContext<{ + /** + * null for default state + * undefined for invalid state + * DataSourceOption for valid state + */ + selectedDataSourceOption: DataSourceOption | null | undefined; + setSelectedDataSourceOption: Dispatch>; + dataSourceEnabled: boolean | null; + setDataSourceEnabled: Dispatch>; +}>({ + selectedDataSourceOption: null, + setSelectedDataSourceOption: () => null, + dataSourceEnabled: null, + setDataSourceEnabled: () => null, +}); + +const { Provider, Consumer } = DataSourceContext; + +export type DataSourceContextProviderProps = React.PropsWithChildren<{ + initialValue?: { + selectedDataSourceOption?: DataSourceOption | null | undefined; + dataSourceEnabled?: boolean; + }; +}>; + +export const DataSourceContextProvider = ({ + children, + initialValue, +}: DataSourceContextProviderProps) => { + const [selectedDataSourceOption, setSelectedDataSourceOption] = useState< + DataSourceOption | undefined | null + >(initialValue?.selectedDataSourceOption ?? null); + const [dataSourceEnabled, setDataSourceEnabled] = useState( + initialValue?.dataSourceEnabled ?? null + ); + const value = useMemo( + () => ({ + selectedDataSourceOption, + setSelectedDataSourceOption, + dataSourceEnabled, + setDataSourceEnabled, + }), + [selectedDataSourceOption, setSelectedDataSourceOption, dataSourceEnabled, setDataSourceEnabled] + ); + return {children}; +}; + +export const DataSourceContextConsumer = Consumer; diff --git a/public/contexts/index.ts b/public/contexts/index.ts new file mode 100644 index 00000000..ec02ca3e --- /dev/null +++ b/public/contexts/index.ts @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export { + DataSourceOption, + DataSourceContext, + DataSourceContextConsumer, + DataSourceContextProvider, + DataSourceContextProviderProps, +} from './data_source_context'; diff --git a/public/hooks/tests/use_fetcher.test.ts b/public/hooks/tests/use_fetcher.test.ts index ff5c5808..64e20d30 100644 --- a/public/hooks/tests/use_fetcher.test.ts +++ b/public/hooks/tests/use_fetcher.test.ts @@ -4,7 +4,7 @@ */ import { act, renderHook } from '@testing-library/react-hooks'; -import { useFetcher } from '../use_fetcher'; +import { DO_NOT_FETCH, useFetcher } from '../use_fetcher'; describe('useFetcher', () => { it('should call fetcher with consistent params and return consistent result', async () => { @@ -18,7 +18,7 @@ describe('useFetcher', () => { expect(fetcher).toHaveBeenCalledWith('foo'); }); - it('should call fetcher only onece if params content not change', async () => { + it('should call fetcher only once if params content not change', async () => { const fetcher = jest.fn((_arg1: any) => Promise.resolve()); const { result, waitFor, rerender } = renderHook(({ params }) => useFetcher(fetcher, params), { initialProps: { params: { foo: 'bar' } }, @@ -130,4 +130,42 @@ describe('useFetcher', () => { expect(result.current.data).toBe('bar'); }); + + it('should not call fetcher when first parameter is DO_NOT_FETCH', () => { + const fetcher = jest.fn(); + const { result } = renderHook(() => useFetcher(fetcher, DO_NOT_FETCH)); + + expect(result.current.loading).toBe(false); + expect(fetcher).not.toHaveBeenCalledWith(); + }); + + it('should not call fetcher after reload called when first parameter is DO_NOT_FETCH', () => { + const fetcher = jest.fn(); + const { result } = renderHook(() => useFetcher(fetcher, DO_NOT_FETCH)); + + result.current.reload(); + expect(result.current.loading).toBe(false); + expect(fetcher).not.toHaveBeenCalledWith(); + }); + + it('should call fetcher after first parameter changed from DO_NOT_FETCH', async () => { + const fetcher = jest.fn(async (...params) => params); + const { result, rerender, waitFor } = renderHook( + ({ params }) => useFetcher(fetcher, ...params), + { + initialProps: { + params: [DO_NOT_FETCH], + }, + } + ); + + rerender({ params: [] }); + expect(result.current.loading).toBe(true); + expect(fetcher).toHaveBeenCalled(); + + await waitFor(() => { + expect(result.current.loading).toBe(false); + expect(result.current.data).toEqual([]); + }); + }); }); diff --git a/public/hooks/use_fetcher.ts b/public/hooks/use_fetcher.ts index 4ee69709..f3e5d684 100644 --- a/public/hooks/use_fetcher.ts +++ b/public/hooks/use_fetcher.ts @@ -5,13 +5,26 @@ import { useCallback, useEffect, useRef, useState } from 'react'; +/** + * + * This symbol is for prevent fetcher be executed when component mount, + * the fetcher won't be executed if second parameters of useFetcher hook is DO_NOT_FETCH. + * + */ +export const DO_NOT_FETCH = Symbol('DO_NOT_FETCH'); + +export type DoNotFetchParams = [typeof DO_NOT_FETCH]; + +const isDoNotFetch = (test: string | any[] | DoNotFetchParams): test is DoNotFetchParams => + test[0] === DO_NOT_FETCH; + export const useFetcher = ( fetcher: (...params: TParams) => Promise, - ...params: TParams + ...params: TParams | [typeof DO_NOT_FETCH] ) => { const [, setCount] = useState(0); const dataRef = useRef(null); - const loadingRef = useRef(true); + const loadingRef = useRef(!isDoNotFetch(params)); const errorRef = useRef(null); const usedRef = useRef({ data: false, @@ -20,7 +33,7 @@ export const useFetcher = ( }); const paramsRef = useRef(params); paramsRef.current = params; - const stringifyParams = JSON.stringify(params); + const paramsKey = isDoNotFetch(params) ? params : JSON.stringify(params); const forceUpdate = useCallback(() => { setCount((prevCount) => (prevCount === Number.MAX_SAFE_INTEGER ? 0 : prevCount + 1)); @@ -61,7 +74,9 @@ export const useFetcher = ( ); const reload = useCallback(() => { - loadData(paramsRef.current); + if (!isDoNotFetch(paramsRef.current)) { + loadData(paramsRef.current); + } }, [loadData]); const update = useCallback( @@ -80,12 +95,15 @@ export const useFetcher = ( ); useEffect(() => { + if (isDoNotFetch(paramsKey)) { + return; + } let changed = false; - loadData(JSON.parse(stringifyParams), () => !changed); + loadData(JSON.parse(paramsKey), () => !changed); return () => { changed = true; }; - }, [stringifyParams, loadData]); + }, [paramsKey, loadData]); return Object.defineProperties( { diff --git a/public/plugin.ts b/public/plugin.ts index 91177a3d..20069db0 100644 --- a/public/plugin.ts +++ b/public/plugin.ts @@ -9,13 +9,15 @@ import { MlCommonsPluginPluginStart, AppPluginStartDependencies, MLServices, + MlCommonsPluginPluginSetupDependencies, } from './types'; import { PLUGIN_NAME, PLUGIN_ID } from '../common'; export class MlCommonsPluginPlugin implements Plugin { public setup( - core: CoreSetup + core: CoreSetup, + { dataSource, dataSourceManagement }: MlCommonsPluginPluginSetupDependencies ): MlCommonsPluginPluginSetup { // Register an application into the side navigation menu core.application.register({ @@ -37,6 +39,8 @@ export class MlCommonsPluginPlugin data, navigation, history: params.history, + dataSource, + dataSourceManagement, setHeaderActionMenu: params.setHeaderActionMenu, }; // Render the application diff --git a/public/types.ts b/public/types.ts index df68d094..06776354 100644 --- a/public/types.ts +++ b/public/types.ts @@ -7,18 +7,25 @@ import { History } from 'history'; import { DataPublicPluginStart } from '../../../src/plugins/data/public'; import { NavigationPublicPluginStart } from '../../../src/plugins/navigation/public'; import { AppMountParameters, CoreStart } from '../../../src/core/public'; +import type { DataSourceManagementPluginSetup } from '../../../src/plugins/data_source_management/public'; +import type { DataSourcePluginSetup } from '../../../src/plugins/data_source/public'; // eslint-disable-next-line @typescript-eslint/no-empty-interface export interface MlCommonsPluginPluginSetup {} // eslint-disable-next-line @typescript-eslint/no-empty-interface export interface MlCommonsPluginPluginStart {} +export interface MlCommonsPluginPluginSetupDependencies { + dataSource?: DataSourcePluginSetup; + dataSourceManagement?: DataSourceManagementPluginSetup; +} + export interface AppPluginStartDependencies { navigation: NavigationPublicPluginStart; data: DataPublicPluginStart; } -export interface MLServices extends CoreStart { +export interface MLServices extends CoreStart, MlCommonsPluginPluginSetupDependencies { setHeaderActionMenu: AppMountParameters['setHeaderActionMenu']; navigation: NavigationPublicPluginStart; data: DataPublicPluginStart; diff --git a/public/utils/__tests__/data_source.test.ts b/public/utils/__tests__/data_source.test.ts new file mode 100644 index 00000000..51b6c3f2 --- /dev/null +++ b/public/utils/__tests__/data_source.test.ts @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { DATA_SOURCE_FETCHING_ID, DATA_SOURCE_INVALID_ID, getDataSourceId } from '../data_source'; + +describe('getDataSourceId', () => { + it('should return undefined when data source not enabled', () => { + expect(getDataSourceId(false, null)).toBe(undefined); + expect(getDataSourceId(null, null)).toBe(undefined); + }); + + it('should return fetching id when selected data source option is null', () => { + expect(getDataSourceId(true, null)).toBe(DATA_SOURCE_FETCHING_ID); + }); + + it('should return invalid id when selected data source option is undefined', () => { + expect(getDataSourceId(true, undefined)).toBe(DATA_SOURCE_INVALID_ID); + }); + + it('should return undefined when selected data source id is empty', () => { + expect(getDataSourceId(true, { id: '' })).toBe(undefined); + }); + + it('should return selected data source id', () => { + expect(getDataSourceId(true, { id: 'foo' })).toBe('foo'); + }); +}); diff --git a/public/utils/data_source.ts b/public/utils/data_source.ts new file mode 100644 index 00000000..ba236bd7 --- /dev/null +++ b/public/utils/data_source.ts @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { DataSourceOption } from '../contexts'; + +export const DATA_SOURCE_FETCHING_ID = Symbol('DATA_SOURCE_FETCHING_ID'); +export const DATA_SOURCE_INVALID_ID = Symbol('DATA_SOURCE_INVALID_ID'); + +export const getDataSourceId = ( + dataSourceEnabled: boolean | null, + selectedDataSourceOption: DataSourceOption | null | undefined +) => { + if (!dataSourceEnabled) { + return undefined; + } + if (selectedDataSourceOption === null) { + return DATA_SOURCE_FETCHING_ID; + } + if (selectedDataSourceOption === undefined) { + return DATA_SOURCE_INVALID_ID; + } + // If selected data source is local cluster, the data source id should be undefined + if (selectedDataSourceOption.id === '') { + return undefined; + } + return selectedDataSourceOption.id; +}; + +export type DataSourceId = ReturnType; diff --git a/server/__tests__/plugin.test.ts b/server/__tests__/plugin.test.ts new file mode 100644 index 00000000..b27eb587 --- /dev/null +++ b/server/__tests__/plugin.test.ts @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { MlCommonsPlugin } from '../plugin'; +import { coreMock, httpServiceMock } from '../../../../src/core/server/mocks'; +import * as modelRouterExports from '../routes/model_router'; +import * as connectorRouterExports from '../routes/connector_router'; +import * as profileRouterExports from '../routes/profile_router'; + +describe('MlCommonsPlugin', () => { + describe('setup', () => { + let mockCoreSetup: ReturnType; + let initContext: ReturnType; + let routerMock: ReturnType; + + beforeEach(() => { + mockCoreSetup = coreMock.createSetup(); + routerMock = httpServiceMock.createRouter(); + mockCoreSetup.http.createRouter.mockReturnValue(routerMock); + initContext = coreMock.createPluginInitializerContext(); + }); + + it('should register model routers', () => { + jest.spyOn(modelRouterExports, 'modelRouter'); + new MlCommonsPlugin(initContext).setup(mockCoreSetup); + expect(modelRouterExports.modelRouter).toHaveBeenCalledWith(routerMock); + }); + + it('should register connector routers', () => { + jest.spyOn(connectorRouterExports, 'connectorRouter'); + new MlCommonsPlugin(initContext).setup(mockCoreSetup); + expect(connectorRouterExports.connectorRouter).toHaveBeenCalledWith(routerMock); + }); + + it('should register profile routers', () => { + jest.spyOn(profileRouterExports, 'profileRouter'); + new MlCommonsPlugin(initContext).setup(mockCoreSetup); + expect(profileRouterExports.profileRouter).toHaveBeenCalledWith(routerMock); + }); + }); +}); diff --git a/server/routes/__tests__/connector_router.test.ts b/server/routes/__tests__/connector_router.test.ts new file mode 100644 index 00000000..d106d2e7 --- /dev/null +++ b/server/routes/__tests__/connector_router.test.ts @@ -0,0 +1,117 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ResponseObject } from '@hapi/hapi'; + +import { Router } from '../../../../../src/core/server/http/router'; +import { triggerHandler, createDataSourceEnhancedRouter } from '../router.mock'; +import { httpServerMock } from '../../../../../src/core/server/http/http_server.mocks'; +import { loggerMock } from '../../../../../src/core/server/logging/logger.mock'; +import { connectorRouter } from '../connector_router'; +import { CONNECTOR_API_ENDPOINT, INTERNAL_CONNECTOR_API_ENDPOINT } from '../constants'; +import { ConnectorService } from '../../services/connector_service'; + +const setupRouter = () => { + const mockedLogger = loggerMock.create(); + const { + router, + dataSourceTransportMock, + getLatestCurrentUserTransport, + } = createDataSourceEnhancedRouter(mockedLogger); + + connectorRouter(router); + return { + router, + dataSourceTransportMock, + getLatestCurrentUserTransport, + }; +}; + +const triggerGetAllConnectors = (router: Router, dataSourceId?: string) => + triggerHandler(router, { + method: 'GET', + path: CONNECTOR_API_ENDPOINT, + req: httpServerMock.createRawRequest({ query: { data_source_id: dataSourceId } }), + }); +const triggerGetAllInternalConnectors = (router: Router, dataSourceId?: string) => + triggerHandler(router, { + method: 'GET', + path: INTERNAL_CONNECTOR_API_ENDPOINT, + req: httpServerMock.createRawRequest({ query: { data_source_id: dataSourceId } }), + }); + +jest.mock('../../services/connector_service'); + +describe('connector routers', () => { + beforeEach(() => { + jest.spyOn(ConnectorService, 'search'); + jest.spyOn(ConnectorService, 'getUniqueInternalConnectorNames'); + }); + afterEach(() => { + jest.resetAllMocks(); + }); + + describe('get all connector', () => { + it('should call connector search and return consistent result', async () => { + expect(ConnectorService.search).not.toHaveBeenCalled(); + const { router, getLatestCurrentUserTransport } = setupRouter(); + const result = (await triggerGetAllConnectors(router)) as ResponseObject; + expect(ConnectorService.search).toHaveBeenCalledWith({ + transport: getLatestCurrentUserTransport(), + from: 0, + size: 10000, + }); + expect(result.source).toMatchInlineSnapshot(` + Object { + "data": Array [ + "connector 1", + "connector 2", + ], + "total_connectors": 2, + } + `); + }); + + it('should call connector search with data source transport', async () => { + expect(ConnectorService.search).not.toHaveBeenCalled(); + const { router, dataSourceTransportMock } = setupRouter(); + + await triggerGetAllConnectors(router, 'foo'); + expect(ConnectorService.search).toHaveBeenCalledWith({ + transport: dataSourceTransportMock, + from: 0, + size: 10000, + }); + }); + }); + + describe('get all internal connector', () => { + it('should call connector getUniqueInternalConnectorNames and return consistent result', async () => { + expect(ConnectorService.getUniqueInternalConnectorNames).not.toHaveBeenCalled(); + const { router, getLatestCurrentUserTransport } = setupRouter(); + const result = (await triggerGetAllInternalConnectors(router)) as ResponseObject; + expect(ConnectorService.getUniqueInternalConnectorNames).toHaveBeenCalledWith({ + transport: getLatestCurrentUserTransport(), + size: 10000, + }); + expect(result.source).toMatchInlineSnapshot(` + Object { + "data": undefined, + } + `); + }); + + it('should call connector getUniqueInternalConnectorNames with data source transport', async () => { + expect(ConnectorService.search).not.toHaveBeenCalled(); + const { router, dataSourceTransportMock } = setupRouter(); + + await triggerGetAllInternalConnectors(router, 'foo'); + expect(ConnectorService.getUniqueInternalConnectorNames).toHaveBeenCalledWith({ + transport: dataSourceTransportMock, + size: 10000, + }); + }); + }); +}); diff --git a/server/routes/__tests__/model_router.test.ts b/server/routes/__tests__/model_router.test.ts new file mode 100644 index 00000000..7126e2f6 --- /dev/null +++ b/server/routes/__tests__/model_router.test.ts @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ResponseObject } from '@hapi/hapi'; +import { Boom } from '@hapi/boom'; + +import { Router } from '../../../../../src/core/server/http/router'; +import { triggerHandler, createDataSourceEnhancedRouter } from '../router.mock'; +import { httpServerMock } from '../../../../../src/core/server/http/http_server.mocks'; +import { loggerMock } from '../../../../../src/core/server/logging/logger.mock'; +import { MODEL_API_ENDPOINT } from '../constants'; +import { modelRouter } from '../model_router'; +import { ModelService } from '../../services'; + +const setupRouter = () => { + const mockedLogger = loggerMock.create(); + const { + router, + dataSourceTransportMock, + getLatestCurrentUserTransport, + } = createDataSourceEnhancedRouter(mockedLogger); + + modelRouter(router); + return { + router, + dataSourceTransportMock, + getLatestCurrentUserTransport, + }; +}; + +const triggerModelSearch = ( + router: Router, + { dataSourceId, from, size }: { dataSourceId?: string; from?: number; size?: number } +) => + triggerHandler(router, { + method: 'GET', + path: MODEL_API_ENDPOINT, + req: httpServerMock.createRawRequest({ + query: { data_source_id: dataSourceId, from, size }, + }), + }); + +jest.mock('../../services/model_service'); + +describe('model routers', () => { + beforeEach(() => { + jest.spyOn(ModelService, 'search'); + }); + afterEach(() => { + jest.resetAllMocks(); + }); + + describe('model search', () => { + it('should call connector search and return consistent result', async () => { + expect(ModelService.search).not.toHaveBeenCalled(); + const { router, getLatestCurrentUserTransport } = setupRouter(); + + const result = (await triggerModelSearch(router, { from: 0, size: 50 })) as ResponseObject; + expect(ModelService.search).toHaveBeenCalledWith( + expect.objectContaining({ + transport: getLatestCurrentUserTransport(), + from: 0, + size: 50, + }) + ); + expect(result.source).toMatchInlineSnapshot(` + Object { + "data": Array [ + Object { + "name": "Model 1", + }, + ], + "total_models": 1, + } + `); + }); + + it('should call model search with data source transport', async () => { + expect(ModelService.search).not.toHaveBeenCalled(); + const { router, dataSourceTransportMock } = setupRouter(); + + await triggerModelSearch(router, { dataSourceId: 'foo', from: 0, size: 50 }); + expect(ModelService.search).toHaveBeenCalledWith({ + transport: dataSourceTransportMock, + from: 0, + size: 50, + }); + }); + + it('should response error message after model search throw error', async () => { + jest.spyOn(ModelService, 'search').mockImplementationOnce(() => { + throw new Error('foo'); + }); + const { router, getLatestCurrentUserTransport } = setupRouter(); + + const result = (await triggerModelSearch(router, { from: 0, size: 50 })) as Boom; + expect(ModelService.search).toHaveBeenCalledWith( + expect.objectContaining({ + transport: getLatestCurrentUserTransport(), + from: 0, + size: 50, + }) + ); + expect(result.output.payload).toMatchInlineSnapshot(` + Object { + "error": "Bad Request", + "message": "foo", + "statusCode": 400, + } + `); + }); + }); +}); diff --git a/server/routes/__tests__/profile_router.test.ts b/server/routes/__tests__/profile_router.test.ts new file mode 100644 index 00000000..c5c4f057 --- /dev/null +++ b/server/routes/__tests__/profile_router.test.ts @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ResponseObject } from '@hapi/hapi'; +import { Boom } from '@hapi/boom'; + +import { Router } from '../../../../../src/core/server/http/router'; +import { triggerHandler, createDataSourceEnhancedRouter } from '../router.mock'; +import { httpServerMock } from '../../../../../src/core/server/http/http_server.mocks'; +import { loggerMock } from '../../../../../src/core/server/logging/logger.mock'; +import { DEPLOYED_MODEL_PROFILE_API_ENDPOINT } from '../constants'; +import { ProfileService } from '../../services/profile_service'; +import { profileRouter } from '../profile_router'; + +const setupRouter = () => { + const mockedLogger = loggerMock.create(); + const { + router, + dataSourceTransportMock, + getLatestCurrentUserTransport, + } = createDataSourceEnhancedRouter(mockedLogger); + + profileRouter(router); + return { + router, + dataSourceTransportMock, + getLatestCurrentUserTransport, + }; +}; + +const triggerGetModelProfile = ( + router: Router, + { dataSourceId, modelId }: { dataSourceId?: string; modelId: string } +) => + triggerHandler(router, { + method: 'GET', + path: `${DEPLOYED_MODEL_PROFILE_API_ENDPOINT}/{modelId}`, + req: httpServerMock.createRawRequest({ + query: { data_source_id: dataSourceId }, + params: { modelId }, + }), + }); + +jest.mock('../../services/profile_service'); + +describe('profile routers', () => { + beforeEach(() => { + jest.spyOn(ProfileService, 'getModel'); + }); + afterEach(() => { + jest.resetAllMocks(); + }); + + describe('get model profile', () => { + it('should call get model profile and return consistent result', async () => { + expect(ProfileService.getModel).not.toHaveBeenCalled(); + const { router, getLatestCurrentUserTransport } = setupRouter(); + + const result = (await triggerGetModelProfile(router, { + modelId: 'model-1', + })) as ResponseObject; + expect(ProfileService.getModel).toHaveBeenCalledWith({ + modelId: 'model-1', + transport: getLatestCurrentUserTransport(), + }); + expect(result.source).toMatchInlineSnapshot(` + Object { + "id": "model-1", + "not_worker_nodes": Array [], + "target_worker_nodes": Array [ + "node-1", + ], + "worker_nodes": Array [ + "node-1", + ], + } + `); + }); + + it('should call get model profile with data source transport', async () => { + expect(ProfileService.getModel).not.toHaveBeenCalled(); + const { router, dataSourceTransportMock } = setupRouter(); + + await triggerGetModelProfile(router, { dataSourceId: 'foo', modelId: 'model-1' }); + expect(ProfileService.getModel).toHaveBeenCalledWith({ + modelId: 'model-1', + transport: dataSourceTransportMock, + }); + }); + + it('should response consistent error message after get model profile throw error', async () => { + jest.spyOn(ProfileService, 'getModel').mockImplementationOnce(() => { + throw new Error('foo'); + }); + const { router } = setupRouter(); + + const result = (await triggerGetModelProfile(router, { + dataSourceId: 'foo', + modelId: 'model-1', + })) as Boom; + expect(result.output.payload).toMatchInlineSnapshot(` + Object { + "error": "Bad Request", + "message": "foo", + "statusCode": 400, + } + `); + }); + + it('should response invalid model id', async () => { + const { router } = setupRouter(); + + const result = (await triggerGetModelProfile(router, { + dataSourceId: 'foo', + modelId: 'foo~!', + })) as Boom; + expect(result.output.payload).toMatchInlineSnapshot(` + Object { + "error": "Bad Request", + "message": "[request params.modelId]: Invalid model id", + "statusCode": 400, + } + `); + }); + }); +}); diff --git a/server/routes/__tests__/utils.test.ts b/server/routes/__tests__/utils.test.ts new file mode 100644 index 00000000..d8d19938 --- /dev/null +++ b/server/routes/__tests__/utils.test.ts @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { getOpenSearchClientTransport } from '../utils'; +import { coreMock } from '../../../../../src/core/server/mocks'; + +describe('getOpenSearchClientTransport', () => { + it('should return current user opensearch transport', async () => { + const core = coreMock.createRequestHandlerContext(); + + expect(await getOpenSearchClientTransport({ context: { core } })).toBe( + core.opensearch.client.asCurrentUser.transport + ); + }); + it('should data source id related opensearch transport', async () => { + const transportMock = {}; + const core = coreMock.createRequestHandlerContext(); + const context = { + core, + dataSource: { + opensearch: { + getClient: async (_dataSourceId: string) => ({ + transport: transportMock, + }), + }, + }, + }; + + expect(await getOpenSearchClientTransport({ context, dataSourceId: 'foo' })).toBe( + transportMock + ); + }); +}); diff --git a/server/routes/connector_router.ts b/server/routes/connector_router.ts index eb71cbdf..9e36e2d8 100644 --- a/server/routes/connector_router.ts +++ b/server/routes/connector_router.ts @@ -3,19 +3,28 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { schema } from '@osd/config-schema'; import { IRouter } from '../../../../src/core/server'; import { CONNECTOR_API_ENDPOINT, INTERNAL_CONNECTOR_API_ENDPOINT } from './constants'; import { ConnectorService } from '../services/connector_service'; +import { getOpenSearchClientTransport } from './utils'; export const connectorRouter = (router: IRouter) => { router.get( { path: CONNECTOR_API_ENDPOINT, - validate: {}, + validate: { + query: schema.object({ + data_source_id: schema.maybe(schema.string()), + }), + }, }, - router.handleLegacyErrors(async (context, _req, res) => { + router.handleLegacyErrors(async (context, request, res) => { const payload = await ConnectorService.search({ - client: context.core.opensearch.client, + transport: await getOpenSearchClientTransport({ + dataSourceId: request.query.data_source_id, + context, + }), from: 0, size: 10000, }); @@ -25,11 +34,18 @@ export const connectorRouter = (router: IRouter) => { router.get( { path: INTERNAL_CONNECTOR_API_ENDPOINT, - validate: {}, + validate: { + query: schema.object({ + data_source_id: schema.maybe(schema.string()), + }), + }, }, - router.handleLegacyErrors(async (context, _req, res) => { + router.handleLegacyErrors(async (context, request, res) => { const data = await ConnectorService.getUniqueInternalConnectorNames({ - client: context.core.opensearch.client, + transport: await getOpenSearchClientTransport({ + dataSourceId: request.query.data_source_id, + context, + }), size: 10000, }); return res.ok({ body: { data } }); diff --git a/server/routes/model_router.ts b/server/routes/model_router.ts index 0adbd537..4a352dd6 100644 --- a/server/routes/model_router.ts +++ b/server/routes/model_router.ts @@ -5,9 +5,10 @@ import { schema } from '@osd/config-schema'; import { MODEL_STATE } from '../../common'; -import { IRouter, opensearchDashboardsResponseFactory } from '../../../../src/core/server'; +import { IRouter } from '../../../../src/core/server'; import { ModelService } from '../services'; import { MODEL_API_ENDPOINT } from './constants'; +import { getOpenSearchClientTransport } from './utils'; const modelSortQuerySchema = schema.oneOf([ schema.literal('name-asc'), @@ -43,14 +44,26 @@ export const modelRouter = (router: IRouter) => { states: schema.maybe(schema.oneOf([schema.arrayOf(modelStateSchema), modelStateSchema])), nameOrId: schema.maybe(schema.string()), extra_query: schema.maybe(schema.recordOf(schema.string(), schema.any())), + data_source_id: schema.maybe(schema.string()), }), }, }, - async (context, request) => { - const { from, size, sort, states, nameOrId, extra_query: extraQuery } = request.query; + async (context, request, response) => { + const { + from, + size, + sort, + states, + nameOrId, + extra_query: extraQuery, + data_source_id: dataSourceId, + } = request.query; try { const payload = await ModelService.search({ - client: context.core.opensearch.client, + transport: await getOpenSearchClientTransport({ + dataSourceId, + context, + }), from, size, sort: typeof sort === 'string' ? [sort] : sort, @@ -58,9 +71,9 @@ export const modelRouter = (router: IRouter) => { nameOrId, extraQuery, }); - return opensearchDashboardsResponseFactory.ok({ body: payload }); + return response.ok({ body: payload }); } catch (err) { - return opensearchDashboardsResponseFactory.badRequest({ body: err.message }); + return response.badRequest({ body: err.message }); } } ); diff --git a/server/routes/profile_router.ts b/server/routes/profile_router.ts index 34062ead..bb10afe8 100644 --- a/server/routes/profile_router.ts +++ b/server/routes/profile_router.ts @@ -4,9 +4,11 @@ */ import { schema } from '@osd/config-schema'; -import { IRouter, opensearchDashboardsResponseFactory } from '../../../../src/core/server'; + +import { IRouter } from '../../../../src/core/server'; import { ProfileService } from '../services/profile_service'; import { DEPLOYED_MODEL_PROFILE_API_ENDPOINT } from './constants'; +import { getOpenSearchClientTransport } from './utils'; export const profileRouter = (router: IRouter) => { router.get( @@ -22,17 +24,23 @@ export const profileRouter = (router: IRouter) => { }, }), }), + query: schema.object({ + data_source_id: schema.maybe(schema.string()), + }), }, }, - async (context, request) => { + async (context, request, response) => { try { const payload = await ProfileService.getModel({ - client: context.core.opensearch.client, + transport: await getOpenSearchClientTransport({ + dataSourceId: request.query.data_source_id, + context, + }), modelId: request.params.modelId, }); - return opensearchDashboardsResponseFactory.ok({ body: payload }); + return response.ok({ body: payload }); } catch (error) { - return opensearchDashboardsResponseFactory.badRequest({ body: error as Error }); + return response.badRequest({ body: error as Error }); } } ); diff --git a/server/routes/router.mock.ts b/server/routes/router.mock.ts new file mode 100644 index 00000000..7047f505 --- /dev/null +++ b/server/routes/router.mock.ts @@ -0,0 +1,154 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + Auth, + AuthenticationData, + Request, + ResponseObject, + ResponseToolkit, + ServerRealm, + ServerStateCookieOptions, +} from '@hapi/hapi'; +// @ts-ignore +import Response from '@hapi/hapi/lib/response'; +import { ProxyHandlerOptions } from '@hapi/h2o2'; +import { ReplyFileHandlerOptions } from '@hapi/inert'; +import { httpServerMock } from '../../../../src/core/server/http/http_server.mocks'; +import { + OpenSearchDashboardsRequest, + OpenSearchDashboardsResponseFactory, + RouteMethod, + Router, +} from '../../../../src/core/server/http/router'; +import { CoreRouteHandlerContext } from '../../../../src/core/server/core_route_handler_context'; +import { coreMock } from '../../../../src/core/server/mocks'; +import { ContextEnhancer } from '../../../../src/core/server/http/router/router'; +import { Logger, OpenSearchClient } from '../../../../src/core/server'; + +/** + * For hapi, ResponseToolkit is an internal implementation + * so we have to create a MockResponseToolkit to mock the behavior. + * This class should be put under OSD core, + */ +export class MockResponseToolkit implements ResponseToolkit { + abandon: symbol = Symbol('abandon'); + close: symbol = Symbol('close'); + context: unknown; + continue: symbol = Symbol('continue'); + realm: ServerRealm = { + modifiers: { + route: { + prefix: '', + vhost: '', + }, + }, + parent: null, + plugin: '', + pluginOptions: {}, + plugins: [], + settings: { + files: { + relativeTo: '', + }, + bind: {}, + }, + }; + request: Readonly = httpServerMock.createRawRequest(); + authenticated(): Auth { + throw new Error('Method not implemented.'); + } + entity( + options?: + | { etag?: string | undefined; modified?: string | undefined; vary?: boolean | undefined } + | undefined + ): ResponseObject | undefined { + throw new Error('Method not implemented.'); + } + redirect(uri?: string | undefined): ResponseObject { + throw new Error('Method not implemented.'); + } + state( + name: string, + value: string | object, + options?: ServerStateCookieOptions | undefined + ): void { + throw new Error('Method not implemented.'); + } + unauthenticated(error: Error, data?: AuthenticationData | undefined): void { + throw new Error('Method not implemented.'); + } + unstate(name: string, options?: ServerStateCookieOptions | undefined): void { + throw new Error('Method not implemented.'); + } + file(path: string, options?: ReplyFileHandlerOptions | undefined): ResponseObject { + throw new Error('Method not implemented.'); + } + proxy(options: ProxyHandlerOptions): Promise { + throw new Error('Method not implemented.'); + } + response(payload: unknown) { + return new Response(payload); + } +} + +const enhanceWithContext = (((coreContext: CoreRouteHandlerContext, otherContext?: object) => ( + fn: (...args: unknown[]) => unknown +) => (req: OpenSearchDashboardsRequest, res: OpenSearchDashboardsResponseFactory) => { + return fn.call( + null, + { + core: coreContext, + ...otherContext, + }, + req, + res + ); +}) as unknown) as ( + otherContext?: object +) => ContextEnhancer; + +export const triggerHandler = async ( + router: Router, + options: { + method: string; + path: string; + req: Request; + } +) => { + const allRoutes = router.getRoutes(); + const findRoute = allRoutes.find( + (item) => + item.method.toUpperCase() === options.method.toUpperCase() && item.path === options.path + ); + return await findRoute?.handler(options.req, new MockResponseToolkit()); +}; + +export const createDataSourceEnhancedRouter = (logger: Logger) => { + const dataSourceTransportMock = {}; + let latestCurrentUserTransport: OpenSearchClient['transport']; + const router = new Router('', logger, (fn) => (req, res) => { + const core = new CoreRouteHandlerContext(coreMock.createInternalStart(), req); + latestCurrentUserTransport = core.opensearch.client.asCurrentUser.transport; + return fn.call( + null, + { + core, + dataSource: { + opensearch: { + getClient: async (_dataSourceId: string) => ({ transport: dataSourceTransportMock }), + }, + }, + }, + req, + res + ); + }); + return { + router, + dataSourceTransportMock, + getLatestCurrentUserTransport: () => latestCurrentUserTransport, + }; +}; diff --git a/server/routes/utils.ts b/server/routes/utils.ts new file mode 100644 index 00000000..9f54a4ff --- /dev/null +++ b/server/routes/utils.ts @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { OpenSearchClient, RequestHandlerContext } from '../../../../src/core/server'; + +export const getOpenSearchClientTransport = async ({ + context, + dataSourceId, +}: { + context: RequestHandlerContext & { + dataSource?: { + opensearch: { + getClient: (dataSourceId: string) => Promise; + }; + }; + }; + dataSourceId?: string; +}) => { + if (dataSourceId && context.dataSource) { + return (await context.dataSource.opensearch.getClient(dataSourceId)).transport; + } + return context.core.opensearch.client.asCurrentUser.transport; +}; diff --git a/server/services/__mocks__/connector_service.ts b/server/services/__mocks__/connector_service.ts new file mode 100644 index 00000000..66ae6b1c --- /dev/null +++ b/server/services/__mocks__/connector_service.ts @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export class ConnectorService { + public static async search() { + return { + data: ['connector 1', 'connector 2'], + total_connectors: 2, + }; + } + + public static async getUniqueInternalConnectorNames() { + return ['internal connector 1', 'internal connector 2']; + } +} diff --git a/server/services/__mocks__/model_service.ts b/server/services/__mocks__/model_service.ts new file mode 100644 index 00000000..faedeb58 --- /dev/null +++ b/server/services/__mocks__/model_service.ts @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export class ModelService { + public static async search() { + return { + data: [{ name: 'Model 1' }], + total_models: 1, + }; + } +} diff --git a/server/services/__mocks__/profile_service.ts b/server/services/__mocks__/profile_service.ts new file mode 100644 index 00000000..e45e101e --- /dev/null +++ b/server/services/__mocks__/profile_service.ts @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export class ProfileService { + public static async getModel() { + return { + id: 'model-1', + target_worker_nodes: ['node-1'], + worker_nodes: ['node-1'], + not_worker_nodes: [], + }; + } +} diff --git a/server/services/__tests__/connector_service.test.ts b/server/services/__tests__/connector_service.test.ts new file mode 100644 index 00000000..67959d80 --- /dev/null +++ b/server/services/__tests__/connector_service.test.ts @@ -0,0 +1,213 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ConnectorService } from '../connector_service'; + +const CONNECTOR_SEARCH_RESULT_MOCK = { + body: { + hits: { + hits: [{ _id: 'connector-1', _source: { name: 'Connector 1' } }], + total: { + value: 1, + }, + }, + }, +}; + +const INTERNAL_CONNECTOR_AGGS_RESULT_MOCK = { + body: { + aggregations: { + unique_connector_names: { + buckets: ['Internal connector 1'], + }, + }, + }, +}; + +const createMockTransport = (result: unknown) => ({ + request: jest.fn().mockResolvedValue(result), +}); + +describe('ConnectorService', () => { + describe('search', () => { + it('should call transport request with consistent params', () => { + const mockTransport = createMockTransport(CONNECTOR_SEARCH_RESULT_MOCK); + ConnectorService.search({ + from: 0, + size: 1, + transport: mockTransport, + }); + + expect(mockTransport.request).toHaveBeenCalledTimes(1); + expect(mockTransport.request.mock.calls[0]).toMatchInlineSnapshot(` + Array [ + Object { + "body": Object { + "from": 0, + "query": Object { + "match_all": Object {}, + }, + "size": 1, + }, + "method": "POST", + "path": "/_plugins/_ml/connectors/_search", + }, + ] + `); + }); + + it('should return consistent results', async () => { + const result = await ConnectorService.search({ + from: 0, + size: 1, + transport: createMockTransport(CONNECTOR_SEARCH_RESULT_MOCK), + }); + + expect(result).toMatchInlineSnapshot(` + Object { + "data": Array [ + Object { + "id": "connector-1", + "name": "Connector 1", + }, + ], + "total_connectors": 1, + } + `); + }); + + it('should return empty results when transport request throw index_not_found_exception', async () => { + const mockTransport = createMockTransport(CONNECTOR_SEARCH_RESULT_MOCK); + mockTransport.request.mockImplementationOnce(() => { + throw new Error('index_not_found_exception'); + }); + + const result = await ConnectorService.search({ + from: 0, + size: 1, + transport: mockTransport, + }); + + expect(result).toEqual({ + data: [], + total_connectors: 0, + }); + }); + + it('should throw unexpected error', async () => { + const mockTransport = createMockTransport(CONNECTOR_SEARCH_RESULT_MOCK); + const unexpectedError = new Error('unexpected'); + mockTransport.request.mockImplementationOnce(() => { + throw unexpectedError; + }); + + let error; + try { + await ConnectorService.search({ + from: 0, + size: 1, + transport: mockTransport, + }); + } catch (e) { + error = e; + } + + expect(error).toBe(unexpectedError); + }); + }); + describe('getUniqueInternalConnectorNames', () => { + it('should call transport request with consistent params', () => { + const mockTransport = createMockTransport(CONNECTOR_SEARCH_RESULT_MOCK); + ConnectorService.getUniqueInternalConnectorNames({ + size: 1, + transport: mockTransport, + }); + + expect(mockTransport.request).toHaveBeenCalledTimes(1); + expect(mockTransport.request.mock.calls[0]).toMatchInlineSnapshot(` + Array [ + Object { + "body": Object { + "aggs": Object { + "unique_connector_names": Object { + "terms": Object { + "field": "connector.name.keyword", + "size": 1, + }, + }, + }, + "size": 0, + }, + "method": "POST", + "path": "/_plugins/_ml/models/_search", + }, + ] + `); + }); + + it('should return consistent results', async () => { + const result = await ConnectorService.getUniqueInternalConnectorNames({ + size: 1, + transport: createMockTransport(INTERNAL_CONNECTOR_AGGS_RESULT_MOCK), + }); + + expect(result).toMatchInlineSnapshot(` + Array [ + undefined, + ] + `); + }); + + it('should return empty results when no aggregations results', async () => { + const mockTransport = createMockTransport({ + body: {}, + }); + mockTransport.request.mockImplementationOnce(() => { + throw new Error('index_not_found_exception'); + }); + + const result = await ConnectorService.getUniqueInternalConnectorNames({ + size: 1, + transport: mockTransport, + }); + + expect(result).toEqual([]); + }); + + it('should return empty results when transport request throw index_not_found_exception', async () => { + const mockTransport = createMockTransport(INTERNAL_CONNECTOR_AGGS_RESULT_MOCK); + mockTransport.request.mockImplementationOnce(() => { + throw new Error('index_not_found_exception'); + }); + + const result = await ConnectorService.getUniqueInternalConnectorNames({ + size: 1, + transport: mockTransport, + }); + + expect(result).toEqual([]); + }); + + it('should throw unexpected error', async () => { + const mockTransport = createMockTransport(CONNECTOR_SEARCH_RESULT_MOCK); + const unexpectedError = new Error('unexpected'); + mockTransport.request.mockImplementationOnce(() => { + throw unexpectedError; + }); + + let error; + try { + await ConnectorService.getUniqueInternalConnectorNames({ + size: 1, + transport: mockTransport, + }); + } catch (e) { + error = e; + } + + expect(error).toBe(unexpectedError); + }); + }); +}); diff --git a/server/services/__tests__/model_service.test.ts b/server/services/__tests__/model_service.test.ts new file mode 100644 index 00000000..24381324 --- /dev/null +++ b/server/services/__tests__/model_service.test.ts @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ModelService } from '../model_service'; + +const createTransportMock = () => ({ + request: jest.fn().mockResolvedValue({ + body: { + hits: { + hits: [{ _id: 'model-1', _source: { name: 'Model 1' } }], + total: { + value: 1, + }, + }, + }, + }), +}); + +describe('ModelService', () => { + it('should call transport request with consistent params', () => { + const mockTransport = createTransportMock(); + ModelService.search({ + from: 0, + size: 1, + transport: mockTransport, + }); + + expect(mockTransport.request).toHaveBeenCalledTimes(1); + expect(mockTransport.request.mock.calls[0]).toMatchInlineSnapshot(` + Array [ + Object { + "body": Object { + "from": 0, + "query": Object { + "bool": Object { + "must": Array [], + "must_not": Object { + "exists": Object { + "field": "chunk_number", + }, + }, + }, + }, + "size": 1, + }, + "method": "POST", + "path": "/_plugins/_ml/models/_search", + }, + ] + `); + }); + + it('should call transport request with sort params', () => { + const mockTransport = createTransportMock(); + ModelService.search({ + from: 0, + size: 1, + transport: mockTransport, + sort: ['id-asc'], + }); + + expect(mockTransport.request).toHaveBeenCalledWith( + expect.objectContaining({ + body: expect.objectContaining({ + sort: [{ _id: 'asc' }], + }), + }) + ); + }); + + it('should return consistent results', async () => { + const result = await ModelService.search({ + from: 0, + size: 1, + transport: createTransportMock(), + }); + + expect(result).toMatchInlineSnapshot(` + Object { + "data": Array [ + Object { + "id": "model-1", + "name": "Model 1", + }, + ], + "total_models": 1, + } + `); + }); +}); diff --git a/server/services/__tests__/profile_service.test.ts b/server/services/__tests__/profile_service.test.ts new file mode 100644 index 00000000..c394d8f5 --- /dev/null +++ b/server/services/__tests__/profile_service.test.ts @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ProfileService } from '../profile_service'; + +const createTransportMock = (result?: unknown) => ({ + request: jest.fn().mockResolvedValue( + result || { + body: { + models: { + 'model-1': { + name: 'Model 1', + target_worker_nodes: ['node-1', 'node-2'], + worker_nodes: ['node-1'], + }, + }, + }, + } + ), +}); + +describe('ProfileService', () => { + it('should call transport request with consistent params', () => { + const mockTransport = createTransportMock(); + ProfileService.getModel({ + transport: mockTransport, + modelId: 'model-1', + }); + + expect(mockTransport.request).toHaveBeenCalledTimes(1); + expect(mockTransport.request.mock.calls[0]).toMatchInlineSnapshot(` + Array [ + Object { + "method": "GET", + "path": "/_plugins/_ml/profile/models/model-1?view=model", + }, + ] + `); + }); + + it('should return empty object when models not exists in response', async () => { + const result = await ProfileService.getModel({ + modelId: 'model-1', + transport: createTransportMock({ body: {} }), + }); + + expect(result).toEqual({}); + }); + + it('should return consistent results', async () => { + const result = await ProfileService.getModel({ + modelId: 'model-1', + transport: createTransportMock(), + }); + + expect(result).toMatchInlineSnapshot(` + Object { + "id": "model-1", + "not_worker_nodes": Array [ + "node-2", + ], + "target_worker_nodes": Array [ + "node-1", + "node-2", + ], + "worker_nodes": Array [ + "node-1", + ], + } + `); + }); +}); diff --git a/server/services/connector_service.ts b/server/services/connector_service.ts index 96ba00cb..3dade3e8 100644 --- a/server/services/connector_service.ts +++ b/server/services/connector_service.ts @@ -18,7 +18,7 @@ * permissions and limitations under the License. */ -import { IScopedClusterClient } from '../../../../src/core/server'; +import { OpenSearchClient } from '../../../../src/core/server'; import { CONNECTOR_SEARCH_API, MODEL_SEARCH_API } from './utils/constants'; @@ -26,15 +26,15 @@ export class ConnectorService { public static async search({ from, size, - client, + transport, }: { - client: IScopedClusterClient; + transport: OpenSearchClient['transport']; from: number; size: number; }) { let result; try { - result = await client.asCurrentUser.transport.request({ + result = await transport.request({ method: 'POST', path: CONNECTOR_SEARCH_API, body: { @@ -64,15 +64,15 @@ export class ConnectorService { } public static async getUniqueInternalConnectorNames({ - client, + transport, size, }: { - client: IScopedClusterClient; + transport: OpenSearchClient['transport']; size: number; }) { let result; try { - result = await client.asCurrentUser.transport.request({ + result = await transport.request({ method: 'POST', path: MODEL_SEARCH_API, body: { @@ -93,6 +93,9 @@ export class ConnectorService { } throw e; } + if (!result.body.aggregations) { + return []; + } return result.body.aggregations.unique_connector_names.buckets.map(({ key }) => key); } } diff --git a/server/services/model_service.ts b/server/services/model_service.ts index 062d53dd..730240f2 100644 --- a/server/services/model_service.ts +++ b/server/services/model_service.ts @@ -18,7 +18,7 @@ * permissions and limitations under the License. */ -import { IScopedClusterClient } from '../../../../src/core/server'; +import { OpenSearchClient } from '../../../../src/core/server'; import { MODEL_STATE, ModelSearchSort } from '../../common'; import { generateModelSearchQuery } from './utils/model'; @@ -34,10 +34,10 @@ export class ModelService { from, size, sort, - client, + transport, ...restParams }: { - client: IScopedClusterClient; + transport: OpenSearchClient['transport']; from: number; size: number; sort?: ModelSearchSort[]; @@ -47,7 +47,7 @@ export class ModelService { }) { const { body: { hits }, - } = await client.asCurrentUser.transport.request({ + } = await transport.request({ method: 'POST', path: `${MODEL_BASE_API}/_search`, body: { diff --git a/server/services/profile_service.ts b/server/services/profile_service.ts index dc9f9fcb..e9b59a2c 100644 --- a/server/services/profile_service.ts +++ b/server/services/profile_service.ts @@ -18,7 +18,7 @@ * permissions and limitations under the License. */ -import { IScopedClusterClient } from '../../../../src/core/server'; +import { OpenSearchClient } from '../../../../src/core/server'; import { OpenSearchMLCommonsProfile } from '../../common/profile'; import { PROFILE_BASE_API } from './utils/constants'; @@ -34,10 +34,13 @@ export class ProfileService { }; } - public static async getModel(params: { client: IScopedClusterClient; modelId: string }) { - const { client, modelId } = params; + public static async getModel(params: { + transport: OpenSearchClient['transport']; + modelId: string; + }) { + const { transport, modelId } = params; const result = ( - await client.asCurrentUser.transport.request({ + await transport.request({ method: 'GET', path: `${PROFILE_BASE_API}/models/${modelId}?view=model`, }) diff --git a/server/services/utils/__tests__/model.test.ts b/server/services/utils/__tests__/model.test.ts new file mode 100644 index 00000000..a49b62df --- /dev/null +++ b/server/services/utils/__tests__/model.test.ts @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { MODEL_STATE } from '../../../../common'; +import { generateModelSearchQuery } from '../model'; + +describe('generateModelSearchQuery', () => { + it('should generate consistent query when states provided', () => { + expect(generateModelSearchQuery({ states: [MODEL_STATE.loaded, MODEL_STATE.partiallyLoaded] })) + .toMatchInlineSnapshot(` + Object { + "bool": Object { + "must": Array [ + Object { + "terms": Object { + "model_state": Array [ + "DEPLOYED", + "PARTIALLY_DEPLOYED", + ], + }, + }, + ], + "must_not": Object { + "exists": Object { + "field": "chunk_number", + }, + }, + }, + } + `); + }); + it('should generate consistent query when nameOrId provided', () => { + expect(generateModelSearchQuery({ nameOrId: 'foo' })).toMatchInlineSnapshot(` + Object { + "bool": Object { + "must": Array [ + Object { + "bool": Object { + "should": Array [ + Object { + "wildcard": Object { + "name.keyword": Object { + "case_insensitive": true, + "value": "*foo*", + }, + }, + }, + Object { + "term": Object { + "_id": Object { + "value": "foo", + }, + }, + }, + ], + }, + }, + ], + "must_not": Object { + "exists": Object { + "field": "chunk_number", + }, + }, + }, + } + `); + }); + it('should generate consistent query when extraQuery provided', () => { + expect( + generateModelSearchQuery({ + extraQuery: { + bool: { + must: [ + { + term: { + algorithm: { value: 'REMOTE' }, + }, + }, + ], + }, + }, + }) + ).toMatchInlineSnapshot(` + Object { + "bool": Object { + "must": Array [ + Object { + "bool": Object { + "must": Array [ + Object { + "term": Object { + "algorithm": Object { + "value": "REMOTE", + }, + }, + }, + ], + }, + }, + ], + "must_not": Object { + "exists": Object { + "field": "chunk_number", + }, + }, + }, + } + `); + }); +}); diff --git a/server/services/utils/__tests__/query.test.ts b/server/services/utils/__tests__/query.test.ts new file mode 100644 index 00000000..68c9c812 --- /dev/null +++ b/server/services/utils/__tests__/query.test.ts @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { generateTermQuery, generateMustQueries } from '../query'; + +describe('generateTermQuery', () => { + it('should return consistent result when number value provided', () => { + expect(generateTermQuery('foo', 1)).toMatchInlineSnapshot(` + Object { + "term": Object { + "foo": Object { + "value": 1, + }, + }, + } + `); + }); + it('should return consistent result when string value provided', () => { + expect(generateTermQuery('foo', 'bar')).toMatchInlineSnapshot(` + Object { + "term": Object { + "foo": Object { + "value": "bar", + }, + }, + } + `); + }); + it('should return consistent result when array value provided', () => { + expect(generateTermQuery('foo', [1, 'bar'])).toMatchInlineSnapshot(` + Object { + "terms": Object { + "foo": Array [ + 1, + "bar", + ], + }, + } + `); + }); +}); + +describe('generateMustQueries', () => { + it('should return consistent result when no query provided', () => { + expect(generateMustQueries([])).toMatchInlineSnapshot(` + Object { + "match_all": Object {}, + } + `); + }); + it('should return consistent result when only one query provided', () => { + expect(generateMustQueries([generateTermQuery('foo', 'bar')])).toMatchInlineSnapshot(` + Object { + "term": Object { + "foo": Object { + "value": "bar", + }, + }, + } + `); + }); + it('should return consistent result when multi query provided', () => { + expect(generateMustQueries([generateTermQuery('foo', 'bar'), generateTermQuery('bar', 'baz')])) + .toMatchInlineSnapshot(` + Object { + "bool": Object { + "must": Array [ + Object { + "term": Object { + "foo": Object { + "value": "bar", + }, + }, + }, + Object { + "term": Object { + "bar": Object { + "value": "baz", + }, + }, + }, + ], + }, + } + `); + }); +}); diff --git a/test/jest.config.js b/test/jest.config.js index 5eca2489..ee1ca887 100644 --- a/test/jest.config.js +++ b/test/jest.config.js @@ -18,6 +18,9 @@ module.exports = { '/public/**/*.{ts,tsx}', '!/public/**/*.test.{ts,tsx}', '!/public/**/*.types.ts', + '/server/**/*.{ts,tsx}', + '!/server/**/*.test.{ts,tsx}', + '!/server/**/*.mock.{ts,tsx}', ], coverageDirectory: './coverage', coverageReporters: ['lcov', 'text', 'cobertura', 'html'], diff --git a/test/test_utils.tsx b/test/test_utils.tsx index 07523beb..fadc47f0 100644 --- a/test/test_utils.tsx +++ b/test/test_utils.tsx @@ -6,9 +6,14 @@ import React, { FC, ReactElement } from 'react'; import { I18nProvider } from '@osd/i18n/react'; import { render, RenderOptions } from '@testing-library/react'; +import { DataSourceContextProvider } from '../public/contexts'; const AllTheProviders: FC<{ children: React.ReactNode }> = ({ children }) => { - return {children}; + return ( + + {children} + + ); }; const customRender = (ui: ReactElement, options?: Omit) =>