Skip to content

Commit

Permalink
Merge pull request #70 from Zain-ul-din/custom-model
Browse files Browse the repository at this point in the history
Custom model
  • Loading branch information
Zain-ul-din authored Nov 20, 2024
2 parents 80dc43e + 94ef287 commit a9d7259
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 53 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
/session
/auth_info_baileys

# meta files

pnpm-lock.yaml
package-lock.json
yarn.lock

.vscode
1 change: 1 addition & 0 deletions docs/wa-ai-bot.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The WhatsApp AI Bot is a chatbot that uses AI models APIs to generate responses to user input. The bot supports several AI models, including CHAT-GPT, DALL-E, and Stability AI, and users can also create their own models to customize the bot's behavior.
44 changes: 31 additions & 13 deletions src/baileys/handlers/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@ import { ChatGPTModel } from './../../models/OpenAIModel';
import { GeminiModel } from './../../models/GeminiModel';
import { FluxModel } from './../../models/FluxModel';
import { ENV } from '../env';

interface ModelByPrefix {
modelName: AIModels;
prefix: string;
}
import config from '../../whatsapp-ai.config';
import { CustomAIModel } from '../../models/CustomModel';

/* Declare models */
const modelTable: Record<AIModels, any> = {
ChatGPT: ENV.OPENAI_ENABLED ? new ChatGPTModel() : null,
Gemini: ENV.GEMINI_ENABLED ? new GeminiModel() : null,
FLUX: ENV.HF_ENABLED ? new FluxModel() : null,
Stability: ENV.STABILITY_ENABLED ? new StabilityModel() : null,
Dalle: null
Dalle: null,
Custom: config.models.Custom
? config.models.Custom.map((model) => new CustomAIModel(model))
: null
};

if (ENV.DALLE_ENABLED && ENV.OPENAI_ENABLED) {
Expand All @@ -32,21 +32,34 @@ if (ENV.DALLE_ENABLED && ENV.OPENAI_ENABLED) {

// handles message
export async function handleMessage({ client, msg, metadata }: MessageHandlerParams) {
const modelInfo: ModelByPrefix | undefined = Util.getModelByPrefix(
metadata.text,
metadata.fromMe
);
const modelInfo = Util.getModelByPrefix(metadata.text);

if (!modelInfo) {
if (ENV.Debug) {
console.log("[Debug] Model '" + modelInfo + "' not found");
}
return;
}

const model = modelTable[modelInfo.modelName];
let model = modelTable[modelInfo.name];
let prefix = modelInfo.name !== 'Custom' ? modelInfo.meta.prefix : '';

if (modelInfo.name === 'Custom') {
if (!modelInfo.customMeta) return;
const customModels = model as Array<CustomAIModel>;
const potentialCustomModel = customModels.find(
(model) => model.modelName === modelInfo.customMeta?.meta.modelName
);

model = potentialCustomModel;
prefix = modelInfo.customMeta.meta.prefix;
}

if (!model) {
if (ENV.Debug) {
console.log("[Debug] Model '" + modelInfo.modelName + "' is disabled or not found");
console.log(
"[Debug] Model '" + JSON.stringify(modelInfo, null, 2) + "' is disabled or not found"
);
}
return;
}
Expand All @@ -59,7 +72,12 @@ export async function handleMessage({ client, msg, metadata }: MessageHandlerPar
);

model.sendMessage(
{ sender: metadata.sender, prompt: prompt, metadata: metadata, prefix: modelInfo.prefix },
{
sender: metadata.sender,
prompt: prompt,
metadata: metadata,
prefix: prefix
},
async (res: any, err: any) => {
if (err) {
client.sendMessage(metadata.remoteJid, {
Expand Down
35 changes: 28 additions & 7 deletions src/models/BaseAiModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@ type AIMetaData = {
sender: string;
senderName?: string;
fromMe: boolean;
msgType: 'unknown' | 'text' | 'extendedText' | 'image' | 'video' | 'document' | 'contact' | 'location' | 'audio';
msgType:
| 'unknown'
| 'text'
| 'extendedText'
| 'image'
| 'video'
| 'document'
| 'contact'
| 'location'
| 'audio';
type: MessageUpsertType;
isQuoted: boolean;
quoteMetaData: {
Expand Down Expand Up @@ -63,12 +72,24 @@ abstract class AIModel<AIArguments, CallBack> {
this.iconPrefix = icon === undefined ? '' : '[' + icon + '] ';
}

public getApiKey(): string { return this.apiKey };
public sessionCreate(user: string): void { this.history[user] = [] };
public sessionRemove(user: string): void { delete this.history[user] };
public sessionExists(user: string): boolean { return this.history[user] !== undefined };
public sessionAddMessage(user: string, args: any): void { this.history[user].push(args) };
public addPrefixIcon(text: string): string { return this.iconPrefix + text };
public getApiKey(): string {
return this.apiKey;
}
public sessionCreate(user: string): void {
this.history[user] = [];
}
public sessionRemove(user: string): void {
delete this.history[user];
}
public sessionExists(user: string): boolean {
return this.history[user] !== undefined;
}
public sessionAddMessage(user: string, args: any): void {
this.history[user].push(args);
}
public addPrefixIcon(text: string): string {
return this.iconPrefix + text;
}

abstract sendMessage(args: AIArguments, handle: CallBack): Promise<any>;
}
Expand Down
97 changes: 97 additions & 0 deletions src/models/CustomModel.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/* Local modules */
import { AIModel, AIArguments, AIHandle } from './BaseAiModel';
import { ENV } from '../baileys/env';
import { GeminiModel } from './GeminiModel';
import { ChatGPTModel } from './OpenAIModel';
import { IModelType, SupportedBaseModels } from '../types/Config';
import { readFile } from 'fs/promises';

class CustomAIModel extends AIModel<AIArguments, AIHandle> {
private geminiModel: GeminiModel;
private chatGPTModel: ChatGPTModel;
private selectedBaseModel: SupportedBaseModels;
private self: IModelType;

public constructor(model: IModelType) {
const apiKeys: { [key in SupportedBaseModels]: string | undefined } = {
ChatGPT: ENV.API_KEY_OPENAI,
Gemini: ENV.API_KEY_GEMINI
};
super(apiKeys[model.baseModel], model.modelName as any);

this.self = model;
this.selectedBaseModel = model.baseModel;
this.geminiModel = new GeminiModel();
this.chatGPTModel = new ChatGPTModel();
}

private static constructInstructAblePrompt({
prompt,
instructions
}: {
prompt: string;
instructions: string;
}) {
return `
<context>
${instructions}
</context>
note!
only answer from the given context
prompt:
${prompt}
`;
}

public async sendMessage({ prompt, ...rest }: AIArguments, handle: AIHandle) {
try {
const instructions = await CustomAIModel.readContext(this.self);
const promptWithInstructions = CustomAIModel.constructInstructAblePrompt({
prompt,
instructions: instructions
});

switch (this.selectedBaseModel) {
case 'ChatGPT':
await this.chatGPTModel.sendMessage({ prompt: promptWithInstructions, ...rest }, handle);
break;
case 'Gemini':
await this.geminiModel.sendMessage({ prompt: promptWithInstructions, ...rest }, handle);
break;
}
} catch (err) {
await handle('', err as string);
}
}

// read the context
private static async readContext(model: IModelType) {
const supportedFiles = ['.txt', '.text', '.md'];
const httpProtocols = [
'https://',
'http://',
'ftp://',
'ftps://',
'sftp://',
'ssh://',
'git://',
'svn://',
'ws://',
'wss://'
];

if (httpProtocols.filter((protocol) => model.context.trim().startsWith(protocol)).length > 0) {
const res = await fetch(model.context);
return res.text();
}

if (supportedFiles.filter((fileExt) => model.context.trim().endsWith(fileExt)).length > 0) {
return readFile(model.context, { encoding: 'utf-8' });
}

// plane text
return model.context;
}
}

export { CustomAIModel };
2 changes: 1 addition & 1 deletion src/types/AiModels.d.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
export type AIModels = 'ChatGPT' | 'Gemini' | 'FLUX' | 'Stability' | 'Dalle' ;
export type AIModels = 'ChatGPT' | 'Gemini' | 'FLUX' | 'Stability' | 'Dalle' | 'Custom';
export type AIModelsName = Exclude<AIModels, 'Custom'>;
10 changes: 9 additions & 1 deletion src/types/Config.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,19 @@ export interface IModelConfig {
enable: boolean;
}

export type SupportedBaseModels = keyof Pick<
{
[key in AIModels]: string;
},
'ChatGPT' | 'Gemini'
>;

export interface IModelType extends IModelConfig {
modelName: string;
prefix: string;
context: string;
includeSender?: boolean;
baseModel: SupportedBaseModels;
}

export interface IDefaultConfig {
Expand All @@ -20,7 +28,7 @@ export interface IDefaultConfig {

export type Config = {
models: {
[key in AIModels]?: key extends 'Custom' ? Array<IModelType> | [] : IModelConfig | null;
[key in AIModels]?: key extends 'Custom' ? Array<IModelType> : IModelConfig;
};
} & {
prefix: IDefaultConfig;
Expand Down
68 changes: 49 additions & 19 deletions src/util/Util.ts
Original file line number Diff line number Diff line change
@@ -1,29 +1,60 @@
import config from '../whatsapp-ai.config';
import { existsSync, readFileSync } from 'fs';
import { IModelConfig } from '../types/Config';
import { AIModels } from '../types/AiModels';
import config from '../whatsapp-ai.config';
import { IModelConfig, IModelType } from '../types/Config';

export class Util {
public static getModelByPrefix(
message: string,
fromMe: Boolean
): { modelName: AIModels; prefix: string } | undefined {
if (fromMe) {
const defaultModelName = config.prefix.defaultModel;
const defaultModel = config.models[defaultModelName];
if (defaultModel && defaultModel.enable) {
return { modelName: defaultModelName as AIModels, prefix: defaultModel.prefix as string };
public static getModelByPrefix(message: string):
| {
name: Exclude<AIModels, 'Custom'>;
meta: IModelConfig;
}
}

| {
name: 'Custom';
customMeta: { name: string; meta: IModelType } | undefined;
}
| undefined {
// models
for (let [modelName, model] of Object.entries(config.models)) {
const currentModel = model as IModelConfig;
if (!currentModel.enable) continue;

if (
message.toLocaleLowerCase().startsWith((currentModel.prefix as string).toLocaleLowerCase())
!(model as IModelConfig).enable &&
(modelName as AIModels) != 'Custom' // ignore array
)
continue;

if ((modelName as AIModels) == 'Custom') {
return {
name: 'Custom',
customMeta: Util.getModelByCustomPrefix(message)
};
} else if (
model &&
message
.toLocaleLowerCase()
.startsWith(((model as IModelConfig)?.prefix || '').toLocaleLowerCase())
) {
return { modelName: modelName as AIModels, prefix: currentModel.prefix as string };
return {
name: modelName as Exclude<AIModels, 'Custom'>,
meta: config.models[modelName as AIModels] as IModelConfig
};
}
}

return undefined;
}

private static getModelByCustomPrefix(
message: string
): { name: string; meta: IModelType } | undefined {
if (!config.models.Custom) return undefined;
for (let model of config.models.Custom) {
if (!(model as IModelType).enable) continue;

if (message.toLocaleLowerCase().startsWith(model.prefix.toLocaleLowerCase())) {
return {
name: model.modelName,
meta: config.models.Custom.find((m) => m.modelName === model.modelName) as IModelType
};
}
}

Expand All @@ -32,7 +63,6 @@ export class Util {

public static readFile(filePath: string) {
if (!existsSync(filePath)) throw new Error(`File at path ${filePath} not found`);

return readFileSync(filePath, 'utf-8');
}
}
Loading

0 comments on commit a9d7259

Please sign in to comment.