Skip to content

Commit

Permalink
Merge pull request #72 from Zain-ul-din/prompt_optimization
Browse files Browse the repository at this point in the history
(#71)
  • Loading branch information
Zain-ul-din authored Nov 22, 2024
2 parents fefba58 + 7928733 commit 10d2fdb
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 9 deletions.
2 changes: 2 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"build": "npx yarn"
},
"devDependencies": {
"@types/invariant": "^2.2.37",
"@types/node": "^18.14.0",
"@types/node-fetch": "^2.6.2",
"prettier": "^2.8.4",
Expand All @@ -24,6 +25,7 @@
"@whiskeysockets/baileys": "^6.7.7",
"@whiskeysockets/libsignal-node": "github:WhiskeySockets/libsignal-node",
"dotenv": "^16.0.3",
"invariant": "^2.2.4",
"mongo-baileys": "^1.0.1",
"mongodb": "^6.8.0",
"openai": "^4.56.0",
Expand Down
2 changes: 1 addition & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@

// whatsappClient.messageEvent.on('self', welcomeUser);

import { connectToWhatsApp } from "./baileys";
import { connectToWhatsApp } from './baileys';
connectToWhatsApp();
11 changes: 9 additions & 2 deletions src/models/CustomModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ class CustomAIModel extends AIModel<AIArguments, AIHandle> {
this.chatGPTModel = new ChatGPTModel();
}

private static constructInstructAblePrompt({
private constructInstructAblePrompt({
prompt,
instructions
}: {
prompt: string;
instructions: string;
}) {
if (!this.self.dangerouslyAllowFewShotApproach) return prompt;
return `
<context>
${instructions}
Expand All @@ -46,16 +47,22 @@ prompt:
public async sendMessage({ prompt, ...rest }: AIArguments, handle: AIHandle) {
try {
const instructions = await CustomAIModel.readContext(this.self);
const promptWithInstructions = CustomAIModel.constructInstructAblePrompt({
const promptWithInstructions = this.constructInstructAblePrompt({
prompt,
instructions: instructions
});

switch (this.selectedBaseModel) {
case 'ChatGPT':
this.chatGPTModel.instructions = !this.self.dangerouslyAllowFewShotApproach
? instructions
: undefined;
await this.chatGPTModel.sendMessage({ prompt: promptWithInstructions, ...rest }, handle);
break;
case 'Gemini':
this.geminiModel.instructions = !this.self.dangerouslyAllowFewShotApproach
? instructions
: undefined;
await this.geminiModel.sendMessage({ prompt: promptWithInstructions, ...rest }, handle);
break;
}
Expand Down
20 changes: 15 additions & 5 deletions src/models/GeminiModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,27 @@ import { downloadMediaMessage } from '@whiskeysockets/baileys';
/* Local modules */
import { AIModel, AIArguments, AIHandle, AIMetaData } from './BaseAiModel';
import { ENV } from '../baileys/env';
import invariant from 'invariant';

/* Gemini Model */
class GeminiModel extends AIModel<AIArguments, AIHandle> {
/* Variables */
private generativeModel: GenerativeModel;
private generativeModel?: GenerativeModel;
private Gemini: GoogleGenerativeAI;
public chats: { [from: string]: ChatSession };
public instructions: string | undefined;

public constructor() {
super(ENV.API_KEY_GEMINI, 'Gemini', ENV.GEMINI_ICON_PREFIX);
this.Gemini = new GoogleGenerativeAI(ENV.API_KEY_GEMINI as string);

// https://ai.google.dev/gemini-api/docs/models/gemini
this.generativeModel = this.Gemini.getGenerativeModel({ model: 'gemini-1.5-flash' });
this.chats = {};
}

/* Methods */
public async generateCompletion(user: string, prompt: string): Promise<string> {
if (!this.sessionExists(user)) {
this.sessionCreate(user);
this.chats[user] = this.generativeModel.startChat();
this.chats[user] = this.generativeModel!.startChat();
}

const chat = this.chats[user];
Expand All @@ -47,7 +46,18 @@ class GeminiModel extends AIModel<AIArguments, AIHandle> {
};
}

private initGenerativeModel() {
// https://ai.google.dev/gemini-api/docs/models/gemini
this.generativeModel = this.Gemini.getGenerativeModel({
model: 'gemini-1.5-flash',
systemInstruction: this.instructions
});
}

public async generateImageCompletion(prompt: string, metadata: AIMetaData): Promise<string> {
this.initGenerativeModel();
invariant(this.generativeModel, 'Unable to initialize Gemini Generative model');

const { mimeType } = metadata.quoteMetaData.imgMetaData;
if (mimeType === 'image/jpeg') {
const buffer = await downloadMediaMessage(
Expand Down
7 changes: 6 additions & 1 deletion src/models/OpenAIModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class ChatGPTModel extends AIModel<AIArguments, AIHandle> {
private OpenAI: OpenAI;
public DalleSize: DalleSizeImage;

public instructions: string | undefined = undefined;

public constructor() {
super(ENV.API_KEY_OPENAI, 'ChatGPT');

Expand All @@ -42,7 +44,10 @@ class ChatGPTModel extends AIModel<AIArguments, AIHandle> {
/* Methods */
public async generateCompletion(user: string): Promise<ChatCompletionMessage> {
const completion = await this.OpenAI.chat.completions.create({
messages: this.history[user],
messages: [
...(this.instructions ? [{ role: 'system', content: this.instructions }] : []),
...this.history[user]
],
model: ENV.OPENAI_MODEL
});

Expand Down
1 change: 1 addition & 0 deletions src/types/Config.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export interface IModelType extends IModelConfig {
context: string;
includeSender?: boolean;
baseModel: SupportedBaseModels;
dangerouslyAllowFewShotApproach?: boolean;
}

export interface IDefaultConfig {
Expand Down

0 comments on commit 10d2fdb

Please sign in to comment.