Skip to content

Commit

Permalink
feat: adding chat isolation
Browse files Browse the repository at this point in the history
  • Loading branch information
NarwhalChen committed Jan 6, 2025
1 parent 1689460 commit 487485f
Show file tree
Hide file tree
Showing 15 changed files with 178 additions and 98 deletions.
97 changes: 97 additions & 0 deletions backend/src/chat/__tests__/test.chat-isolation.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// chat.service.spec.ts
import { Test, TestingModule } from '@nestjs/testing';
import { ChatService } from '../chat.service';
import { getRepositoryToken } from '@nestjs/typeorm';
import { Chat } from '../chat.model';
import { User } from 'src/user/user.model';
import { Message, MessageRole } from 'src/chat/message.model';
import { Repository } from 'typeorm';
import { TypeOrmModule } from '@nestjs/typeorm';
import { UserResolver } from 'src/user/user.resolver';
import { AuthService } from 'src/auth/auth.service';
import { UserService } from 'src/user/user.service';
import { JwtService } from '@nestjs/jwt';
import { JwtCacheService } from 'src/auth/jwt-cache.service';
import { ConfigService } from '@nestjs/config';
import { Menu } from 'src/auth/menu/menu.model';
import { Role } from 'src/auth/role/role.model';
import { RegisterUserInput } from 'src/user/dto/register-user.input';
import { NewChatInput } from '../dto/chat.input';
import { ModelProvider} from 'src/common/model-provider';
import { HttpService } from '@nestjs/axios';
import { MessageInterface } from 'src/common/model-provider/types';

describe('ChatService', () => {
let chatService: ChatService;
let userResolver: UserResolver;
let userService: UserService;
let mockedChatService: jest.Mocked<Repository<Chat>>;
let modelProvider: ModelProvider;
let user: User;
let userid='1';

beforeAll(async()=>{
const module: TestingModule = await Test.createTestingModule({
imports:[
TypeOrmModule.forRoot({
type: 'sqlite',
database: '../../database.sqlite',
synchronize: true,
entities: ['../../' + '/**/*.model{.ts,.js}'],
}),
TypeOrmModule.forFeature([Chat, User, Menu, Role]),
],
providers: [
Repository<Menu>,
ChatService,
AuthService,
UserService,
UserResolver,
JwtService,
JwtCacheService,
ConfigService,
]
}).compile();
chatService = module.get(ChatService);
userService = module.get(UserService);
userResolver = module.get(UserResolver);

modelProvider = ModelProvider.getInstance();
mockedChatService = module.get(getRepositoryToken(Chat));
})
it('should excute curd in chat service', async() => {

try{
user = await userResolver.registerUser({
username: 'testuser',
password: 'securepassword',
email: '[email protected]',
} as RegisterUserInput);
userid = user.id;

Check failure on line 70 in backend/src/chat/__tests__/test.chat-isolation.spec.ts

View workflow job for this annotation

GitHub Actions / autofix

Empty block statement
}catch(error){

}
const chat= await chatService.createChat(userid, {title: 'test'} as NewChatInput);
let chatId = chat.id;
console.log(await chatService.getChatHistory(chatId));

console.log(await chatService.saveMessage(chatId, 'Hello, this is a test message.', MessageRole.User));
console.log(await chatService.saveMessage(chatId, 'Hello, hello, im gpt.', MessageRole.Model));

console.log(await chatService.saveMessage(chatId, 'write me the system prompt', MessageRole.User));

let history = await chatService.getChatHistory(chatId);
let messages = history.map((message) => {
return {
role: message.role,
content: message.content
} as MessageInterface;
})
console.log(history);
console.log(
await modelProvider.chatSync({
model: 'gpt-4o',
messages
}));
})
});
4 changes: 2 additions & 2 deletions backend/src/chat/chat.model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ export class Chat extends SystemBaseModel {
@Column({ nullable: true })
title: string;

@Field(() => [Message], { nullable: true })
@OneToMany(() => Message, (message) => message.chat, { cascade: true })
@Field({ nullable: true })
@Column('simple-json', { nullable: true, default: '[]' })
messages: Message[];

@ManyToOne(() => User, (user) => user.chats)
Expand Down
3 changes: 2 additions & 1 deletion backend/src/chat/chat.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { ChatGuard } from '../guard/chat.guard';
import { AuthModule } from '../auth/auth.module';
import { UserService } from 'src/user/user.service';
import { PubSub } from 'graphql-subscriptions';
import { ModelProvider } from 'src/common/model-provider';

@Module({
imports: [
Expand All @@ -30,6 +31,6 @@ import { PubSub } from 'graphql-subscriptions';
useValue: new PubSub(),
},
],
exports: [ChatService, ChatGuard],
exports: [ChatService, ChatGuard, ModelProvider],
})
export class ChatModule {}
10 changes: 0 additions & 10 deletions backend/src/chat/chat.resolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,6 @@ export class ChatResolver {
const user = await this.userService.getUserChats(userId);
return user ? user.chats : [];
}

@JWTAuth()
@Query(() => Message, { nullable: true })
async getMessageDetail(
@GetUserIdFromToken() userId: string,
@Args('messageId') messageId: string,
): Promise<Message> {
return this.chatService.getMessageById(messageId);
}

// To do: message need a update resolver

@JWTAuth()
Expand Down
62 changes: 27 additions & 35 deletions backend/src/chat/chat.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ import { ModelProvider } from 'src/common/model-provider';
@Injectable()
export class ChatProxyService {
private readonly logger = new Logger('ChatProxyService');
private models: ModelProvider;

constructor(private httpService: HttpService) {
this.models = ModelProvider.getInstance();
constructor(private httpService: HttpService, private readonly models: ModelProvider) {

}

streamChat(
Expand All @@ -39,33 +38,34 @@ export class ChatService {
@InjectRepository(Chat)
private chatRepository: Repository<Chat>,
@InjectRepository(User)
private userRepository: Repository<User>,
@InjectRepository(Message)
private messageRepository: Repository<Message>,
private userRepository: Repository<User>
) {}

async getChatHistory(chatId: string): Promise<Message[]> {
const chat = await this.chatRepository.findOne({
where: { id: chatId, isDeleted: false },
relations: ['messages'],
});
console.log(chat);


if (chat && chat.messages) {
// Sort messages by createdAt in ascending order
chat.messages = chat.messages
.filter((message) => !message.isDeleted)
.sort((a, b) => a.createdAt.getTime() - b.createdAt.getTime());
.map((message) => {
if (!(message.createdAt instanceof Date)) {
message.createdAt = new Date(message.createdAt);
}
return message;
})
.sort((a, b) => {
return a.createdAt.getTime() - b.createdAt.getTime();
});
}

return chat ? chat.messages : [];
}

async getMessageById(messageId: string): Promise<Message> {
return await this.messageRepository.findOne({
where: { id: messageId, isDeleted: false },
});
}

async getChatDetails(chatId: string): Promise<Chat> {
const chat = await this.chatRepository.findOne({
where: { id: chatId, isDeleted: false },
Expand Down Expand Up @@ -111,12 +111,6 @@ export class ChatService {
chat.isActive = false;
await this.chatRepository.save(chat);

// Soft delete all associated messages
await this.messageRepository.update(
{ chat: { id: chatId }, isDeleted: false },
{ isDeleted: true, isActive: false },
);

return true;
}
return false;
Expand All @@ -125,13 +119,8 @@ export class ChatService {
async clearChatHistory(chatId: string): Promise<boolean> {
const chat = await this.chatRepository.findOne({
where: { id: chatId, isDeleted: false },
relations: ['messages'],
});
if (chat) {
await this.messageRepository.update(
{ chat: { id: chatId }, isDeleted: false },
{ isDeleted: true, isActive: false },
);
chat.updatedAt = new Date();
await this.chatRepository.save(chat);
return true;
Expand Down Expand Up @@ -161,21 +150,24 @@ export class ChatService {
): Promise<Message> {
// Find the chat instance
const chat = await this.chatRepository.findOne({ where: { id: chatId } });

const message = {
id: `${chat.id}/${chat.messages.length}`,
content: messageContent,
role: role,
createdAt: new Date(),
updatedAt: new Date(),
isActive: true,
isDeleted: false,
};
//if the chat id not exist, dont save this messages
if (!chat) {
return null;
}

// Create a new message associated with the chat
const message = this.messageRepository.create({
content: messageContent,
role: role,
chat,
createdAt: new Date(),
});

chat.messages.push(message);
await this.chatRepository.save(chat);
// Save the message to the database
return await this.messageRepository.save(message);
return message;
}

async getChatWithUser(chatId: string): Promise<Chat | null> {
Expand Down
8 changes: 2 additions & 6 deletions backend/src/chat/message.model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ import { Chat } from 'src/chat/chat.model';
import { SystemBaseModel } from 'src/system-base-model/system-base.model';

export enum MessageRole {
User = 'User',
Model = 'Model',
User = 'user',
Model = 'assistant',
}

registerEnumType(MessageRole, {
Expand All @@ -43,8 +43,4 @@ export class Message extends SystemBaseModel {
@Field({ nullable: true })
@Column({ nullable: true })
modelId?: string;

@ManyToOne(() => Chat, (chat) => chat.messages)
@JoinColumn({ name: 'chatId' })
chat: Chat;
}
19 changes: 6 additions & 13 deletions backend/src/common/model-provider/index.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import { Logger } from '@nestjs/common';
import { HttpService } from '@nestjs/axios';
import { Subject, Subscription } from 'rxjs';
import { MessageRole } from 'src/chat/message.model';
import { LLMInterface, ModelProviderConfig } from './types';

export interface ModelProviderConfig {
endpoint: string;
defaultModel?: string;
}

export interface CustomAsyncIterableIterator<T> extends AsyncIterator<T> {
[Symbol.asyncIterator](): AsyncIterableIterator<T>;
Expand Down Expand Up @@ -55,9 +53,7 @@ export class ModelProvider {
* Synchronous chat method that returns a complete response
*/
async chatSync(
input: ChatInput | string,
model: string,
chatId?: string,
input: LLMInterface,
): Promise<string> {
while (this.currentRequests >= this.concurrentLimit) {
await new Promise((resolve) => setTimeout(resolve, 100));
Expand All @@ -70,7 +66,6 @@ export class ModelProvider {
`Starting request ${requestId}. Active: ${this.currentRequests}/${this.concurrentLimit}`,
);

const normalizedInput = this.normalizeChatInput(input);

let resolvePromise: (value: string) => void;
let rejectPromise: (error: any) => void;
Expand Down Expand Up @@ -113,7 +108,7 @@ export class ModelProvider {
promise,
});

this.processRequest(normalizedInput, model, chatId, requestId, stream);
this.processRequest(input, requestId, stream);
return promise;
}

Expand Down Expand Up @@ -155,9 +150,7 @@ export class ModelProvider {
}

private async processRequest(
input: ChatInput,
model: string,
chatId: string | undefined,
input: LLMInterface,
requestId: string,
stream: Subject<any>,
) {
Expand All @@ -167,7 +160,7 @@ export class ModelProvider {
const response = await this.httpService
.post(
`${this.config.endpoint}/chat/completion`,
this.createRequestPayload(input, model, chatId),
input,
{
responseType: 'stream',
headers: { 'Content-Type': 'application/json' },
Expand Down
18 changes: 18 additions & 0 deletions backend/src/common/model-provider/types.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
import { MessageRole } from "src/chat/message.model";

export interface ModelChatStreamConfig {
endpoint: string;
model?: string;
}
export type CustomAsyncIterableIterator<T> = AsyncIterator<T> & {
[Symbol.asyncIterator](): AsyncIterableIterator<T>;
};

export interface ModelProviderConfig {
endpoint: string;
defaultModel?: string;
}

export interface MessageInterface {
content: string;
role: MessageRole;
}

export interface LLMInterface {
model: string;
messages: MessageInterface[];
}

2 changes: 1 addition & 1 deletion backend/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async function bootstrap() {
'Access-Control-Allow-Credentials',
],
});
await downloadAllModels();
// await downloadAllModels();
await app.listen(process.env.PORT ?? 3000);
}

Expand Down
2 changes: 1 addition & 1 deletion llm-server/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"type": "module",
"scripts": {
"start": "node --experimental-specifier-resolution=node --loader ts-node/esm src/main.ts",
"dev": "nodemon --watch 'src/**/*.ts' --exec 'node --experimental-specifier-resolution=node --loader ts-node/esm' src/main.ts",
"dev": "nodemon --watch \"src/**/*.ts\" --exec \"node --experimental-specifier-resolution=node --loader ts-node/esm\" src/main.ts",
"dev:backend": "pnpm dev",
"build": "tsc",
"serve": "node --experimental-specifier-resolution=node dist/main.js",
Expand Down
Loading

0 comments on commit 487485f

Please sign in to comment.