Skip to content

Commit

Permalink
Expose a general function for agent execution (#268) (#277)
Browse files Browse the repository at this point in the history
* feat: add assistant client to public and server



* tweaks



* fix path



* update changelog



* feat: support both execute agent by name and by id



* fix(public): support execute agent by name or id



---------

Signed-off-by: Yulong Ruan <[email protected]>
Signed-off-by: gaobinlong <[email protected]>
Co-authored-by: gaobinlong <[email protected]>
  • Loading branch information
ruanyl and gaobinlong authored Sep 6, 2024
1 parent e80d85e commit 94dc4a2
Show file tree
Hide file tree
Showing 12 changed files with 268 additions and 39 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

### Unreleased
- fix: make sure $schema always added to LLM generated vega json object([252](https://github.com/opensearch-project/dashboards-assistant/pull/252))
- feat: expose a general function for agent execution([268](https://github.com/opensearch-project/dashboards-assistant/pull/268))
- Fix CVE-2024-4067 ([#269](https://github.com/opensearch-project/dashboards-assistant/pull/269))

### 📈 Features/Enhancements

- Add support for registerMessageParser ([#5](https://github.com/opensearch-project/dashboards-assistant/pull/5))
Expand Down
4 changes: 4 additions & 0 deletions common/constants/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ export const TEXT2VIZ_API = {
TEXT2VEGA: `${API_BASE}/text2vega`,
};

export const AGENT_API = {
EXECUTE: `${API_BASE}/agent/_execute`,
};

export const NOTEBOOK_API = {
CREATE_NOTEBOOK: `${NOTEBOOK_PREFIX}/note`,
SET_PARAGRAPH: `${NOTEBOOK_PREFIX}/set_paragraphs/`,
Expand Down
7 changes: 7 additions & 0 deletions public/plugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ import {
import { ConfigSchema } from '../common/types/config';
import { DataSourceService } from './services/data_source_service';
import { ASSISTANT_API, DEFAULT_USER_NAME } from '../common/constants/llm';
import { IncontextInsightProps } from './components/incontext_insight';
import { AssistantService } from './services/assistant_service';

export const [getCoreStart, setCoreStart] = createGetterSetter<CoreStart>('CoreStart');

Expand Down Expand Up @@ -71,6 +73,7 @@ export class AssistantPlugin
incontextInsightRegistry: IncontextInsightRegistry | undefined;
private dataSourceService: DataSourceService;
private resetChatSubscription: Subscription | undefined;
private assistantService = new AssistantService();

constructor(initializerContext: PluginInitializerContext) {
this.config = initializerContext.config.get<ConfigSchema>();
Expand All @@ -81,6 +84,7 @@ export class AssistantPlugin
core: CoreSetup<AssistantPluginStartDependencies>,
setupDeps: AssistantPluginSetupDependencies
): AssistantSetup {
this.assistantService.setup();
this.incontextInsightRegistry = new IncontextInsightRegistry();
setIncontextInsightRegistry(this.incontextInsightRegistry);
const messageRenderers: Record<string, MessageRenderer> = {};
Expand Down Expand Up @@ -211,17 +215,20 @@ export class AssistantPlugin
}

public start(core: CoreStart): AssistantStart {
const assistantServiceStart = this.assistantService.start(core.http);
setCoreStart(core);
setChrome(core.chrome);
setNotifications(core.notifications);

return {
dataSource: this.dataSourceService.start(),
assistantClient: assistantServiceStart.client,
};
}

public stop() {
this.dataSourceService.stop();
this.assistantService.stop();
this.resetChatSubscription?.unsubscribe();
}
}
35 changes: 35 additions & 0 deletions public/services/assistant_client.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { API_BASE } from '../../common/constants/llm';
import { HttpSetup } from '../../../../src/core/public';

interface Options {
dataSourceId?: string;
}

export class AssistantClient {
constructor(private http: HttpSetup) {}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
executeAgent = (agentId: string, parameters: Record<string, any>, options?: Options) => {
return this.http.fetch({
method: 'POST',
path: `${API_BASE}/agent/_execute`,
body: JSON.stringify(parameters),
query: { dataSourceId: options?.dataSourceId, agentId },
});
};

// eslint-disable-next-line @typescript-eslint/no-explicit-any
executeAgentByName = (agentName: string, parameters: Record<string, any>, options?: Options) => {
return this.http.fetch({
method: 'POST',
path: `${API_BASE}/agent/_execute`,
body: JSON.stringify(parameters),
query: { dataSourceId: options?.dataSourceId, agentName },
});
};
}
26 changes: 26 additions & 0 deletions public/services/assistant_service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { HttpSetup } from '../../../../src/core/public';
import { AssistantClient } from './assistant_client';

export interface AssistantServiceStart {
client: AssistantClient;
}

export class AssistantService {
constructor() {}

setup() {}

start(http: HttpSetup): AssistantServiceStart {
const assistantClient = new AssistantClient(http);
return {
client: assistantClient,
};
}

stop() {}
}
2 changes: 2 additions & 0 deletions public/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
} from '../../../src/plugins/visualizations/public';
import { DataPublicPluginSetup, DataPublicPluginStart } from '../../../src/plugins/data/public';
import { AppMountParameters, CoreStart } from '../../../src/core/public';
import { AssistantClient } from './services/assistant_client';

export interface RenderProps {
props: MessageContentProps;
Expand Down Expand Up @@ -67,6 +68,7 @@ export interface AssistantSetup {

export interface AssistantStart {
dataSource: DataSourceServiceContract;
assistantClient: AssistantClient;
}

export type StartServices = CoreStart &
Expand Down
15 changes: 13 additions & 2 deletions server/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ import { BasicInputOutputParser } from './parsers/basic_input_output_parser';
import { VisualizationCardParser } from './parsers/visualization_card_parser';
import { registerChatRoutes } from './routes/chat_routes';
import { registerText2VizRoutes } from './routes/text2viz_routes';
import { AssistantService } from './services/assistant_service';
import { registerAgentRoutes } from './routes/agent_routes';

export class AssistantPlugin implements Plugin<AssistantPluginSetup, AssistantPluginStart> {
private readonly logger: Logger;
private messageParsers: MessageParser[] = [];
private assistantService = new AssistantService();

constructor(private readonly initializerContext: PluginInitializerContext) {
this.logger = initializerContext.logger.get();
Expand All @@ -33,6 +36,8 @@ export class AssistantPlugin implements Plugin<AssistantPluginSetup, AssistantPl
.pipe(first())
.toPromise();

const assistantServiceSetup = this.assistantService.setup();

const router = core.http.createRouter();

core.http.registerRouteHandlerContext('assistant_plugin', () => {
Expand All @@ -42,6 +47,8 @@ export class AssistantPlugin implements Plugin<AssistantPluginSetup, AssistantPl
};
});

registerAgentRoutes(router, assistantServiceSetup);

// Register server side APIs
registerChatRoutes(router, {
messageParsers: this.messageParsers,
Expand All @@ -50,7 +57,7 @@ export class AssistantPlugin implements Plugin<AssistantPluginSetup, AssistantPl

// Register router for text to visualization
if (config.next.enabled) {
registerText2VizRoutes(router);
registerText2VizRoutes(router, assistantServiceSetup);
}

core.capabilities.registerProvider(() => ({
Expand All @@ -72,6 +79,7 @@ export class AssistantPlugin implements Plugin<AssistantPluginSetup, AssistantPl
registerMessageParser(VisualizationCardParser);

return {
assistantService: assistantServiceSetup,
registerMessageParser,
removeMessageParser: (parserId: MessageParser['id']) => {
const findIndex = this.messageParsers.findIndex((item) => item.id === parserId);
Expand All @@ -86,8 +94,11 @@ export class AssistantPlugin implements Plugin<AssistantPluginSetup, AssistantPl

public start(core: CoreStart) {
this.logger.debug('Assistant: Started');
this.assistantService.start();
return {};
}

public stop() {}
public stop() {
this.assistantService.stop();
}
}
43 changes: 43 additions & 0 deletions server/routes/agent_routes.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { schema } from '@osd/config-schema';
import { IRouter } from '../../../../src/core/server';
import { AGENT_API } from '../../common/constants/llm';
import { AssistantServiceSetup } from '../services/assistant_service';

export function registerAgentRoutes(router: IRouter, assistantService: AssistantServiceSetup) {
router.post(
{
path: AGENT_API.EXECUTE,
validate: {
body: schema.any(),
query: schema.oneOf([
schema.object({
dataSourceId: schema.maybe(schema.string()),
agentId: schema.string(),
}),
schema.object({
dataSourceId: schema.maybe(schema.string()),
agentName: schema.string(),
}),
]),
},
},
router.handleLegacyErrors(async (context, req, res) => {
try {
const assistantClient = assistantService.getScopedClient(req, context);
if ('agentId' in req.query) {
const response = await assistantClient.executeAgent(req.query.agentId, req.body);
return res.ok({ body: response });
}
const response = await assistantClient.executeAgentByName(req.query.agentName, req.body);
return res.ok({ body: response });
} catch (e) {
return res.internalError();
}
})
);
}
53 changes: 16 additions & 37 deletions server/routes/text2viz_routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
import { schema } from '@osd/config-schema';
import { IRouter } from '../../../../src/core/server';
import { TEXT2VIZ_API } from '../../common/constants/llm';
import { getOpenSearchClientTransport } from '../utils/get_opensearch_client_transport';
import { ML_COMMONS_BASE_API } from '../utils/constants';
import { getAgent } from './get_agent';
import { AssistantServiceSetup } from '../services/assistant_service';

const TEXT2VEGA_AGENT_CONFIG_ID = 'text2vega';
const TEXT2PPL_AGENT_CONFIG_ID = 'text2ppl';

export function registerText2VizRoutes(router: IRouter) {
export function registerText2VizRoutes(router: IRouter, assistantService: AssistantServiceSetup) {
router.post(
{
path: TEXT2VIZ_API.TEXT2VEGA,
Expand All @@ -30,25 +28,15 @@ export function registerText2VizRoutes(router: IRouter) {
},
},
router.handleLegacyErrors(async (context, req, res) => {
const client = await getOpenSearchClientTransport({
context,
dataSourceId: req.query.dataSourceId,
});
const agentId = await getAgent(TEXT2VEGA_AGENT_CONFIG_ID, client);
const response = await client.request({
method: 'POST',
path: `${ML_COMMONS_BASE_API}/agents/${agentId}/_execute`,
body: {
parameters: {
input: req.body.input,
ppl: req.body.ppl,
dataSchema: req.body.dataSchema,
sampleData: req.body.sampleData,
},
},
});

const assistantClient = assistantService.getScopedClient(req, context);
try {
const response = await assistantClient.executeAgentByName(TEXT2VEGA_AGENT_CONFIG_ID, {
input: req.body.input,
ppl: req.body.ppl,
dataSchema: req.body.dataSchema,
sampleData: req.body.sampleData,
});

// let result = response.body.inference_results[0].output[0].dataAsMap;
let result = JSON.parse(response.body.inference_results[0].output[0].result);
// sometimes llm returns {response: <schema>} instead of <schema>
Expand Down Expand Up @@ -80,22 +68,13 @@ export function registerText2VizRoutes(router: IRouter) {
},
},
router.handleLegacyErrors(async (context, req, res) => {
const client = await getOpenSearchClientTransport({
context,
dataSourceId: req.query.dataSourceId,
});
const agentId = await getAgent(TEXT2PPL_AGENT_CONFIG_ID, client);
const response = await client.request({
method: 'POST',
path: `${ML_COMMONS_BASE_API}/agents/${agentId}/_execute`,
body: {
parameters: {
question: req.body.question,
index: req.body.index,
},
},
});
const assistantClient = assistantService.getScopedClient(req, context);
try {
const response = await assistantClient.executeAgentByName(TEXT2PPL_AGENT_CONFIG_ID, {
question: req.body.question,
index: req.body.index,
});

const result = JSON.parse(response.body.inference_results[0].output[0].result);
return res.ok({ body: result });
} catch (e) {
Expand Down
Loading

0 comments on commit 94dc4a2

Please sign in to comment.