diff --git a/common/constants/llm.ts b/common/constants/llm.ts index 48be672f..48da38fa 100644 --- a/common/constants/llm.ts +++ b/common/constants/llm.ts @@ -43,3 +43,6 @@ export const NOTEBOOK_API = { export const DEFAULT_USER_NAME = 'User'; export const TEXT2VEGA_INPUT_SIZE_LIMIT = 400; + +export const TEXT2VEGA_AGENT_CONFIG_ID = 'os_text2vega'; +export const TEXT2PPL_AGENT_CONFIG_ID = 'os_query_assist_ppl'; diff --git a/public/components/visualization/source_selector.tsx b/public/components/visualization/source_selector.tsx index 09dc5aca..f053a074 100644 --- a/public/components/visualization/source_selector.tsx +++ b/public/components/visualization/source_selector.tsx @@ -14,6 +14,10 @@ import { DataSourceOption, } from '../../../../../src/plugins/data/public'; import { StartServices } from '../../types'; +import { TEXT2VEGA_AGENT_CONFIG_ID } from '../../../common/constants/llm'; +import { getAssistantService } from '../../services'; + +const DEFAULT_DATA_SOURCE_TYPE = 'DEFAULT_INDEX_PATTERNS'; export const SourceSelector = ({ selectedSourceId, @@ -71,6 +75,64 @@ export const SourceSelector = ({ [onChange] ); + const onSetDataSourceOptions = useCallback( + async (options: DataSourceGroup[]) => { + // Only support opensearch default data source + const indexPatternOptions = options.find( + (item) => item.groupType === DEFAULT_DATA_SOURCE_TYPE + ); + const supportedDataSources = currentDataSources.filter( + (dataSource) => dataSource.getType() === DEFAULT_DATA_SOURCE_TYPE + ); + + if (!indexPatternOptions || supportedDataSources.length === 0) { + return; + } + + // Group index pattern ids by data source id + const dataSourceIdToIndexPatternIds: Record = {}; + const promises = supportedDataSources.map(async (dataSource) => { + const { dataSets } = await dataSource.getDataSet(); + if (Array.isArray(dataSets)) { + /** + * id: the index pattern id + * dataSourceId: the data source id + */ + for (const { id, dataSourceId = 'DEFAULT' } of dataSets) { + if (!dataSourceIdToIndexPatternIds[dataSourceId]) { + dataSourceIdToIndexPatternIds[dataSourceId] = []; + } + dataSourceIdToIndexPatternIds[dataSourceId].push(id); + } + } + }); + await Promise.allSettled(promises); + + const assistantService = getAssistantService(); + /** + * Check each data source to see if text to vega agent is configured or not + * If not configured, disable the corresponding index pattern from the selection list + */ + Object.keys(dataSourceIdToIndexPatternIds).forEach(async (key) => { + const res = await assistantService.client.agentConfigExists(TEXT2VEGA_AGENT_CONFIG_ID, { + dataSourceId: key !== 'DEFAULT' ? key : undefined, + }); + if (!res.exists) { + dataSourceIdToIndexPatternIds[key].forEach((indexPatternId) => { + indexPatternOptions.options.forEach((option) => { + if (option.value === indexPatternId) { + option.disabled = true; + } + }); + }); + } + }); + + setDataSourceOptions([indexPatternOptions]); + }, + [currentDataSources] + ); + const handleGetDataSetError = useCallback( () => (error: Error) => { toasts.addError(error, { @@ -91,7 +153,7 @@ export const SourceSelector = ({ { const getInputSection = () => { return ( <> - + setSelectedSource(ds.value)} diff --git a/server/routes/agent_routes.ts b/server/routes/agent_routes.ts index 8fb7abd4..abbbd087 100644 --- a/server/routes/agent_routes.ts +++ b/server/routes/agent_routes.ts @@ -50,7 +50,7 @@ export function registerAgentRoutes(router: IRouter, assistantService: Assistant validate: { query: schema.oneOf([ schema.object({ - dataSourceId: schema.string(), + dataSourceId: schema.maybe(schema.string()), agentConfigName: schema.string(), }), ]), diff --git a/server/routes/text2viz_routes.ts b/server/routes/text2viz_routes.ts index 314f0ce2..bf6b9cf9 100644 --- a/server/routes/text2viz_routes.ts +++ b/server/routes/text2viz_routes.ts @@ -5,12 +5,14 @@ import { schema } from '@osd/config-schema'; import { IRouter } from '../../../../src/core/server'; -import { TEXT2VEGA_INPUT_SIZE_LIMIT, TEXT2VIZ_API } from '../../common/constants/llm'; +import { + TEXT2PPL_AGENT_CONFIG_ID, + TEXT2VEGA_AGENT_CONFIG_ID, + TEXT2VEGA_INPUT_SIZE_LIMIT, + TEXT2VIZ_API, +} from '../../common/constants/llm'; import { AssistantServiceSetup } from '../services/assistant_service'; -const TEXT2VEGA_AGENT_CONFIG_ID = 'os_text2vega'; -const TEXT2PPL_AGENT_CONFIG_ID = 'os_query_assist_ppl'; - const inputSchema = schema.string({ maxLength: TEXT2VEGA_INPUT_SIZE_LIMIT, validate(value) { @@ -48,22 +50,39 @@ export function registerText2VizRoutes(router: IRouter, assistantService: Assist sampleData: req.body.sampleData, }); - // let result = response.body.inference_results[0].output[0].dataAsMap; - let result = JSON.parse(response.body.inference_results[0].output[0].result); - // sometimes llm returns {response: } instead of - if (result.response) { - result = JSON.parse(result.response); + let textContent = response.body.inference_results[0].output[0].result; + + // extra content between tag + const startTag = ''; + const endTag = ''; + + const startIndex = textContent.indexOf(startTag); + const endIndex = textContent.indexOf(endTag); + + if (startIndex !== -1 && endIndex !== -1 && startIndex < endIndex) { + // Extract the content between the tags + textContent = textContent.substring(startIndex + startTag.length, endIndex).trim(); } - // Sometimes the response contains width and height which is not needed, here delete the these fields - delete result.width; - delete result.height; - // make sure $schema field always been added, sometimes, LLM 'forgot' to add this field - result.$schema = 'https://vega.github.io/schema/vega-lite/v5.json'; + // extract json object + const jsonMatch = textContent.match(/\{.*\}/s); + if (jsonMatch) { + let result = JSON.parse(jsonMatch[0]); + // sometimes llm returns {response: } instead of + if (result.response) { + result = JSON.parse(result.response); + } + // Sometimes the response contains width and height which is not needed, here delete the these fields + delete result.width; + delete result.height; - return res.ok({ body: result }); + // make sure $schema field always been added, sometimes, LLM 'forgot' to add this field + result.$schema = 'https://vega.github.io/schema/vega-lite/v5.json'; + return res.ok({ body: result }); + } + return res.badRequest(); } catch (e) { - return res.internalError(); + return res.badRequest(); } }) ); @@ -92,7 +111,7 @@ export function registerText2VizRoutes(router: IRouter, assistantService: Assist const result = JSON.parse(response.body.inference_results[0].output[0].result); return res.ok({ body: result }); } catch (e) { - return res.internalError(); + return res.badRequest(); } }) );