Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(prompt) system prompt #40

Merged
merged 6 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/src/chat/chat.resolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ export class ChatResolver {
MessageRole.User,
);

const iterator = this.chatProxyService.streamChat(input.message);
const iterator = this.chatProxyService.streamChat(input);
let accumulatedContent = '';

for await (const chunk of iterator) {
Expand Down
16 changes: 12 additions & 4 deletions backend/src/chat/chat.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ import { Message, MessageRole } from 'src/chat/message.model';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm';
import { User } from 'src/user/user.model';
import { NewChatInput, UpdateChatTitleInput } from 'src/chat/dto/chat.input';
import {
ChatInput,
NewChatInput,
UpdateChatTitleInput,
} from 'src/chat/dto/chat.input';

type CustomAsyncIterableIterator<T> = AsyncIterator<T> & {
[Symbol.asyncIterator](): AsyncIterableIterator<T>;
Expand All @@ -17,8 +21,12 @@ export class ChatProxyService {

constructor(private httpService: HttpService) {}

streamChat(input: string): CustomAsyncIterableIterator<ChatCompletionChunk> {
this.logger.debug('request chat input: ' + input);
streamChat(
input: ChatInput,
): CustomAsyncIterableIterator<ChatCompletionChunk> {
this.logger.debug(
`Request chat input: ${input.message} with model: ${input.model}`,
);
let isDone = false;
let responseSubscription: any;
const chunkQueue: ChatCompletionChunk[] = [];
Expand Down Expand Up @@ -60,7 +68,7 @@ export class ChatProxyService {
responseSubscription = this.httpService
.post(
'http://localhost:3001/chat/completion',
{ content: input },
{ content: input.message, model: input.model },
{ responseType: 'stream' },
)
.subscribe({
Expand Down
5 changes: 3 additions & 2 deletions llm-server/src/llm-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { ModelProvider } from './model/model-provider';
import { OpenAIModelProvider } from './model/openai-model-provider';
import { LlamaModelProvider } from './model/llama-model-provider';
import { Logger } from '@nestjs/common';
import { GenerateMessageParams } from './type/GenerateMessage';

export interface ChatMessageInput {
content: string;
Expand Down Expand Up @@ -32,10 +33,10 @@ export class LLMProvider {
}

async generateStreamingResponse(
content: string,
params: GenerateMessageParams,
res: Response,
): Promise<void> {
await this.modelProvider.generateStreamingResponse(content, res);
await this.modelProvider.generateStreamingResponse(params, res);
}

async getModelTags(res: Response): Promise<void> {
Expand Down
14 changes: 12 additions & 2 deletions llm-server/src/main.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Logger } from '@nestjs/common';
import { ChatMessageInput, LLMProvider } from './llm-provider';
import express, { Express, Request, Response } from 'express';
import { GenerateMessageParams } from './type/GenerateMessage';

export class App {
private readonly logger = new Logger(App.name);
Expand All @@ -27,13 +28,22 @@ export class App {
this.logger.log('Received chat request.');
try {
this.logger.debug(JSON.stringify(req.body));
const { content } = req.body as ChatMessageInput;
const { content, model } = req.body as ChatMessageInput & {
model: string;
};
Comment on lines +31 to +33
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add type validation for request body.

The type assertion using as bypasses runtime type checking. Consider adding validation to ensure the request body matches the expected structure.

-      const { content, model } = req.body as ChatMessageInput & {
-        model: string;
-      };
+      if (!req.body?.content || typeof req.body.content !== 'string') {
+        throw new Error('Invalid content in request body');
+      }
+      const { content, model = 'gpt-3.5-turbo' } = req.body;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
const { content, model } = req.body as ChatMessageInput & {
model: string;
};
if (!req.body?.content || typeof req.body.content !== 'string') {
throw new Error('Invalid content in request body');
}
const { content, model = 'gpt-3.5-turbo' } = req.body;


const params: GenerateMessageParams = {
model: model || 'gpt-3.5-turbo', // Default to 'gpt-3.5-turbo' if model is not provided
message: content,
role: 'user',
};
Comment on lines +35 to +39
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Refactor model handling and role assignment.

Several concerns with the current implementation:

  1. Model defaulting should be handled at a configuration level
  2. No validation for supported model types
  3. Hardcoded 'user' role might limit the system prompt functionality mentioned in PR objectives

Consider these improvements:

  1. Move model configuration to a central location
  2. Add model validation
  3. Allow role to be specified in the request for system prompts
+      const supportedModels = ['gpt-3.5-turbo', 'gpt-4']; // Move to config
+      if (model && !supportedModels.includes(model)) {
+        throw new Error(`Unsupported model: ${model}`);
+      }
       const params: GenerateMessageParams = {
-        model: model || 'gpt-3.5-turbo', // Default to 'gpt-3.5-turbo' if model is not provided
+        model: model || process.env.DEFAULT_MODEL || 'gpt-3.5-turbo',
         message: content,
-        role: 'user',
+        role: req.body.role || 'user', // Allow system prompts
       };

Committable suggestion skipped: line range outside the PR's diff.


this.logger.debug(`Request content: "${content}"`);
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
this.logger.debug('Response headers set for streaming.');
await this.llmProvider.generateStreamingResponse(content, res);
await this.llmProvider.generateStreamingResponse(params, res);
} catch (error) {
this.logger.error('Error in chat endpoint:', error);
res.status(500).json({ error: 'Internal server error' });
Expand Down
21 changes: 19 additions & 2 deletions llm-server/src/model/llama-model-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import {
} from 'node-llama-cpp';
import { ModelProvider } from './model-provider.js';
import { Logger } from '@nestjs/common';
import { systemPrompts } from '../prompt/systemPrompt';
import { ChatCompletionMessageParam } from 'openai/resources/chat/completions';
import { GenerateMessageParams } from '../type/GenerateMessage';

//TODO: using protocol class
export class LlamaModelProvider extends ModelProvider {
Expand All @@ -33,7 +36,7 @@ export class LlamaModelProvider extends ModelProvider {
}

async generateStreamingResponse(
content: string,
{ model, message, role = 'user' }: GenerateMessageParams,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Unused 'model' parameter in method signature.

The destructured model parameter is not being utilized within the method implementation. Either remove it if unnecessary or implement model-specific logic.

res: Response,
): Promise<void> {
Comment on lines +39 to 41
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add input validation and implement system prompt support.

Several improvements are needed:

  1. Add validation for required parameters
  2. Implement system prompt support as per PR objectives

Consider this implementation:

 async generateStreamingResponse(
   { model, message, role = 'user' }: GenerateMessageParams,
   res: Response,
 ): Promise<void> {
+  if (!message?.trim()) {
+    throw new Error('Message cannot be empty');
+  }
+
+  const systemPrompt = 'You are a helpful AI assistant.'; // Consider making this configurable
+  const formattedMessage = role === 'system' ? message : `${systemPrompt}\n\nUser: ${message}`;
+
   this.logger.log('Generating streaming response with Llama...');
   const session = new LlamaChatSession({
     contextSequence: this.context.getSequence(),
   });
   this.logger.log('LlamaChatSession created.');
   let chunkCount = 0;
   const startTime = Date.now();
   try {
-    await session.prompt(message, {
+    await session.prompt(formattedMessage, {

Also applies to: 48-48

this.logger.log('Generating streaming response with Llama...');
Expand All @@ -43,8 +46,22 @@ export class LlamaModelProvider extends ModelProvider {
this.logger.log('LlamaChatSession created.');
let chunkCount = 0;
const startTime = Date.now();

// Get the system prompt based on the model
const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || '';

const messages = [
{ role: 'system', content: systemPrompt },
{ role: role as 'user' | 'system' | 'assistant', content: message },
];

// Convert messages array to a single formatted string for Llama
const formattedPrompt = messages
.map(({ role, content }) => `${role}: ${content}`)
.join('\n');
Comment on lines +50 to +61
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance system prompt handling with proper validation.

The current implementation has several potential issues:

  1. Hardcoded model name 'codefox-basic'
  2. No validation if the model exists in systemPrompts
  3. Silent fallback to empty string could mask configuration issues

Consider this improved implementation:

-    // Get the system prompt based on the model
-    const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || '';
+    // Get the system prompt based on the provided model
+    if (!systemPrompts[model]) {
+      throw new Error(`System prompt not found for model: ${model}`);
+    }
+    const systemPrompt = systemPrompts[model].systemPrompt;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Get the system prompt based on the model
const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || '';
const messages = [
{ role: 'system', content: systemPrompt },
{ role: role as 'user' | 'system' | 'assistant', content: message },
];
// Convert messages array to a single formatted string for Llama
const formattedPrompt = messages
.map(({ role, content }) => `${role}: ${content}`)
.join('\n');
// Get the system prompt based on the provided model
if (!systemPrompts[model]) {
throw new Error(`System prompt not found for model: ${model}`);
}
const systemPrompt = systemPrompts[model].systemPrompt;
const messages = [
{ role: 'system', content: systemPrompt },
{ role: role as 'user' | 'system' | 'assistant', content: message },
];
// Convert messages array to a single formatted string for Llama
const formattedPrompt = messages
.map(({ role, content }) => `${role}: ${content}`)
.join('\n');


try {
await session.prompt(content, {
await session.prompt(formattedPrompt, {
onTextChunk: chunk => {
chunkCount++;
this.logger.debug(`Sending chunk #${chunkCount}: "${chunk}"`);
Expand Down
3 changes: 2 additions & 1 deletion llm-server/src/model/model-provider.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { Response } from 'express';
import { GenerateMessageParams } from '../type/GenerateMessage';

export abstract class ModelProvider {
abstract initialize(): Promise<void>;
abstract generateStreamingResponse(
content: string,
params: GenerateMessageParams,
res: Response,
): Promise<void>;

Expand Down
28 changes: 21 additions & 7 deletions llm-server/src/model/openai-model-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ import { Response } from 'express';
import OpenAI from 'openai';
import { ModelProvider } from './model-provider';
import { Logger } from '@nestjs/common';
import { systemPrompts } from '../prompt/systemPrompt';
import { ChatCompletionMessageParam } from 'openai/resources/chat/completions';
import { GenerateMessageParams } from '../type/GenerateMessage';

export class OpenAIModelProvider extends ModelProvider {
private readonly logger = new Logger(OpenAIModelProvider.name);
private openai: OpenAI;
Expand All @@ -15,23 +19,34 @@ export class OpenAIModelProvider extends ModelProvider {
}

async generateStreamingResponse(
content: string,
{ model, message, role = 'user' }: GenerateMessageParams,
res: Response,
): Promise<void> {
this.logger.log('Generating streaming response with OpenAI...');
const startTime = Date.now();

// Set SSE headers
res.writeHead(200, {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
});

// Get the system prompt based on the model
const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || '';
Comment on lines +35 to +36
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Suggestion: Use the model parameter to select the system prompt

Currently, the system prompt is fetched using the hardcoded key 'codefox-basic'. To support dynamic selection of system prompts based on the model being used, consider using the model parameter as the key.

Apply this diff to utilize the model parameter:

-const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || '';
+const systemPrompt = systemPrompts[model]?.systemPrompt || '';
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Get the system prompt based on the model
const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || '';
// Get the system prompt based on the model
const systemPrompt = systemPrompts[model]?.systemPrompt || '';


const messages: ChatCompletionMessageParam[] = [
{ role: 'system', content: systemPrompt },
{ role: role as 'user' | 'system' | 'assistant', content: message },
];

try {
const stream = await this.openai.chat.completions.create({
model: 'gpt-3.5-turbo',
messages: [{ role: 'user', content: content }],
model,
messages,
Comment on lines +45 to +46
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add model validation before API call

Consider validating the model parameter against the list of available models to fail fast and provide better error messages.

+    // Validate model before making the API call
+    try {
+      await this.openai.models.retrieve(model);
+    } catch (error) {
+      this.logger.error(`Invalid model: ${model}`);
+      throw new Error(`Invalid model: ${model}`);
+    }
+
     const stream = await this.openai.chat.completions.create({
       model,
       messages,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
model,
messages,
// Validate model before making the API call
try {
await this.openai.models.retrieve(model);
} catch (error) {
this.logger.error(`Invalid model: ${model}`);
throw new Error(`Invalid model: ${model}`);
}
const stream = await this.openai.chat.completions.create({
model,
messages,

stream: true,
});

let chunkCount = 0;
for await (const chunk of stream) {
const content = chunk.choices[0]?.delta?.content || '';
Expand All @@ -41,6 +56,7 @@ export class OpenAIModelProvider extends ModelProvider {
res.write(`data: ${JSON.stringify(chunk)}\n\n`);
}
}

const endTime = Date.now();
this.logger.log(
`Response generation completed. Total chunks: ${chunkCount}`,
Expand All @@ -59,20 +75,18 @@ export class OpenAIModelProvider extends ModelProvider {

async getModelTagsResponse(res: Response): Promise<void> {
this.logger.log('Fetching available models from OpenAI...');
// Set SSE headers
res.writeHead(200, {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
});

try {
const startTime = Date.now();
const models = await this.openai.models.list();

const response = {
models: models, // Wrap the models in the required structure
models: models,
};

const endTime = Date.now();
this.logger.log(
`Model fetching completed. Total models: ${models.data.length}`,
Expand Down
7 changes: 7 additions & 0 deletions llm-server/src/prompt/systemPrompt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Define and export the system prompts object
export const systemPrompts = {
'codefox-basic': {
systemPrompt: `You are CodeFox, an advanced and powerful AI specialized in code generation and software engineering.
Your purpose is to help developers build complete and efficient applications by providing well-structured, optimized, and maintainable code.`,
},
};
5 changes: 5 additions & 0 deletions llm-server/src/type/GenerateMessage.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export interface GenerateMessageParams {
model: string; // Model to use, e.g., 'gpt-3.5-turbo'
message: string; // User's message or query
role?: 'user' | 'system' | 'assistant' | 'tool' | 'function'; // Optional role
}
Loading