Skip to content

Commit

Permalink
[Feature] Add rootAgentId config in dashboard yml file (#78)
Browse files Browse the repository at this point in the history
* feat: export some types

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: consume the config

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: add unit test

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: add unit test for regenerate route

Signed-off-by: SuZhou-Joe <[email protected]>

* fix: unit test error

Signed-off-by: SuZhou-Joe <[email protected]>

* fix: typo

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: add warning when chatbot enabled without a root agent id

Signed-off-by: SuZhou-Joe <[email protected]>

---------

Signed-off-by: SuZhou-Joe <[email protected]>
  • Loading branch information
SuZhou-Joe authored Dec 21, 2023
1 parent a0a958f commit a474220
Show file tree
Hide file tree
Showing 14 changed files with 303 additions and 46 deletions.
4 changes: 0 additions & 4 deletions public/chat_header_button.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ export const HeaderChatButton: React.FC<HeaderChatButtonProps> = (props) => {
const [inputFocus, setInputFocus] = useState(false);
const flyoutFullScreen = chatSize === 'fullscreen';
const inputRef = useRef<HTMLInputElement>(null);
const [rootAgentId, setRootAgentId] = useState<string>(
new URL(window.location.href).searchParams.get('agent_id') || ''
);

if (!flyoutLoaded && flyoutVisible) flyoutLoaded = true;

Expand Down Expand Up @@ -80,7 +77,6 @@ export const HeaderChatButton: React.FC<HeaderChatButtonProps> = (props) => {
setTitle,
traceId,
setTraceId,
rootAgentId,
}),
[
appId,
Expand Down
1 change: 0 additions & 1 deletion public/contexts/chat_context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ export interface IChatContext {
setTitle: React.Dispatch<React.SetStateAction<string | undefined>>;
traceId?: string;
setTraceId: React.Dispatch<React.SetStateAction<string | undefined>>;
rootAgentId?: string;
}
export const ChatContext = React.createContext<IChatContext | null>(null);

Expand Down
5 changes: 0 additions & 5 deletions public/hooks/use_chat_actions.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ describe('useChatActions hook', () => {
const setTraceIdMock = jest.fn();

const chatContextMock = {
rootAgentId: 'root_agent_id_mock',
selectedTabId: 'chat',
setSessionId: jest.fn(),
setTitle: jest.fn(),
Expand Down Expand Up @@ -109,7 +108,6 @@ describe('useChatActions hook', () => {
// it should call send message api
expect(httpMock.post).toHaveBeenCalledWith(ASSISTANT_API.SEND_MESSAGE, {
body: JSON.stringify({
rootAgentId: 'root_agent_id_mock',
messages: [],
input: INPUT_MESSAGE,
}),
Expand Down Expand Up @@ -173,7 +171,6 @@ describe('useChatActions hook', () => {
// sending message with the suggestion
expect(httpMock.post).toHaveBeenCalledWith(ASSISTANT_API.SEND_MESSAGE, {
body: JSON.stringify({
rootAgentId: 'root_agent_id_mock',
messages: [],
input: { type: 'input', content: 'message that send as input', contentType: 'text' },
}),
Expand Down Expand Up @@ -255,7 +252,6 @@ describe('useChatActions hook', () => {
expect(httpMock.put).toHaveBeenCalledWith(ASSISTANT_API.REGENERATE, {
body: JSON.stringify({
sessionId: 'session_id_mock',
rootAgentId: 'root_agent_id_mock',
interactionId: 'interaction_id_mock',
}),
});
Expand All @@ -281,7 +277,6 @@ describe('useChatActions hook', () => {
expect(httpMock.put).toHaveBeenCalledWith(ASSISTANT_API.REGENERATE, {
body: JSON.stringify({
sessionId: 'session_id_mock',
rootAgentId: 'root_agent_id_mock',
interactionId: 'interaction_id_mock',
}),
});
Expand Down
2 changes: 0 additions & 2 deletions public/hooks/use_chat_actions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ export const useChatActions = (): AssistantActions => {
// do not send abort signal to http client to allow LLM call run in background
body: JSON.stringify({
sessionId: chatContext.sessionId,
rootAgentId: chatContext.rootAgentId,
...(!chatContext.sessionId && { messages: chatState.messages }), // include all previous messages for new chats
input,
}),
Expand Down Expand Up @@ -168,7 +167,6 @@ export const useChatActions = (): AssistantActions => {
const response = await core.services.http.put(`${ASSISTANT_API.REGENERATE}`, {
body: JSON.stringify({
sessionId: chatContext.sessionId,
rootAgentId: chatContext.rootAgentId,
interactionId,
}),
});
Expand Down
2 changes: 2 additions & 0 deletions public/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ export { AssistantPlugin as Plugin };
export function plugin(initializerContext: PluginInitializerContext) {
return new AssistantPlugin(initializerContext);
}

export { AssistantSetup } from './types';
3 changes: 2 additions & 1 deletion public/plugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ interface PublicConfig {
chat: {
// whether chat feature is enabled, UI should hide if false
enabled: boolean;
rootAgentId?: string;
};
}

Expand Down Expand Up @@ -74,7 +75,7 @@ export class AssistantPlugin
const checkAccess = (account: Awaited<ReturnType<typeof getAccount>>) =>
account.data.roles.some((role) => ['all_access', 'assistant_user'].includes(role));

if (this.config.chat.enabled) {
if (this.config.chat.enabled && this.config.chat.rootAgentId) {
core.getStartServices().then(async ([coreStart, startDeps]) => {
const CoreContext = createOpenSearchDashboardsReactContext<AssistantServices>({
...coreStart,
Expand Down
3 changes: 2 additions & 1 deletion server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ export function plugin(initializerContext: PluginInitializerContext) {
return new AssistantPlugin(initializerContext);
}

export { AssistantPluginSetup, AssistantPluginStart } from './types';
export { AssistantPluginSetup, AssistantPluginStart, MessageParser } from './types';

const assistantConfig = {
schema: schema.object({
chat: schema.object({
enabled: schema.boolean({ defaultValue: false }),
rootAgentId: schema.maybe(schema.string()),
}),
}),
};
Expand Down
13 changes: 12 additions & 1 deletion server/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { setupRoutes } from './routes/index';
import { AssistantPluginSetup, AssistantPluginStart, MessageParser } from './types';
import { BasicInputOutputParser } from './parsers/basic_input_output_parser';
import { VisualizationCardParser } from './parsers/visualization_card_parser';
import { AgentIdNotFoundError } from './routes/chat_routes';

export class AssistantPlugin implements Plugin<AssistantPluginSetup, AssistantPluginStart> {
private readonly logger: Logger;
Expand All @@ -25,12 +26,21 @@ export class AssistantPlugin implements Plugin<AssistantPluginSetup, AssistantPl
this.logger = initializerContext.logger.get();
}

public async setup(core: CoreSetup) {
public async setup(core: CoreSetup): Promise<AssistantPluginSetup> {
this.logger.debug('Assistant: Setup');
const config = await this.initializerContext.config
.create<AssistantConfig>()
.pipe(first())
.toPromise();

/**
* Check if user enable the chat without specifying a root agent id.
* If so, gives a warning for guidance.
*/
if (config.chat.enabled && !config.chat.rootAgentId) {
this.logger.warn(AgentIdNotFoundError);
}

const router = core.http.createRouter();

core.http.registerRouteHandlerContext('assistant_plugin', () => {
Expand All @@ -43,6 +53,7 @@ export class AssistantPlugin implements Plugin<AssistantPluginSetup, AssistantPl
// Register server side APIs
setupRoutes(router, {
messageParsers: this.messageParsers,
rootAgentId: config.chat.rootAgentId,
});

core.capabilities.registerProvider(() => ({
Expand Down
29 changes: 23 additions & 6 deletions server/routes/chat_routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ const llmRequestRoute = {
body: schema.object({
sessionId: schema.maybe(schema.string()),
messages: schema.maybe(schema.arrayOf(schema.any())),
rootAgentId: schema.string(),
input: schema.object({
type: schema.literal('input'),
context: schema.object({
Expand All @@ -37,6 +36,9 @@ const llmRequestRoute = {
};
export type LLMRequestSchema = TypeOf<typeof llmRequestRoute.validate.body>;

export const AgentIdNotFoundError =
'rootAgentId is required, please specify one in opensearch_dashboards.yml';

const getSessionRoute = {
path: `${ASSISTANT_API.SESSION}/{sessionId}`,
validate: {
Expand All @@ -62,7 +64,6 @@ const regenerateRoute = {
validate: {
body: schema.object({
sessionId: schema.string(),
rootAgentId: schema.string(),
interactionId: schema.string(),
}),
},
Expand Down Expand Up @@ -142,7 +143,11 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise<IOpenSearchDashboardsResponse<HttpResponsePayload | ResponseError>> => {
const { messages = [], input, sessionId: sessionIdInRequestBody, rootAgentId } = request.body;
if (!routeOptions.rootAgentId) {
context.assistant_plugin.logger.error(AgentIdNotFoundError);
return response.custom({ statusCode: 400, body: AgentIdNotFoundError });
}
const { messages = [], input, sessionId: sessionIdInRequestBody } = request.body;
const storageService = createStorageService(context);
const chatService = createChatService();

Expand All @@ -153,7 +158,12 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
*/
try {
outputs = await chatService.requestLLM(
{ messages, input, sessionId: sessionIdInRequestBody, rootAgentId },
{
messages,
input,
sessionId: sessionIdInRequestBody,
rootAgentId: routeOptions.rootAgentId,
},
context
);
} catch (error) {
Expand Down Expand Up @@ -314,7 +324,11 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise<IOpenSearchDashboardsResponse<HttpResponsePayload | ResponseError>> => {
const { sessionId, rootAgentId, interactionId } = request.body;
if (!routeOptions.rootAgentId) {
context.assistant_plugin.logger.error(AgentIdNotFoundError);
return response.custom({ statusCode: 400, body: AgentIdNotFoundError });
}
const { sessionId, interactionId } = request.body;
const storageService = createStorageService(context);
const chatService = createChatService();

Expand All @@ -324,7 +338,10 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
* Get final answer from Agent framework
*/
try {
outputs = await chatService.regenerate({ sessionId, rootAgentId, interactionId }, context);
outputs = await chatService.regenerate(
{ sessionId, rootAgentId: routeOptions.rootAgentId, interactionId },
context
);
} catch (error) {
context.assistant_plugin.logger.error(error);
}
Expand Down
Loading

0 comments on commit a474220

Please sign in to comment.