Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: some optimization for customized render function #94

Merged
merged 12 commits into from
Jan 11, 2024
8 changes: 3 additions & 5 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@

### Check List
- [ ] New functionality includes testing.
- [ ] All tests pass, including unit test, integration test and doctest
- [ ] New functionality has been documented.
- [ ] New functionality has javadoc added
- [ ] New functionality has user manual doc added
- [ ] Commits are signed per the DCO using --signoff
- [ ] All tests pass, including unit test, integration test.
- [ ] New functionality has user manual doc added.
- [ ] Commits are signed per the DCO using --signoff.

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check [here](https://github.com/opensearch-project/OpenSearch/blob/main/CONTRIBUTING.md#developer-certificate-of-origin).
4 changes: 2 additions & 2 deletions common/types/chat_saved_object_attributes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ export interface IOutput {
type: 'output';
traceId?: string; // used for tracing agent calls
toolsUsed?: string[];
// TODO: ppl_visualization type may need to be removed in the PR which replaces ppl query render from visualization to data grid. @suzhou
contentType: 'error' | 'markdown' | 'visualization' | 'ppl_visualization' | 'ppl_data_grid';
contentType: 'error' | 'markdown' | 'visualization' | string;
content: string;
suggestedActions?: ISuggestedAction[];
messageId?: string;
fullWidth?: boolean;
}
export type IMessage = IInput | IOutput;

Expand Down
12 changes: 6 additions & 6 deletions public/chat_header_button.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ describe('<HeaderChatButton />', () => {
<HeaderChatButton
application={applicationStart}
userHasAccess={true}
contentRenderers={{}}
messageRenderers={{}}
actionExecutors={{}}
assistantActions={{} as AssistantActions}
currentAccount={{ username: 'test_user', tenant: 'test_tenant' }}
Expand Down Expand Up @@ -102,7 +102,7 @@ describe('<HeaderChatButton />', () => {
<HeaderChatButton
application={applicationServiceMock.createStartContract()}
userHasAccess={true}
contentRenderers={{}}
messageRenderers={{}}
actionExecutors={{}}
assistantActions={{} as AssistantActions}
currentAccount={{ username: 'test_user', tenant: 'test_tenant' }}
Expand All @@ -120,7 +120,7 @@ describe('<HeaderChatButton />', () => {
<HeaderChatButton
application={applicationServiceMock.createStartContract()}
userHasAccess={true}
contentRenderers={{}}
messageRenderers={{}}
actionExecutors={{}}
assistantActions={{} as AssistantActions}
currentAccount={{ username: 'test_user', tenant: 'test_tenant' }}
Expand All @@ -144,7 +144,7 @@ describe('<HeaderChatButton />', () => {
<HeaderChatButton
application={applicationServiceMock.createStartContract()}
userHasAccess={true}
contentRenderers={{}}
messageRenderers={{}}
actionExecutors={{}}
assistantActions={{} as AssistantActions}
currentAccount={{ username: 'test_user', tenant: 'test_tenant' }}
Expand All @@ -165,7 +165,7 @@ describe('<HeaderChatButton />', () => {
<HeaderChatButton
application={applicationServiceMock.createStartContract()}
userHasAccess={false}
contentRenderers={{}}
messageRenderers={{}}
actionExecutors={{}}
assistantActions={{} as AssistantActions}
currentAccount={{ username: 'test_user', tenant: 'test_tenant' }}
Expand All @@ -179,7 +179,7 @@ describe('<HeaderChatButton />', () => {
<HeaderChatButton
application={applicationServiceMock.createStartContract()}
userHasAccess={false}
contentRenderers={{}}
messageRenderers={{}}
actionExecutors={{}}
assistantActions={{} as AssistantActions}
currentAccount={{ username: 'test_user', tenant: 'test_tenant' }}
Expand Down
8 changes: 4 additions & 4 deletions public/chat_header_button.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ import { SetContext } from './contexts/set_context';
import { ChatStateProvider } from './hooks';
import './index.scss';
import chatIcon from './assets/chat.svg';
import { ActionExecutor, AssistantActions, ContentRenderer, UserAccount, TabId } from './types';
import { ActionExecutor, AssistantActions, MessageRenderer, UserAccount, TabId } from './types';
import { TAB_ID } from './utils/constants';

interface HeaderChatButtonProps {
application: ApplicationStart;
userHasAccess: boolean;
contentRenderers: Record<string, ContentRenderer>;
messageRenderers: Record<string, MessageRenderer>;
actionExecutors: Record<string, ActionExecutor>;
assistantActions: AssistantActions;
currentAccount: UserAccount;
Expand Down Expand Up @@ -70,7 +70,7 @@ export const HeaderChatButton = (props: HeaderChatButtonProps) => {
setFlyoutVisible,
setFlyoutComponent,
userHasAccess: props.userHasAccess,
contentRenderers: props.contentRenderers,
messageRenderers: props.messageRenderers,
actionExecutors: props.actionExecutors,
currentAccount: props.currentAccount,
title,
Expand All @@ -86,7 +86,7 @@ export const HeaderChatButton = (props: HeaderChatButtonProps) => {
selectedTabId,
preSelectedTabId,
props.userHasAccess,
props.contentRenderers,
props.messageRenderers,
props.actionExecutors,
props.currentAccount,
title,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,23 @@ HTMLCollection [
class="euiMarkdownFormat"
>
<div>
<h1>
<h1
id="how-was-this-generated"
>
How was this generated
</h1>


<h4>
<h4
id="question"
>
Question
</h4>


<h4>
<h4
id="result"
>
Result
</h4>

Expand Down Expand Up @@ -66,7 +72,11 @@ HTMLCollection [
viewBox="0 0 16 16"
width="16"
xmlns="http://www.w3.org/2000/svg"
/>
>
<path
d="M5.277 10.088c.02.014.04.03.057.047.582.55 1.134.812 1.666.812.586 0 1.84-.293 3.713-.88L9 6.212V2H7v4.212l-1.723 3.876Zm-.438.987L3.539 14h8.922l-1.32-2.969C9.096 11.677 7.733 12 7 12c-.74 0-1.463-.315-2.161-.925ZM6 2H5V1h6v1h-1v4l3.375 7.594A1 1 0 0 1 12.461 15H3.54a1 1 0 0 1-.914-1.406L6 6V2Z"
/>
</svg>
</span>
<span
class="euiIEFlexWrapFix"
Expand Down
2 changes: 1 addition & 1 deletion public/contexts/__tests__/chat_context.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ describe('useChatContext', () => {
setFlyoutVisible: jest.fn(),
setFlyoutComponent: jest.fn(),
userHasAccess: true,
contentRenderers: {},
messageRenderers: {},
actionExecutors: {},
currentAccount: { username: 'foo', tenant: '' },
setTitle: jest.fn(),
Expand Down
4 changes: 2 additions & 2 deletions public/contexts/chat_context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
*/

import React, { useContext } from 'react';
import { ActionExecutor, ContentRenderer, UserAccount, TabId } from '../types';
import { ActionExecutor, UserAccount, TabId, MessageRenderer } from '../types';

export interface IChatContext {
appId?: string;
Expand All @@ -18,7 +18,7 @@ export interface IChatContext {
setFlyoutVisible: React.Dispatch<React.SetStateAction<boolean>>;
setFlyoutComponent: React.Dispatch<React.SetStateAction<React.ReactNode | null>>;
userHasAccess: boolean;
contentRenderers: Record<string, ContentRenderer>;
messageRenderers: Record<string, MessageRenderer>;
actionExecutors: Record<string, ActionExecutor>;
currentAccount: UserAccount;
title?: string;
Expand Down
2 changes: 1 addition & 1 deletion public/hooks/use_chat_actions.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ describe('useChatActions hook', () => {
flyoutVisible: false,
flyoutFullScreen: false,
userHasAccess: false,
contentRenderers: {},
messageRenderers: {},
currentAccount: {
username: '',
tenant: '',
Expand Down
3 changes: 2 additions & 1 deletion public/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ export function plugin(initializerContext: PluginInitializerContext) {
return new AssistantPlugin(initializerContext);
}

export { AssistantSetup } from './types';
export { AssistantSetup, RenderProps } from './types';
export { IMessage } from '../common/types/chat_saved_object_attributes';
12 changes: 6 additions & 6 deletions public/plugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {
AssistantActions,
AssistantSetup,
AssistantStart,
ContentRenderer,
MessageRenderer,
SetupDependencies,
} from './types';

Expand Down Expand Up @@ -49,7 +49,7 @@ export class AssistantPlugin
core: CoreSetup<AppPluginStartDependencies>,
setupDeps: SetupDependencies
): AssistantSetup {
const contentRenderers: Record<string, ContentRenderer> = {};
const messageRenderers: Record<string, MessageRenderer> = {};
const actionExecutors: Record<string, ActionExecutor> = {};
const assistantActions: AssistantActions = {} as AssistantActions;
/**
Expand Down Expand Up @@ -95,7 +95,7 @@ export class AssistantPlugin
<HeaderChatButton
application={coreStart.application}
userHasAccess={checkAccess(account)}
contentRenderers={contentRenderers}
messageRenderers={messageRenderers}
actionExecutors={actionExecutors}
assistantActions={assistantActions}
currentAccount={{ username, tenant }}
Expand All @@ -107,10 +107,10 @@ export class AssistantPlugin
}

return {
registerContentRenderer: (contentType, render) => {
if (contentType in contentRenderers)
registerMessageRenderer: (contentType, render) => {
if (contentType in messageRenderers)
console.warn(`Content renderer type ${contentType} is already registered.`);
contentRenderers[contentType] = render;
messageRenderers[contentType] = render;
},
registerActionExecutor: (actionType, execute) => {
if (actionType in actionExecutors)
Expand Down
2 changes: 2 additions & 0 deletions public/tabs/chat/messages/message_bubble.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ describe('<MessageBubble />', () => {
type: 'output',
contentType: 'visualization',
content: 'vis_id_mock',
fullWidth: true,
}}
/>
);
Expand All @@ -114,6 +115,7 @@ describe('<MessageBubble />', () => {
type: 'output',
contentType: 'ppl_visualization',
content: 'vis_id_mock',
fullWidth: true,
}}
/>
);
Expand Down
16 changes: 7 additions & 9 deletions public/tabs/chat/messages/message_bubble.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,7 @@ export const MessageBubble: React.FC<MessageBubbleProps> = React.memo((props) =>
);
}

// if (['visualization', 'ppl_visualization'].includes(props.contentType)) {
// return <>{props.children}</>;
// }

const isVisualization = ['visualization', 'ppl_visualization'].includes(
props.message.contentType
);
const fullWidth = props.message.fullWidth;

return (
<EuiFlexGroup
Expand All @@ -154,7 +148,11 @@ export const MessageBubble: React.FC<MessageBubbleProps> = React.memo((props) =>
</EuiFlexItem>
<EuiFlexItem className="llm-chat-bubble-wrapper">
<EuiPanel
style={isVisualization ? { minWidth: '100%' } : {}}
/**
* When using minWidth the content width inside may be larger than the container itself,
* especially in data grid case that the content will change its size according to fullScreen or not.
*/
style={fullWidth ? { width: '100%' } : {}}
hasShadow={false}
hasBorder={false}
paddingSize="l"
Expand All @@ -177,7 +175,7 @@ export const MessageBubble: React.FC<MessageBubbleProps> = React.memo((props) =>
justifyContent="flexStart"
style={{ paddingLeft: 10 }}
>
{!isVisualization && (
{!fullWidth && (
<EuiFlexItem grow={false}>
<EuiCopy textToCopy={props.message.content ?? ''}>
{(copy) => (
Expand Down
37 changes: 19 additions & 18 deletions public/tabs/chat/messages/message_content.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
*/

import React from 'react';
import { render, screen, fireEvent } from '@testing-library/react';
import { render, screen } from '@testing-library/react';
import { MessageContent } from './message_content';
import * as chatContextExports from '../../../contexts/chat_context';

Expand All @@ -15,13 +15,11 @@ jest.mock('../../../components/core_visualization', () => {
});

describe('<MessageContent />', () => {
const pplVisualizationRenderMock = jest.fn();
const customizedRenderMock = jest.fn();

beforeEach(() => {
jest.spyOn(chatContextExports, 'useChatContext').mockReturnValue({
contentRenderers: {
ppl_visualization: pplVisualizationRenderMock,
messageRenderers: {
customized_content_type: customizedRenderMock,
},
});
Expand Down Expand Up @@ -79,19 +77,6 @@ describe('<MessageContent />', () => {
expect(screen.queryAllByText('title')).toHaveLength(1);
});

it('should render ppl visualization', () => {
render(
<MessageContent
message={{
type: 'output',
contentType: 'ppl_visualization',
content: 'mock ppl query',
}}
/>
);
expect(pplVisualizationRenderMock).toHaveBeenCalledWith({ query: 'mock ppl query' });
});

it('should render customized render content', () => {
render(
<MessageContent
Expand All @@ -102,6 +87,22 @@ describe('<MessageContent />', () => {
}}
/>
);
expect(customizedRenderMock).toHaveBeenCalledWith('mock customized content');
expect(customizedRenderMock.mock.calls[0][0]).toMatchInlineSnapshot(`
Object {
"content": "mock customized content",
"contentType": "customized_content_type",
"type": "output",
}
`);
expect(customizedRenderMock.mock.calls[0][1].props).toMatchInlineSnapshot(`
Object {
"message": Object {
"content": "mock customized content",
"contentType": "customized_content_type",
"type": "output",
},
}
`);
expect(customizedRenderMock.mock.calls[0][1].chatContext).not.toBeUndefined();
});
});
17 changes: 7 additions & 10 deletions public/tabs/chat/messages/message_content.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { IMessage } from '../../../../common/types/chat_saved_object_attributes'
import { CoreVisualization } from '../../../components/core_visualization';
import { useChatContext } from '../../../contexts/chat_context';

interface MessageContentProps {
export interface MessageContentProps {
message: IMessage;
}

Expand Down Expand Up @@ -37,18 +37,15 @@ export const MessageContent: React.FC<MessageContentProps> = React.memo((props)
</div>
);

case 'ppl_visualization': {
const render = chatContext.contentRenderers[props.message.contentType];
if (!render) return null;
return (
<div className="llm-chat-visualizations">{render({ query: props.message.content })}</div>
);
}

// content types registered by plugins unknown to assistant
default: {
const message = props.message as IMessage;
return chatContext.contentRenderers[message.contentType]?.(message.content) ?? null;
return (
chatContext.messageRenderers[message.contentType]?.(message, {
props,
chatContext,
}) ?? null
);
}
}
});
Loading
Loading