-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat(prompt) system prompt #40
Changes from all commits
02d773b
c22001d
2baa14b
c24518f
4451534
d6eee3c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import { Logger } from '@nestjs/common'; | ||
import { ChatMessageInput, LLMProvider } from './llm-provider'; | ||
import express, { Express, Request, Response } from 'express'; | ||
import { GenerateMessageParams } from './type/GenerateMessage'; | ||
|
||
export class App { | ||
private readonly logger = new Logger(App.name); | ||
|
@@ -27,13 +28,22 @@ export class App { | |
this.logger.log('Received chat request.'); | ||
try { | ||
this.logger.debug(JSON.stringify(req.body)); | ||
const { content } = req.body as ChatMessageInput; | ||
const { content, model } = req.body as ChatMessageInput & { | ||
model: string; | ||
}; | ||
|
||
const params: GenerateMessageParams = { | ||
model: model || 'gpt-3.5-turbo', // Default to 'gpt-3.5-turbo' if model is not provided | ||
message: content, | ||
role: 'user', | ||
}; | ||
Comment on lines
+35
to
+39
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Refactor model handling and role assignment. Several concerns with the current implementation:
Consider these improvements:
+ const supportedModels = ['gpt-3.5-turbo', 'gpt-4']; // Move to config
+ if (model && !supportedModels.includes(model)) {
+ throw new Error(`Unsupported model: ${model}`);
+ }
const params: GenerateMessageParams = {
- model: model || 'gpt-3.5-turbo', // Default to 'gpt-3.5-turbo' if model is not provided
+ model: model || process.env.DEFAULT_MODEL || 'gpt-3.5-turbo',
message: content,
- role: 'user',
+ role: req.body.role || 'user', // Allow system prompts
};
|
||
|
||
this.logger.debug(`Request content: "${content}"`); | ||
res.setHeader('Content-Type', 'text/event-stream'); | ||
res.setHeader('Cache-Control', 'no-cache'); | ||
res.setHeader('Connection', 'keep-alive'); | ||
this.logger.debug('Response headers set for streaming.'); | ||
await this.llmProvider.generateStreamingResponse(content, res); | ||
await this.llmProvider.generateStreamingResponse(params, res); | ||
} catch (error) { | ||
this.logger.error('Error in chat endpoint:', error); | ||
res.status(500).json({ error: 'Internal server error' }); | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -8,6 +8,9 @@ import { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
} from 'node-llama-cpp'; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import { ModelProvider } from './model-provider.js'; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import { Logger } from '@nestjs/common'; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import { systemPrompts } from '../prompt/systemPrompt'; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import { ChatCompletionMessageParam } from 'openai/resources/chat/completions'; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import { GenerateMessageParams } from '../type/GenerateMessage'; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
//TODO: using protocol class | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
export class LlamaModelProvider extends ModelProvider { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -33,7 +36,7 @@ export class LlamaModelProvider extends ModelProvider { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
async generateStreamingResponse( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
content: string, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ model, message, role = 'user' }: GenerateMessageParams, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused 'model' parameter in method signature. The destructured |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
res: Response, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
): Promise<void> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+39
to
41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Add input validation and implement system prompt support. Several improvements are needed:
Consider this implementation: async generateStreamingResponse(
{ model, message, role = 'user' }: GenerateMessageParams,
res: Response,
): Promise<void> {
+ if (!message?.trim()) {
+ throw new Error('Message cannot be empty');
+ }
+
+ const systemPrompt = 'You are a helpful AI assistant.'; // Consider making this configurable
+ const formattedMessage = role === 'system' ? message : `${systemPrompt}\n\nUser: ${message}`;
+
this.logger.log('Generating streaming response with Llama...');
const session = new LlamaChatSession({
contextSequence: this.context.getSequence(),
});
this.logger.log('LlamaChatSession created.');
let chunkCount = 0;
const startTime = Date.now();
try {
- await session.prompt(message, {
+ await session.prompt(formattedMessage, { Also applies to: 48-48 |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
this.logger.log('Generating streaming response with Llama...'); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -43,8 +46,22 @@ export class LlamaModelProvider extends ModelProvider { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
this.logger.log('LlamaChatSession created.'); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
let chunkCount = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const startTime = Date.now(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Get the system prompt based on the model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || ''; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const messages = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ role: 'system', content: systemPrompt }, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ role: role as 'user' | 'system' | 'assistant', content: message }, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Convert messages array to a single formatted string for Llama | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const formattedPrompt = messages | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.map(({ role, content }) => `${role}: ${content}`) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.join('\n'); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+50
to
+61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Enhance system prompt handling with proper validation. The current implementation has several potential issues:
Consider this improved implementation: - // Get the system prompt based on the model
- const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || '';
+ // Get the system prompt based on the provided model
+ if (!systemPrompts[model]) {
+ throw new Error(`System prompt not found for model: ${model}`);
+ }
+ const systemPrompt = systemPrompts[model].systemPrompt; 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
try { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
await session.prompt(content, { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
await session.prompt(formattedPrompt, { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
onTextChunk: chunk => { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
chunkCount++; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
this.logger.debug(`Sending chunk #${chunkCount}: "${chunk}"`); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2,6 +2,10 @@ import { Response } from 'express'; | |||||||||||||||||||||||||||
import OpenAI from 'openai'; | ||||||||||||||||||||||||||||
import { ModelProvider } from './model-provider'; | ||||||||||||||||||||||||||||
import { Logger } from '@nestjs/common'; | ||||||||||||||||||||||||||||
import { systemPrompts } from '../prompt/systemPrompt'; | ||||||||||||||||||||||||||||
import { ChatCompletionMessageParam } from 'openai/resources/chat/completions'; | ||||||||||||||||||||||||||||
import { GenerateMessageParams } from '../type/GenerateMessage'; | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
export class OpenAIModelProvider extends ModelProvider { | ||||||||||||||||||||||||||||
private readonly logger = new Logger(OpenAIModelProvider.name); | ||||||||||||||||||||||||||||
private openai: OpenAI; | ||||||||||||||||||||||||||||
|
@@ -15,23 +19,34 @@ export class OpenAIModelProvider extends ModelProvider { | |||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
async generateStreamingResponse( | ||||||||||||||||||||||||||||
content: string, | ||||||||||||||||||||||||||||
{ model, message, role = 'user' }: GenerateMessageParams, | ||||||||||||||||||||||||||||
res: Response, | ||||||||||||||||||||||||||||
): Promise<void> { | ||||||||||||||||||||||||||||
this.logger.log('Generating streaming response with OpenAI...'); | ||||||||||||||||||||||||||||
const startTime = Date.now(); | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
// Set SSE headers | ||||||||||||||||||||||||||||
res.writeHead(200, { | ||||||||||||||||||||||||||||
'Content-Type': 'text/event-stream', | ||||||||||||||||||||||||||||
'Cache-Control': 'no-cache', | ||||||||||||||||||||||||||||
Connection: 'keep-alive', | ||||||||||||||||||||||||||||
}); | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
// Get the system prompt based on the model | ||||||||||||||||||||||||||||
const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || ''; | ||||||||||||||||||||||||||||
Comment on lines
+35
to
+36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Suggestion: Use the Currently, the system prompt is fetched using the hardcoded key Apply this diff to utilize the -const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || '';
+const systemPrompt = systemPrompts[model]?.systemPrompt || ''; 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
const messages: ChatCompletionMessageParam[] = [ | ||||||||||||||||||||||||||||
{ role: 'system', content: systemPrompt }, | ||||||||||||||||||||||||||||
{ role: role as 'user' | 'system' | 'assistant', content: message }, | ||||||||||||||||||||||||||||
]; | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
try { | ||||||||||||||||||||||||||||
const stream = await this.openai.chat.completions.create({ | ||||||||||||||||||||||||||||
model: 'gpt-3.5-turbo', | ||||||||||||||||||||||||||||
messages: [{ role: 'user', content: content }], | ||||||||||||||||||||||||||||
model, | ||||||||||||||||||||||||||||
messages, | ||||||||||||||||||||||||||||
Comment on lines
+45
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Add model validation before API call Consider validating the model parameter against the list of available models to fail fast and provide better error messages. + // Validate model before making the API call
+ try {
+ await this.openai.models.retrieve(model);
+ } catch (error) {
+ this.logger.error(`Invalid model: ${model}`);
+ throw new Error(`Invalid model: ${model}`);
+ }
+
const stream = await this.openai.chat.completions.create({
model,
messages, 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||
stream: true, | ||||||||||||||||||||||||||||
}); | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
let chunkCount = 0; | ||||||||||||||||||||||||||||
for await (const chunk of stream) { | ||||||||||||||||||||||||||||
const content = chunk.choices[0]?.delta?.content || ''; | ||||||||||||||||||||||||||||
|
@@ -41,6 +56,7 @@ export class OpenAIModelProvider extends ModelProvider { | |||||||||||||||||||||||||||
res.write(`data: ${JSON.stringify(chunk)}\n\n`); | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
const endTime = Date.now(); | ||||||||||||||||||||||||||||
this.logger.log( | ||||||||||||||||||||||||||||
`Response generation completed. Total chunks: ${chunkCount}`, | ||||||||||||||||||||||||||||
|
@@ -59,20 +75,18 @@ export class OpenAIModelProvider extends ModelProvider { | |||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
async getModelTagsResponse(res: Response): Promise<void> { | ||||||||||||||||||||||||||||
this.logger.log('Fetching available models from OpenAI...'); | ||||||||||||||||||||||||||||
// Set SSE headers | ||||||||||||||||||||||||||||
res.writeHead(200, { | ||||||||||||||||||||||||||||
'Content-Type': 'text/event-stream', | ||||||||||||||||||||||||||||
'Cache-Control': 'no-cache', | ||||||||||||||||||||||||||||
Connection: 'keep-alive', | ||||||||||||||||||||||||||||
}); | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
try { | ||||||||||||||||||||||||||||
const startTime = Date.now(); | ||||||||||||||||||||||||||||
const models = await this.openai.models.list(); | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
const response = { | ||||||||||||||||||||||||||||
models: models, // Wrap the models in the required structure | ||||||||||||||||||||||||||||
models: models, | ||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
const endTime = Date.now(); | ||||||||||||||||||||||||||||
this.logger.log( | ||||||||||||||||||||||||||||
`Model fetching completed. Total models: ${models.data.length}`, | ||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
// Define and export the system prompts object | ||
export const systemPrompts = { | ||
'codefox-basic': { | ||
systemPrompt: `You are CodeFox, an advanced and powerful AI specialized in code generation and software engineering. | ||
Your purpose is to help developers build complete and efficient applications by providing well-structured, optimized, and maintainable code.`, | ||
}, | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
export interface GenerateMessageParams { | ||
model: string; // Model to use, e.g., 'gpt-3.5-turbo' | ||
message: string; // User's message or query | ||
role?: 'user' | 'system' | 'assistant' | 'tool' | 'function'; // Optional role | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add type validation for request body.
The type assertion using
as
bypasses runtime type checking. Consider adding validation to ensure the request body matches the expected structure.📝 Committable suggestion