Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auto_dalle plugin #505

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions bot/openai_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_conversation_stats(self, chat_id: int) -> tuple[int, int]:
self.reset_chat_history(chat_id)
return len(self.conversations[chat_id]), self.__count_tokens(self.conversations[chat_id])

async def get_chat_response(self, chat_id: int, query: str) -> tuple[str, str]:
async def get_chat_response(self, bot: ChatGPTTelegramBot, tg_upd: telegram.Update, chat_id: int, query: str) -> tuple[str, str]:
"""
Gets a full response from the GPT model.
:param chat_id: The chat ID
Expand All @@ -132,7 +132,7 @@ async def get_chat_response(self, chat_id: int, query: str) -> tuple[str, str]:
plugins_used = ()
response = await self.__common_get_chat_response(chat_id, query)
if self.config['enable_functions'] and not self.conversations_vision[chat_id]:
response, plugins_used = await self.__handle_function_call(chat_id, response)
response, plugins_used = await self.__handle_function_call(bot, tg_upd, chat_id, response)
if is_direct_result(response):
return response, '0'

Expand Down Expand Up @@ -165,17 +165,19 @@ async def get_chat_response(self, chat_id: int, query: str) -> tuple[str, str]:

return answer, response.usage.total_tokens

async def get_chat_response_stream(self, chat_id: int, query: str):
async def get_chat_response_stream(self, bot: ChatGPTTelegramBot, tg_upd: telegram.Update, chat_id: int, query: str):
"""
Stream response from the GPT model.
:param chat_id: The chat ID
:param query: The query to send to the model
:return: The answer from the model and the number of tokens used, or 'not_finished'
"""
import telegram_bot
plugins_used = ()
response = await self.__common_get_chat_response(chat_id, query, stream=True)
if self.config['enable_functions'] and not self.conversations_vision[chat_id]:
response, plugins_used = await self.__handle_function_call(chat_id, response, stream=True)

if self.config['enable_functions']:
response, plugins_used = await self.__handle_function_call(bot, tg_upd, chat_id, response, stream=True)
if is_direct_result(response):
yield response, '0'
return
Expand Down Expand Up @@ -269,7 +271,7 @@ async def __common_get_chat_response(self, chat_id: int, query: str, stream=Fals
except Exception as e:
raise Exception(f"⚠️ _{localized_text('error', bot_language)}._ ⚠️\n{str(e)}") from e

async def __handle_function_call(self, chat_id, response, stream=False, times=0, plugins_used=()):
async def __handle_function_call(self, bot: ChatGPTTelegramBot, tg_upd: telegram.Update, chat_id, response, stream=False, times=0, plugins_used=()):
function_name = ''
arguments = ''
if stream:
Expand Down Expand Up @@ -301,11 +303,15 @@ async def __handle_function_call(self, chat_id, response, stream=False, times=0,
return response, plugins_used

logging.info(f'Calling function {function_name} with arguments {arguments}')
function_response = await self.plugin_manager.call_function(function_name, self, arguments)
function_response, function_response_dict = await self.plugin_manager.call_function(bot, tg_upd, chat_id, function_name, arguments)

if function_name not in plugins_used:
plugins_used += (function_name,)

# if "result" in function_response_dict and function_response_dict["result"] == "Success":
# self.__add_function_call_to_history(chat_id=chat_id, function_name=function_name, content=function_response)
# return response, plugins_used

if is_direct_result(function_response):
self.__add_function_call_to_history(chat_id=chat_id, function_name=function_name,
content=json.dumps({'result': 'Done, the content has been sent'
Expand All @@ -320,7 +326,7 @@ async def __handle_function_call(self, chat_id, response, stream=False, times=0,
function_call='auto' if times < self.config['functions_max_consecutive_calls'] else 'none',
stream=stream
)
return await self.__handle_function_call(chat_id, response, stream, times + 1, plugins_used)
return await self.__handle_function_call(bot, tg_upd, chat_id, response, stream, times + 1, plugins_used)

async def generate_image(self, prompt: str) -> tuple[str, str]:
"""
Expand Down
7 changes: 5 additions & 2 deletions bot/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from plugins.gtts_text_to_speech import GTTSTextToSpeech
from plugins.auto_tts import AutoTextToSpeech
from plugins.auto_dalle import AutoDalle
from plugins.dice import DicePlugin
from plugins.youtube_audio_extractor import YouTubeAudioExtractorPlugin
from plugins.ddg_image_search import DDGImageSearchPlugin
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(self, config):
'deepl_translate': DeeplTranslatePlugin,
'gtts_text_to_speech': GTTSTextToSpeech,
'auto_tts': AutoTextToSpeech,
'auto_dalle': AutoDalle,
'whois': WhoisPlugin,
'webshot': WebshotPlugin,
}
Expand All @@ -49,14 +51,15 @@ def get_functions_specs(self):
"""
return [spec for specs in map(lambda plugin: plugin.get_spec(), self.plugins) for spec in specs]

async def call_function(self, function_name, helper, arguments):
async def call_function(self, bot, tg_upd, chat_id, function_name, arguments):
"""
Call a function based on the name and parameters provided
"""
plugin = self.__get_plugin_by_function_name(function_name)
if not plugin:
return json.dumps({'error': f'Function {function_name} not found'})
return json.dumps(await plugin.execute(function_name, helper, **json.loads(arguments)), default=str)
result = await plugin.execute(function_name, bot, tg_upd, chat_id, **json.loads(arguments))
return json.dumps(result, default=str), result

def get_plugin_source_name(self, function_name) -> str:
"""
Expand Down
40 changes: 40 additions & 0 deletions bot/plugins/auto_dalle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import asyncio
import datetime
import tempfile
import traceback
from typing import Dict
import telegram

from .plugin import Plugin


class AutoDalle(Plugin):
"""
A plugin to generate image using Openai image generation API
"""

def get_source_name(self) -> str:
return "DALLE"

def get_spec(self) -> [Dict]:
return [{
"name": "dalle_image",
"description": "Create image from scratch based on a text prompt (DALL·E 3 and DALL·E 2). Send to user.",
"parameters": {
"type": "object",
"properties": {
"prompt": {"type": "string", "prompt": "Image description. Use English language for better results."},
},
"required": ["prompt"],
},
}]

async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
await bot.wrap_with_indicator(tg_upd, bot.image_gen(tg_upd, kwargs['prompt']), "upload_photo")
return {
'direct_result': {
'kind': 'none',
'format': '',
'value': 'none',
}
}
24 changes: 9 additions & 15 deletions bot/plugins/auto_tts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import tempfile
from typing import Dict
import telegram

from .plugin import Plugin

Expand All @@ -15,8 +16,8 @@ def get_source_name(self) -> str:

def get_spec(self) -> [Dict]:
return [{
"name": "translate_text_to_speech",
"description": "Translate text to speech using OpenAI API",
"name": "translate_text_to_speech_and_send",
"description": "Translate text to speech using OpenAI API and send result to user.",
"parameters": {
"type": "object",
"properties": {
Expand All @@ -26,19 +27,12 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
try:
bytes, text_length = await helper.generate_speech(text=kwargs['text'])
with tempfile.NamedTemporaryFile(delete=False, suffix='.opus') as temp_file:
temp_file.write(bytes.getvalue())
temp_file_path = temp_file.name
except Exception as e:
logging.exception(e)
return {"Result": "Exception: " + str(e)}
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
await bot.wrap_with_indicator(tg_upd, bot.tts_gen(tg_upd, kwargs['text']), "record_voice")
return {
'direct_result': {
'kind': 'file',
'format': 'path',
'value': temp_file_path
'kind': 'none',
'format': '',
'value': 'none',
}
}
}
3 changes: 2 additions & 1 deletion bot/plugins/crypto.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import telegram
from typing import Dict

import requests
Expand Down Expand Up @@ -26,5 +27,5 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
return requests.get(f"https://api.coincap.io/v2/rates/{kwargs['asset']}").json()
3 changes: 2 additions & 1 deletion bot/plugins/ddg_image_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import random
import telegram
from itertools import islice
from typing import Dict

Expand Down Expand Up @@ -49,7 +50,7 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
with DDGS() as ddgs:
image_type = kwargs.get('type', 'photo')
ddgs_images_gen = ddgs.images(
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/ddg_translate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import telegram
from typing import Dict

from duckduckgo_search import DDGS
Expand Down Expand Up @@ -26,6 +27,6 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
with DDGS() as ddgs:
return ddgs.translate(kwargs['text'], to=kwargs['to_language'])
3 changes: 2 additions & 1 deletion bot/plugins/ddg_web_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from itertools import islice
import telegram
from typing import Dict

from duckduckgo_search import DDGS
Expand Down Expand Up @@ -46,7 +47,7 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
with DDGS() as ddgs:
ddgs_gen = ddgs.text(
kwargs['query'],
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/deepl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Dict

import telegram
import requests

from .plugin import Plugin
Expand Down Expand Up @@ -33,7 +34,7 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
if self.api_key.endswith(':fx'):
url = "https://api-free.deepl.com/v2/translate"
else:
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/dice.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import telegram
from typing import Dict

from .plugin import Plugin
Expand Down Expand Up @@ -28,7 +29,7 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
return {
'direct_result': {
'kind': 'dice',
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/gtts_text_to_speech.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import telegram
from typing import Dict

from gtts import gTTS
Expand Down Expand Up @@ -31,7 +32,7 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
tts = gTTS(kwargs['text'], lang=kwargs.get('lang', 'en'))
output = f'gtts_{datetime.datetime.now().timestamp()}.mp3'
tts.save(output)
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/plugin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import telegram
from abc import abstractmethod, ABC
from typing import Dict

Expand All @@ -23,7 +24,7 @@ def get_spec(self) -> [Dict]:
pass

@abstractmethod
async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
"""
Execute the plugin and return a JSON serializable response
"""
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/spotify.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import telegram
from typing import Dict

import spotipy
Expand Down Expand Up @@ -111,7 +112,7 @@ def get_spec(self) -> [Dict]:
}
]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
time_range = kwargs.get('time_range', 'short_term')
limit = kwargs.get('limit', 5)

Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/weather.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import telegram
from datetime import datetime
from typing import Dict

Expand Down Expand Up @@ -57,7 +58,7 @@ def get_spec(self) -> [Dict]:
}
]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
url = f'https://api.open-meteo.com/v1/forecast' \
f'?latitude={kwargs["latitude"]}' \
f'&longitude={kwargs["longitude"]}' \
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/webshot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os, requests, random, string
import telegram
from typing import Dict
from .plugin import Plugin

Expand Down Expand Up @@ -26,7 +27,7 @@ def generate_random_string(self, length):
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
try:
image_url = f'https://image.thum.io/get/maxAge/12/width/720/{kwargs["url"]}'

Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/whois_.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import telegram
from typing import Dict
from .plugin import Plugin

Expand All @@ -24,7 +25,7 @@ def get_spec(self) -> [Dict]:
},
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
try:
whois_result = whois.query(kwargs['domain'])
if whois_result is None:
Expand Down
3 changes: 2 additions & 1 deletion bot/plugins/wolfram_alpha.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import telegram
from typing import Dict

import wolframalpha
Expand Down Expand Up @@ -32,7 +33,7 @@ def get_spec(self) -> [Dict]:
}
}]

async def execute(self, function_name, helper, **kwargs) -> Dict:
async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict:
client = wolframalpha.Client(self.app_id)
res = client.query(kwargs['query'])
try:
Expand Down
Loading