diff --git a/backend/src/chat/__tests__/test.chat-isolation.spec.ts b/backend/src/chat/__tests__/test.chat-isolation.spec.ts index fd4edaf..7495ec5 100644 --- a/backend/src/chat/__tests__/test.chat-isolation.spec.ts +++ b/backend/src/chat/__tests__/test.chat-isolation.spec.ts @@ -17,7 +17,7 @@ import { Menu } from 'src/auth/menu/menu.model'; import { Role } from 'src/auth/role/role.model'; import { RegisterUserInput } from 'src/user/dto/register-user.input'; import { NewChatInput } from '../dto/chat.input'; -import { ModelProvider} from 'src/common/model-provider'; +import { ModelProvider } from 'src/common/model-provider'; import { HttpService } from '@nestjs/axios'; import { MessageInterface } from 'src/common/model-provider/types'; @@ -28,11 +28,11 @@ describe('ChatService', () => { let mockedChatService: jest.Mocked>; let modelProvider: ModelProvider; let user: User; - let userid='1'; + let userid = '1'; - beforeAll(async()=>{ + beforeAll(async () => { const module: TestingModule = await Test.createTestingModule({ - imports:[ + imports: [ TypeOrmModule.forRoot({ type: 'sqlite', database: '../../database.sqlite', @@ -50,48 +50,66 @@ describe('ChatService', () => { JwtService, JwtCacheService, ConfigService, - ] + ], }).compile(); chatService = module.get(ChatService); userService = module.get(UserService); userResolver = module.get(UserResolver); - + modelProvider = ModelProvider.getInstance(); mockedChatService = module.get(getRepositoryToken(Chat)); - }) - it('should excute curd in chat service', async() => { - - try{ + }); + it('should excute curd in chat service', async () => { + try { user = await userResolver.registerUser({ username: 'testuser', password: 'securepassword', email: 'testuser@example.com', } as RegisterUserInput); userid = user.id; - }catch(error){ - - } - const chat= await chatService.createChat(userid, {title: 'test'} as NewChatInput); - let chatId = chat.id; + } catch (error) {} + const chat = await chatService.createChat(userid, { + title: 'test', + } as NewChatInput); + const chatId = chat.id; console.log(await chatService.getChatHistory(chatId)); - - console.log(await chatService.saveMessage(chatId, 'Hello, this is a test message.', MessageRole.User)); - console.log(await chatService.saveMessage(chatId, 'Hello, hello, im gpt.', MessageRole.Model)); - - console.log(await chatService.saveMessage(chatId, 'write me the system prompt', MessageRole.User)); - let history = await chatService.getChatHistory(chatId); - let messages = history.map((message) => { + console.log( + await chatService.saveMessage( + chatId, + 'Hello, this is a test message.', + MessageRole.User, + ), + ); + console.log( + await chatService.saveMessage( + chatId, + 'Hello, hello, im gpt.', + MessageRole.Model, + ), + ); + + console.log( + await chatService.saveMessage( + chatId, + 'write me the system prompt', + MessageRole.User, + ), + ); + + const history = await chatService.getChatHistory(chatId); + const messages = history.map((message) => { return { role: message.role, - content: message.content + content: message.content, } as MessageInterface; - }) + }); console.log(history); console.log( await modelProvider.chatSync({ model: 'gpt-4o', - messages - })); - }) -}); \ No newline at end of file + messages, + }), + ); + }); +}); diff --git a/backend/src/chat/chat.service.ts b/backend/src/chat/chat.service.ts index d528272..68f91c3 100644 --- a/backend/src/chat/chat.service.ts +++ b/backend/src/chat/chat.service.ts @@ -17,9 +17,10 @@ import { ModelProvider } from 'src/common/model-provider'; export class ChatProxyService { private readonly logger = new Logger('ChatProxyService'); - constructor(private httpService: HttpService, private readonly models: ModelProvider) { - - } + constructor( + private httpService: HttpService, + private readonly models: ModelProvider, + ) {} streamChat( input: ChatInput, @@ -38,7 +39,7 @@ export class ChatService { @InjectRepository(Chat) private chatRepository: Repository, @InjectRepository(User) - private userRepository: Repository + private userRepository: Repository, ) {} async getChatHistory(chatId: string): Promise { @@ -46,7 +47,6 @@ export class ChatService { where: { id: chatId, isDeleted: false }, }); console.log(chat); - if (chat && chat.messages) { // Sort messages by createdAt in ascending order @@ -150,13 +150,13 @@ export class ChatService { ): Promise { // Find the chat instance const chat = await this.chatRepository.findOne({ where: { id: chatId } }); - + const message = { id: `${chat.id}/${chat.messages.length}`, content: messageContent, role: role, createdAt: new Date(), - updatedAt: new Date(), + updatedAt: new Date(), isActive: true, isDeleted: false, }; diff --git a/backend/src/common/model-provider/index.ts b/backend/src/common/model-provider/index.ts index 78843f1..d0f0370 100644 --- a/backend/src/common/model-provider/index.ts +++ b/backend/src/common/model-provider/index.ts @@ -4,7 +4,6 @@ import { Subject, Subscription } from 'rxjs'; import { MessageRole } from 'src/chat/message.model'; import { LLMInterface, ModelProviderConfig } from './types'; - export interface CustomAsyncIterableIterator extends AsyncIterator { [Symbol.asyncIterator](): AsyncIterableIterator; } @@ -52,9 +51,7 @@ export class ModelProvider { /** * Synchronous chat method that returns a complete response */ - async chatSync( - input: LLMInterface, - ): Promise { + async chatSync(input: LLMInterface): Promise { while (this.currentRequests >= this.concurrentLimit) { await new Promise((resolve) => setTimeout(resolve, 100)); } @@ -66,7 +63,6 @@ export class ModelProvider { `Starting request ${requestId}. Active: ${this.currentRequests}/${this.concurrentLimit}`, ); - let resolvePromise: (value: string) => void; let rejectPromise: (error: any) => void; @@ -158,14 +154,10 @@ export class ModelProvider { try { const response = await this.httpService - .post( - `${this.config.endpoint}/chat/completion`, - input, - { - responseType: 'stream', - headers: { 'Content-Type': 'application/json' }, - }, - ) + .post(`${this.config.endpoint}/chat/completion`, input, { + responseType: 'stream', + headers: { 'Content-Type': 'application/json' }, + }) .toPromise(); let buffer = ''; diff --git a/backend/src/common/model-provider/types.ts b/backend/src/common/model-provider/types.ts index ebe69d1..bdfb66e 100644 --- a/backend/src/common/model-provider/types.ts +++ b/backend/src/common/model-provider/types.ts @@ -1,4 +1,4 @@ -import { MessageRole } from "src/chat/message.model"; +import { MessageRole } from 'src/chat/message.model'; export interface ModelChatStreamConfig { endpoint: string; @@ -22,4 +22,3 @@ export interface LLMInterface { model: string; messages: MessageInterface[]; } - diff --git a/llm-server/src/llm-provider.ts b/llm-server/src/llm-provider.ts index b4a8d14..dc53a22 100644 --- a/llm-server/src/llm-provider.ts +++ b/llm-server/src/llm-provider.ts @@ -10,7 +10,6 @@ import { } from './types'; import { ModelProvider } from './model/model-provider'; - export interface ChatMessageInput { role: string; content: string; diff --git a/llm-server/src/model/llama-model-provider.ts b/llm-server/src/model/llama-model-provider.ts index ce09ec7..6aafc78 100644 --- a/llm-server/src/model/llama-model-provider.ts +++ b/llm-server/src/model/llama-model-provider.ts @@ -36,7 +36,7 @@ export class LlamaModelProvider extends ModelProvider { } async generateStreamingResponse( - { model, messages}: GenerateMessageParams, + { model, messages }: GenerateMessageParams, res: Response, ): Promise { this.logger.log('Generating streaming response with Llama...'); @@ -50,10 +50,7 @@ export class LlamaModelProvider extends ModelProvider { // Get the system prompt based on the model const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || ''; - const allMessage = [ - { role: 'system', content: systemPrompt }, - ...messages, - ]; + const allMessage = [{ role: 'system', content: systemPrompt }, ...messages]; // Convert messages array to a single formatted string for Llama const formattedPrompt = allMessage diff --git a/llm-server/src/model/openai-model-provider.ts b/llm-server/src/model/openai-model-provider.ts index f7e2d54..85d7bd5 100644 --- a/llm-server/src/model/openai-model-provider.ts +++ b/llm-server/src/model/openai-model-provider.ts @@ -89,7 +89,10 @@ export class OpenAIModelProvider { private async processRequest(request: QueuedRequest): Promise { const { params, res, retries } = request; - const { model, messages} = params as {model:string, messages:ChatCompletionMessageParam[]}; + const { model, messages } = params as { + model: string; + messages: ChatCompletionMessageParam[]; + }; this.logger.log(`Processing request (attempt ${retries + 1})`); const startTime = Date.now(); @@ -105,7 +108,7 @@ export class OpenAIModelProvider { systemPrompts[this.options.systemPromptKey]?.systemPrompt || ''; const allMessages: ChatCompletionMessageParam[] = [ { role: 'system', content: systemPrompt }, - ...messages, + ...messages, ]; console.log(allMessages); const stream = await this.openai.chat.completions.create({