From 03a8ac5257e1f1d917ad422fd15a4e151fe6f662 Mon Sep 17 00:00:00 2001 From: Yulong Ruan Date: Thu, 12 Sep 2024 22:10:18 +0800 Subject: [PATCH] feat: expose an API to check if a give agent config name has agent id configured + use agentConfigName instead of agentName to avoid confusing Signed-off-by: Yulong Ruan --- common/constants/llm.ts | 1 + public/plugin.tsx | 2 ++ public/services/assistant_client.ts | 27 +++++++++++++++----- public/services/index.ts | 7 +++++- server/routes/agent_routes.ts | 30 +++++++++++++++++++++-- server/routes/get_agent.ts | 19 ++++++++++---- server/routes/summary_routes.ts | 30 ++++++++++++++--------- server/routes/text2viz_routes.ts | 4 +-- server/services/assistant_client.ts | 15 ++++++++---- server/services/chat/olly_chat_service.ts | 4 +-- 10 files changed, 104 insertions(+), 35 deletions(-) diff --git a/common/constants/llm.ts b/common/constants/llm.ts index aef7e830..48be672f 100644 --- a/common/constants/llm.ts +++ b/common/constants/llm.ts @@ -26,6 +26,7 @@ export const TEXT2VIZ_API = { export const AGENT_API = { EXECUTE: `${API_BASE}/agent/_execute`, + CONFIG_EXISTS: `${API_BASE}/agent_config/_exists`, }; export const SUMMARY_ASSISTANT_API = { diff --git a/public/plugin.tsx b/public/plugin.tsx index be9db38d..8bf25378 100644 --- a/public/plugin.tsx +++ b/public/plugin.tsx @@ -43,6 +43,7 @@ import { setUiActions, setExpressions, setHttp, + setAssistantService, } from './services'; import { ConfigSchema } from '../common/types/config'; import { DataSourceService } from './services/data_source_service'; @@ -312,6 +313,7 @@ export class AssistantPlugin setNotifications(core.notifications); setConfigSchema(this.config); setUiActions(uiActions); + setAssistantService(assistantServiceStart); if (this.config.text2viz.enabled) { uiActions.addTriggerAction(AI_ASSISTANT_QUERY_EDITOR_TRIGGER, { diff --git a/public/services/assistant_client.ts b/public/services/assistant_client.ts index fd990137..624b86d8 100644 --- a/public/services/assistant_client.ts +++ b/public/services/assistant_client.ts @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { API_BASE } from '../../common/constants/llm'; +import { AGENT_API } from '../../common/constants/llm'; import { HttpSetup } from '../../../../src/core/public'; interface Options { @@ -17,19 +17,34 @@ export class AssistantClient { executeAgent = (agentId: string, parameters: Record, options?: Options) => { return this.http.fetch({ method: 'POST', - path: `${API_BASE}/agent/_execute`, + path: AGENT_API.EXECUTE, body: JSON.stringify(parameters), query: { dataSourceId: options?.dataSourceId, agentId }, }); }; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - executeAgentByName = (agentName: string, parameters: Record, options?: Options) => { + executeAgentByConfigName = ( + agentConfigName: string, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + parameters: Record, + options?: Options + ) => { return this.http.fetch({ method: 'POST', - path: `${API_BASE}/agent/_execute`, + path: AGENT_API.EXECUTE, body: JSON.stringify(parameters), - query: { dataSourceId: options?.dataSourceId, agentName }, + query: { dataSourceId: options?.dataSourceId, agentConfigName }, + }); + }; + + /** + * Return if the given agent config name has agent id configured + */ + agentConfigExists = (agentConfigName: string, options?: Options) => { + return this.http.fetch<{ exists: boolean }>({ + method: 'GET', + path: AGENT_API.CONFIG_EXISTS, + query: { dataSourceId: options?.dataSourceId, agentConfigName }, }); }; } diff --git a/public/services/index.ts b/public/services/index.ts index e515769d..21a5fed4 100644 --- a/public/services/index.ts +++ b/public/services/index.ts @@ -4,12 +4,13 @@ */ import { createGetterSetter } from '../../../../src/plugins/opensearch_dashboards_utils/public'; -import { UiActionsSetup, UiActionsStart } from '../../../../src/plugins/ui_actions/public'; +import { UiActionsStart } from '../../../../src/plugins/ui_actions/public'; import { ChromeStart, HttpStart, NotificationsStart } from '../../../../src/core/public'; import { IncontextInsightRegistry } from './incontext_insight'; import { ConfigSchema } from '../../common/types/config'; import { IndexPatternsContract } from '../../../../src/plugins/data/public'; import { ExpressionsStart } from '../../../../src/plugins/expressions/public'; +import { AssistantServiceStart } from './assistant_service'; export * from './incontext_insight'; export { ConversationLoadService } from './conversation_load_service'; @@ -37,4 +38,8 @@ export const [getExpressions, setExpressions] = createGetterSetter('Http'); +export const [getAssistantService, setAssistantService] = createGetterSetter( + 'AssistantServiceStart' +); + export { DataSourceService, DataSourceServiceContract } from './data_source_service'; diff --git a/server/routes/agent_routes.ts b/server/routes/agent_routes.ts index 58e90533..8fb7abd4 100644 --- a/server/routes/agent_routes.ts +++ b/server/routes/agent_routes.ts @@ -21,7 +21,7 @@ export function registerAgentRoutes(router: IRouter, assistantService: Assistant }), schema.object({ dataSourceId: schema.maybe(schema.string()), - agentName: schema.string(), + agentConfigName: schema.string(), }), ]), }, @@ -33,11 +33,37 @@ export function registerAgentRoutes(router: IRouter, assistantService: Assistant const response = await assistantClient.executeAgent(req.query.agentId, req.body); return res.ok({ body: response }); } - const response = await assistantClient.executeAgentByName(req.query.agentName, req.body); + const response = await assistantClient.executeAgentByConfigName( + req.query.agentConfigName, + req.body + ); return res.ok({ body: response }); } catch (e) { return res.internalError(); } }) ); + + router.get( + { + path: AGENT_API.CONFIG_EXISTS, + validate: { + query: schema.oneOf([ + schema.object({ + dataSourceId: schema.string(), + agentConfigName: schema.string(), + }), + ]), + }, + }, + router.handleLegacyErrors(async (context, req, res) => { + try { + const assistantClient = assistantService.getScopedClient(req, context); + const agentId = await assistantClient.getAgentIdByConfigName(req.query.agentConfigName); + return res.ok({ body: { exists: Boolean(agentId) } }); + } catch (e) { + return res.ok({ body: { exists: false } }); + } + }) + ); } diff --git a/server/routes/get_agent.ts b/server/routes/get_agent.ts index 993a0dd7..51e57305 100644 --- a/server/routes/get_agent.ts +++ b/server/routes/get_agent.ts @@ -6,9 +6,15 @@ import { OpenSearchClient } from '../../../../src/core/server'; import { ML_COMMONS_BASE_API } from '../utils/constants'; -export const getAgent = async (id: string, client: OpenSearchClient['transport']) => { +/** + * + */ +export const getAgentIdByConfigName = async ( + configName: string, + client: OpenSearchClient['transport'] +): Promise => { try { - const path = `${ML_COMMONS_BASE_API}/config/${id}`; + const path = `${ML_COMMONS_BASE_API}/config/${configName}`; const response = await client.request({ method: 'GET', path, @@ -18,16 +24,19 @@ export const getAgent = async (id: string, client: OpenSearchClient['transport'] !response || !(response.body.ml_configuration?.agent_id || response.body.configuration?.agent_id) ) { - throw new Error(`cannot get agent ${id} by calling the api: ${path}`); + throw new Error(`cannot get agent ${configName} by calling the api: ${path}`); } return response.body.ml_configuration?.agent_id || response.body.configuration.agent_id; } catch (error) { const errorMessage = JSON.stringify(error.meta?.body) || error; - throw new Error(`get agent ${id} failed, reason: ${errorMessage}`); + throw new Error(`get agent ${configName} failed, reason: ${errorMessage}`); } }; -export const searchAgentByName = async (name: string, client: OpenSearchClient['transport']) => { +export const searchAgent = async ( + { name }: { name: string }, + client: OpenSearchClient['transport'] +) => { try { const requestParams = { query: { diff --git a/server/routes/summary_routes.ts b/server/routes/summary_routes.ts index 6a099d5f..c9c13ca0 100644 --- a/server/routes/summary_routes.ts +++ b/server/routes/summary_routes.ts @@ -7,7 +7,7 @@ import { schema } from '@osd/config-schema'; import { IRouter } from '../../../../src/core/server'; import { SUMMARY_ASSISTANT_API } from '../../common/constants/llm'; import { getOpenSearchClientTransport } from '../utils/get_opensearch_client_transport'; -import { getAgent, searchAgentByName } from './get_agent'; +import { getAgentIdByConfigName, searchAgent } from './get_agent'; import { AssistantServiceSetup } from '../services/assistant_service'; const SUMMARY_AGENT_CONFIG_ID = 'os_summary'; @@ -41,7 +41,7 @@ export function registerSummaryAssistantRoutes( dataSourceId: req.query.dataSourceId, }); const assistantClient = assistantService.getScopedClient(req, context); - const response = await assistantClient.executeAgentByName(SUMMARY_AGENT_CONFIG_ID, { + const response = await assistantClient.executeAgentByConfigName(SUMMARY_AGENT_CONFIG_ID, { context: req.body.context, question: req.body.question, }); @@ -53,13 +53,13 @@ export function registerSummaryAssistantRoutes( // only get it by searching on name since it is not stored in agent config. if (req.body.insightType === 'os_insight') { if (!osInsightAgentId) { - osInsightAgentId = await getAgent(OS_INSIGHT_AGENT_CONFIG_ID, client); + osInsightAgentId = await getAgentIdByConfigName(OS_INSIGHT_AGENT_CONFIG_ID, client); } insightAgentIdExists = !!osInsightAgentId; } else if (req.body.insightType === 'user_insight') { if (req.body.type === 'alerts') { if (!userInsightAgentId) { - userInsightAgentId = await searchAgentByName('KB_For_Alert_Insight', client); + userInsightAgentId = await searchAgent({ name: 'KB_For_Alert_Insight' }, client); } } insightAgentIdExists = !!userInsightAgentId; @@ -127,7 +127,10 @@ export function registerSummaryAssistantRoutes( ); } -export function registerData2SummaryRoutes(router: IRouter, assistantService: AssistantServiceSetup) { +export function registerData2SummaryRoutes( + router: IRouter, + assistantService: AssistantServiceSetup +) { router.post( { path: SUMMARY_ASSISTANT_API.DATA2SUMMARY, @@ -147,13 +150,16 @@ export function registerData2SummaryRoutes(router: IRouter, assistantService: As router.handleLegacyErrors(async (context, req, res) => { const assistantClient = assistantService.getScopedClient(req, context); try { - const response = await assistantClient.executeAgentByName(DATA2SUMMARY_AGENT_CONFIG_ID, { - sample_data: req.body.sample_data, - total_count: req.body.total_count, - sample_count: req.body.sample_count, - ppl: req.body.ppl, - question: req.body.question, - }); + const response = await assistantClient.executeAgentByConfigName( + DATA2SUMMARY_AGENT_CONFIG_ID, + { + sample_data: req.body.sample_data, + total_count: req.body.total_count, + sample_count: req.body.sample_count, + ppl: req.body.ppl, + question: req.body.question, + } + ); const result = response.body.inference_results[0].output[0].result; return res.ok({ body: result }); } catch (e) { diff --git a/server/routes/text2viz_routes.ts b/server/routes/text2viz_routes.ts index 760ecc45..314f0ce2 100644 --- a/server/routes/text2viz_routes.ts +++ b/server/routes/text2viz_routes.ts @@ -40,7 +40,7 @@ export function registerText2VizRoutes(router: IRouter, assistantService: Assist router.handleLegacyErrors(async (context, req, res) => { const assistantClient = assistantService.getScopedClient(req, context); try { - const response = await assistantClient.executeAgentByName(TEXT2VEGA_AGENT_CONFIG_ID, { + const response = await assistantClient.executeAgentByConfigName(TEXT2VEGA_AGENT_CONFIG_ID, { input_question: req.body.input_question, input_instruction: req.body.input_instruction, ppl: req.body.ppl, @@ -84,7 +84,7 @@ export function registerText2VizRoutes(router: IRouter, assistantService: Assist router.handleLegacyErrors(async (context, req, res) => { const assistantClient = assistantService.getScopedClient(req, context); try { - const response = await assistantClient.executeAgentByName(TEXT2PPL_AGENT_CONFIG_ID, { + const response = await assistantClient.executeAgentByConfigName(TEXT2PPL_AGENT_CONFIG_ID, { question: req.body.question, index: req.body.index, }); diff --git a/server/services/assistant_client.ts b/server/services/assistant_client.ts index 30593db8..d4584eac 100644 --- a/server/services/assistant_client.ts +++ b/server/services/assistant_client.ts @@ -11,7 +11,7 @@ import { RequestHandlerContext, } from '../../../../src/core/server'; import { ML_COMMONS_BASE_API } from '../utils/constants'; -import { getAgent } from '../routes/get_agent'; +import { getAgentIdByConfigName } from '../routes/get_agent'; interface AgentExecuteResponse { inference_results: Array<{ @@ -59,16 +59,21 @@ export class AssistantClient { return response as ApiResponse; }; - executeAgentByName = async ( - agentName: string, + executeAgentByConfigName = async ( + agentConfigName: string, // eslint-disable-next-line @typescript-eslint/no-explicit-any parameters: Record ) => { - const client = await this.getOpenSearchClient(); - const agentId = await getAgent(agentName, client.transport); + const agentId = await this.getAgentIdByConfigName(agentConfigName); return this.executeAgent(agentId, parameters); }; + getAgentIdByConfigName = async (agentConfigName: string) => { + const client = await this.getOpenSearchClient(); + const agentId = await getAgentIdByConfigName(agentConfigName, client.transport); + return agentId; + }; + private async getOpenSearchClient() { if (!this.client) { let client = this.context.core.opensearch.client.asCurrentUser; diff --git a/server/services/chat/olly_chat_service.ts b/server/services/chat/olly_chat_service.ts index 7d890706..49d8cf60 100644 --- a/server/services/chat/olly_chat_service.ts +++ b/server/services/chat/olly_chat_service.ts @@ -8,7 +8,7 @@ import { OpenSearchClient } from '../../../../../src/core/server'; import { IMessage, IInput } from '../../../common/types/chat_saved_object_attributes'; import { ChatService } from './chat_service'; import { ML_COMMONS_BASE_API, ROOT_AGENT_CONFIG_ID } from '../../utils/constants'; -import { getAgent } from '../../routes/get_agent'; +import { getAgentIdByConfigName } from '../../routes/get_agent'; interface AgentRunPayload { question?: string; @@ -27,7 +27,7 @@ export class OllyChatService implements ChatService { constructor(private readonly opensearchClientTransport: OpenSearchClient['transport']) {} private async getRootAgent(): Promise { - return await getAgent(ROOT_AGENT_CONFIG_ID, this.opensearchClientTransport); + return await getAgentIdByConfigName(ROOT_AGENT_CONFIG_ID, this.opensearchClientTransport); } private async requestAgentRun(payload: AgentRunPayload) {