Skip to content

Commit

Permalink
feat: retrieve conversation metadata when loading conversation (#27)
Browse files Browse the repository at this point in the history
* feat: retrieve conversation metadata when loading conversation

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: change /_plugins/_ml to a constant

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: change import

Signed-off-by: SuZhou-Joe <[email protected]>

* feat: optimize code

Signed-off-by: SuZhou-Joe <[email protected]>

---------

Signed-off-by: SuZhou-Joe <[email protected]>
  • Loading branch information
SuZhou-Joe committed Dec 5, 2023
1 parent ca6ab44 commit 2680dce
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 35 deletions.
2 changes: 1 addition & 1 deletion common/types/chat_saved_object_attributes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export interface Interaction {

export interface ISession {
title: string;
version: number;
version?: number;
createdTimeMs: number;
updatedTimeMs: number;
messages: IMessage[];
Expand Down
4 changes: 2 additions & 2 deletions public/hooks/use_sessions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ export const usePatchSession = () => {
dispatch({ type: 'request' });
return core.services.http
.put(`${ASSISTANT_API.SESSION}/${sessionId}`, {
query: {
body: JSON.stringify({
title,
},
}),
signal: abortControllerRef.current.signal,
})
.then((payload) => dispatch({ type: 'success', payload }))
Expand Down
4 changes: 2 additions & 2 deletions server/routes/chat_routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ const updateSessionRoute = {
params: schema.object({
sessionId: schema.string(),
}),
query: schema.object({
body: schema.object({
title: schema.string(),
}),
},
Expand Down Expand Up @@ -225,7 +225,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
try {
const getResponse = await storageService.updateSession(
request.params.sessionId,
request.query.title
request.body.title
);
return response.ok({ body: getResponse });
} catch (error) {
Expand Down
3 changes: 2 additions & 1 deletion server/services/chat/olly_chat_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { LLMModelFactory } from '../../olly/models/llm_model_factory';
import { PPLTools } from '../../olly/tools/tool_sets/ppl';
import { PPLGenerationRequestSchema } from '../../routes/langchain_routes';
import { ChatService } from './chat_service';
import { ML_COMMONS_BASE_API } from '../../olly/models/constants';

const MEMORY_ID_FIELD = 'memory_id';

Expand Down Expand Up @@ -49,7 +50,7 @@ export class OllyChatService implements ChatService {
}
const agentFrameworkResponse = (await opensearchClient.transport.request({
method: 'POST',
path: `/_plugins/_ml/agents/${rootAgentId}/_execute`,
path: `${ML_COMMONS_BASE_API}/agents/${rootAgentId}/_execute`,
body: {
parameters: parametersPayload,
},
Expand Down
60 changes: 31 additions & 29 deletions server/services/storage/agent_framework_storage_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

import { ApiResponse } from '@opensearch-project/opensearch/.';
import { TransportRequestPromise, ApiResponse } from '@opensearch-project/opensearch/lib/Transport';
import { AgentFrameworkTrace } from '../../../common/utils/llm_chat/traces';
import { OpenSearchClient } from '../../../../../src/core/server';
import {
Expand All @@ -16,6 +16,7 @@ import { GetSessionsSchema } from '../../routes/chat_routes';
import { StorageService } from './storage_service';
import { MessageParser } from '../../types';
import { MessageParserRunner } from '../../utils/message_parser_runner';
import { ML_COMMONS_BASE_API } from '../../olly/models/constants';

export interface SessionOptResponse {
success: boolean;
Expand All @@ -29,37 +30,38 @@ export class AgentFrameworkStorageService implements StorageService {
private readonly messageParsers: MessageParser[] = []
) {}
async getSession(sessionId: string): Promise<ISession> {
const session = (await this.client.transport.request({
method: 'GET',
path: `/_plugins/_ml/memory/conversation/${sessionId}/_list`,
})) as ApiResponse<{
interactions: Interaction[];
}>;
const [interactionsResp, conversation] = await Promise.all([
this.client.transport.request({
method: 'GET',
path: `${ML_COMMONS_BASE_API}/memory/conversation/${sessionId}/_list`,
}) as TransportRequestPromise<
ApiResponse<{
interactions: Interaction[];
}>
>,
this.client.transport.request({
method: 'GET',
path: `${ML_COMMONS_BASE_API}/memory/conversation/${sessionId}`,
}) as TransportRequestPromise<
ApiResponse<{
conversation_id: string;
create_time: string;
updated_time: string;
name: string;
}>
>,
]);
const messageParserRunner = new MessageParserRunner(this.messageParsers);
const finalInteractions: Interaction[] = [...session.body.interactions];
const finalInteractions = interactionsResp.body.interactions;

/**
* Sort interactions according to create_time
*/
finalInteractions.sort((interactionA, interactionB) => {
const { create_time: createTimeA } = interactionA;
const { create_time: createTimeB } = interactionB;
const createTimeMSA = +new Date(createTimeA);
const createTimeMSB = +new Date(createTimeB);
if (isNaN(createTimeMSA) || isNaN(createTimeMSB)) {
return 0;
}
return createTimeMSA - createTimeMSB;
});
let finalMessages: IMessage[] = [];
for (const interaction of finalInteractions) {
finalMessages = [...finalMessages, ...(await messageParserRunner.run(interaction))];
}
return {
title: 'test',
version: 1,
createdTimeMs: Date.now(),
updatedTimeMs: Date.now(),
title: conversation.body.name,
createdTimeMs: +new Date(conversation.body.create_time),
updatedTimeMs: +new Date(conversation.body.updated_time),
messages: finalMessages,
interactions: finalInteractions,
};
Expand Down Expand Up @@ -101,7 +103,7 @@ export class AgentFrameworkStorageService implements StorageService {

const sessions = await this.client.transport.request({
method: 'GET',
path: `/_plugins/_ml/memory/conversation/_search`,
path: `${ML_COMMONS_BASE_API}/memory/conversation/_search`,
body: requestParams,
});

Expand Down Expand Up @@ -140,7 +142,7 @@ export class AgentFrameworkStorageService implements StorageService {
try {
const response = await this.client.transport.request({
method: 'DELETE',
path: `/_plugins/_ml/memory/conversation/${sessionId}/_delete`,
path: `${ML_COMMONS_BASE_API}/memory/conversation/${sessionId}/_delete`,
});
if (response.statusCode === 200) {
return {
Expand All @@ -162,7 +164,7 @@ export class AgentFrameworkStorageService implements StorageService {
try {
const response = await this.client.transport.request({
method: 'PUT',
path: `/_plugins/_ml/memory/conversation/${sessionId}/_update`,
path: `${ML_COMMONS_BASE_API}/memory/conversation/${sessionId}/_update`,
body: {
name: title,
},
Expand All @@ -187,7 +189,7 @@ export class AgentFrameworkStorageService implements StorageService {
try {
const response = (await this.client.transport.request({
method: 'GET',
path: `/_plugins/_ml/memory/trace/${interactionId}/_list`,
path: `${ML_COMMONS_BASE_API}/memory/trace/${interactionId}/_list`,
})) as ApiResponse<{
traces: Array<{
conversation_id: string;
Expand Down

0 comments on commit 2680dce

Please sign in to comment.