Skip to content

Commit

Permalink
Merge branch 'vicentefelipechile-master'
Browse files Browse the repository at this point in the history
  • Loading branch information
Zain-ul-din committed Nov 24, 2024
2 parents acffc6a + 625d9c1 commit 3a2d54c
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 45 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

DEBUG=False
PROCESSING="Thinking..."
IGNORE_SELF_MESSAGES=False

# Model Services
API_KEY_OPENAI=ADD_YOUR_KEY
Expand Down
2 changes: 2 additions & 0 deletions src/baileys/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ configEnv();
interface EnvInterface {
Debug: boolean;
Processing: string;
IGNORE_SELF_MESSAGES: boolean;

// Model Services
API_KEY_OPENAI?: string;
Expand Down Expand Up @@ -52,6 +53,7 @@ interface EnvInterface {
export const ENV: EnvInterface = {
Debug: process.env.DEBUG === 'True',
Processing: process.env.PROCESSING || 'Processing...',
IGNORE_SELF_MESSAGES: process.env.IGNORE_SELF_MESSAGES === 'True',

API_KEY_OPENAI: process.env.API_KEY_OPENAI,
API_KEY_OPENAI_DALLE: process.env.API_KEY_OPENAI_DALLE || process.env.API_KEY_OPENAI,
Expand Down
9 changes: 1 addition & 8 deletions src/baileys/handlers/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,4 @@ export async function handleMessage({ client, msg, metadata }: MessageHandlerPar
}
}
);
}

// handles message from self
export async function handlerMessageFromMe({ metadata, client, msg, type }: MessageHandlerParams) {
// if (metadata.fromMe && metadata.isQuoted) return;
// if (metadata.isQuoted && Util.getModelByPrefix(metadata.text)) return;
await handleMessage({ metadata, client, msg, type });
}
}
8 changes: 4 additions & 4 deletions src/baileys/handlers/messages.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import makeWASocket, { MessageUpsertType, WAMessage } from '@whiskeysockets/baileys';
import useMessageParser from '../hooks/useMessageParser';
import { handleMessage, handlerMessageFromMe } from './message';
import { handleMessage } from './message';
import { ENV } from '../env';

export function messagesHandler({
client,
Expand All @@ -21,10 +22,9 @@ export function messagesHandler({
if (!metadata) return;
if (metadata.msgType === 'unknown') return;
if (metadata.isGroup && metadata.groupMetaData.groupIsLocked) return;
if (metadata.fromMe && ENV.IGNORE_SELF_MESSAGES) return;

await (metadata.fromMe
? handlerMessageFromMe({ client, msg, metadata, type })
: handleMessage({ client, msg, metadata, type }));
await handleMessage({ client, msg, metadata, type });
} catch (_) {}
})
);
Expand Down
25 changes: 15 additions & 10 deletions src/baileys/hooks/useMessageParser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,15 @@ export default async function useMessageParser(
senderName: message.pushName,
fromMe,
msgType,
message,
type,
isQuoted: false,
quoteMetaData: {
remoteJid: '',
message: {},
text: '',
type: 'text',
hasImage: false,
imgMetaData: {
url: '',
mimeType: '',
Expand All @@ -71,11 +73,13 @@ export default async function useMessageParser(
groupName: '',
groupIsLocked: false
},
hasImage: false,
imgMetaData: {
url: '',
mimeType: '',
caption: ''
},
hasAudio: false,
audioMetaData: {
url: '',
mimeType: ''
Expand All @@ -84,20 +88,18 @@ export default async function useMessageParser(

// Handle image messages
if (imageMessage) {
metaData.msgType = 'image';
if (imageMessage.url) metaData.imgMetaData.url = imageMessage.url;
if (imageMessage.mimetype) metaData.imgMetaData.mimeType = imageMessage.mimetype;
if (imageMessage.caption) {
metaData.imgMetaData.caption = imageMessage.caption;
metaData.text = imageMessage.caption;
}
metaData.hasImage = true;
metaData.imgMetaData.url = imageMessage.url || '';
metaData.imgMetaData.mimeType = imageMessage.mimetype || '';
metaData.imgMetaData.caption = imageMessage.caption || '';
metaData.text = imageMessage.caption || '';
}

// Handle audio messages
if (audioMessage) {
metaData.msgType = 'audio';
if (audioMessage.url) metaData.audioMetaData.url = audioMessage.url;
if (audioMessage.mimetype) metaData.audioMetaData.mimeType = audioMessage.mimetype;
metaData.hasAudio = true;
metaData.audioMetaData.url = audioMessage.url || '';
metaData.audioMetaData.mimeType = audioMessage.mimetype || '';
}

// gather context info
Expand All @@ -112,6 +114,7 @@ export default async function useMessageParser(
metaData.quoteMetaData.message = contextInfo.quotedMessage || {};

if (contextInfo.quotedMessage.imageMessage) {
metaData.quoteMetaData.hasImage = true;
metaData.quoteMetaData.type = 'image';
metaData.quoteMetaData.imgMetaData.url = contextInfo.quotedMessage.imageMessage.url || '';
metaData.quoteMetaData.imgMetaData.mimeType =
Expand All @@ -129,5 +132,7 @@ export default async function useMessageParser(
metaData.groupMetaData.groupIsLocked = groupMetadata.restrict || false;
}

metaData.message = message;

return metaData;
}
4 changes: 4 additions & 0 deletions src/models/BaseAiModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ type AIMetaData = {
sender: string;
senderName?: string;
fromMe: boolean;
message: any;
msgType:
| 'unknown'
| 'text'
Expand All @@ -25,6 +26,7 @@ type AIMetaData = {
message: any;
text: string;
type: 'text' | 'image';
hasImage: boolean;
imgMetaData: {
url: string;
mimeType: string;
Expand All @@ -38,11 +40,13 @@ type AIMetaData = {
groupName: string;
groupIsLocked: boolean;
};
hasImage: boolean;
imgMetaData: {
url: string;
mimeType: string;
caption: string;
};
hasAudio: boolean;
audioMetaData: {
url: string;
mimeType: string;
Expand Down
69 changes: 46 additions & 23 deletions src/models/GeminiModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ import { AIModel, AIArguments, AIHandle, AIMetaData } from './BaseAiModel';
import { ENV } from '../baileys/env';
import invariant from 'invariant';

interface imgMetaData {
url: string;
mimeType: string;
caption: string;
}

const validMimeTypes = ['image/jpeg', 'image/png', 'image/jpg', 'image/webp'];

/* Gemini Model */
class GeminiModel extends AIModel<AIArguments, AIHandle> {
/* Variables */
Expand All @@ -20,6 +28,14 @@ class GeminiModel extends AIModel<AIArguments, AIHandle> {
public chats: { [from: string]: ChatSession };
public instructions: string | undefined;

private initGenerativeModel() {
// https://ai.google.dev/gemini-api/docs/models/gemini
this.generativeModel = this.Gemini.getGenerativeModel({
model: 'gemini-1.5-flash',
systemInstruction: this.instructions
});
}

public constructor() {
super(ENV.API_KEY_GEMINI, 'Gemini', ENV.GEMINI_ICON_PREFIX);
this.Gemini = new GoogleGenerativeAI(ENV.API_KEY_GEMINI as string);
Expand All @@ -46,48 +62,55 @@ class GeminiModel extends AIModel<AIArguments, AIHandle> {
};
}

private initGenerativeModel() {
// https://ai.google.dev/gemini-api/docs/models/gemini
this.generativeModel = this.Gemini.getGenerativeModel({
model: 'gemini-1.5-flash',
systemInstruction: this.instructions
});
}

public async generateImageCompletion(prompt: string, metadata: AIMetaData): Promise<string> {

const { mimeType } = metadata.quoteMetaData.imgMetaData;
if (mimeType === 'image/jpeg') {
const buffer = await downloadMediaMessage(
{ message: metadata.quoteMetaData.message } as any,
'buffer',
{}
);
public async generateImageCompletion(
prompt: string,
imgMetaData: imgMetaData,
message: any
): Promise<string> {
const { mimeType } = imgMetaData;
if (validMimeTypes.includes(mimeType)) {
const buffer = await downloadMediaMessage({ message: message } as any, 'buffer', {});
const imageParts = this.createGenerativeContent(buffer, mimeType);
const result = await this.generativeModel!.generateContent([prompt, imageParts]);
const resultText = result.response.text();

return resultText;
}

return '';
return 'The image is not a valid image type.';
}

async sendMessage({ sender, prompt, metadata }: AIArguments, handle: AIHandle) {
this.initGenerativeModel();
invariant(this.generativeModel, 'Unable to initialize Gemini Generative model');

try {
let message = '';
if (metadata.isQuoted) {
if (metadata.quoteMetaData.type === 'image') {
message = this.iconPrefix + (await this.generateImageCompletion(prompt, metadata));
if (metadata.quoteMetaData.hasImage) {
message =
this.iconPrefix +
(await this.generateImageCompletion(
prompt,
metadata.quoteMetaData.imgMetaData,
metadata.quoteMetaData.message
));
} else {
prompt = 'Quoted Message:\n' + metadata.quoteMetaData.text + '---\nMessage:\n' + prompt;
prompt = 'Quoted Message:\n' + metadata.quoteMetaData.text + '\n---\nMessage:\n' + prompt;
message = this.iconPrefix + (await this.generateCompletion(sender, prompt));
}
} else {
message = this.iconPrefix + (await this.generateCompletion(sender, prompt));
if (metadata.hasImage) {
message =
this.iconPrefix +
(await this.generateImageCompletion(
prompt,
metadata.imgMetaData,
metadata.message.message
));
} else {
message = this.iconPrefix + (await this.generateCompletion(sender, prompt));
}
}

handle({ text: message });
Expand Down

0 comments on commit 3a2d54c

Please sign in to comment.