From b9c9ddafe38caa6469018d3d51d04f0005d9e9fa Mon Sep 17 00:00:00 2001 From: Jipok Date: Mon, 18 Dec 2023 02:06:11 +0500 Subject: [PATCH] Add auto_dalle plugin --- bot/openai_helper.py | 22 ++-- bot/plugin_manager.py | 7 +- bot/plugins/auto_dalle.py | 40 ++++++ bot/plugins/auto_tts.py | 24 ++-- bot/plugins/crypto.py | 3 +- bot/plugins/ddg_image_search.py | 3 +- bot/plugins/ddg_translate.py | 3 +- bot/plugins/ddg_web_search.py | 3 +- bot/plugins/deepl.py | 3 +- bot/plugins/dice.py | 3 +- bot/plugins/gtts_text_to_speech.py | 3 +- bot/plugins/plugin.py | 3 +- bot/plugins/spotify.py | 3 +- bot/plugins/weather.py | 3 +- bot/plugins/webshot.py | 3 +- bot/plugins/whois_.py | 3 +- bot/plugins/wolfram_alpha.py | 3 +- bot/plugins/worldtimeapi.py | 3 +- bot/plugins/youtube_audio_extractor.py | 3 +- bot/telegram_bot.py | 162 ++++++++++++++----------- bot/utils.py | 17 --- 21 files changed, 192 insertions(+), 125 deletions(-) create mode 100644 bot/plugins/auto_dalle.py diff --git a/bot/openai_helper.py b/bot/openai_helper.py index 951d745b..0d300fd8 100644 --- a/bot/openai_helper.py +++ b/bot/openai_helper.py @@ -122,7 +122,7 @@ def get_conversation_stats(self, chat_id: int) -> tuple[int, int]: self.reset_chat_history(chat_id) return len(self.conversations[chat_id]), self.__count_tokens(self.conversations[chat_id]) - async def get_chat_response(self, chat_id: int, query: str) -> tuple[str, str]: + async def get_chat_response(self, bot: ChatGPTTelegramBot, tg_upd: telegram.Update, chat_id: int, query: str) -> tuple[str, str]: """ Gets a full response from the GPT model. :param chat_id: The chat ID @@ -132,7 +132,7 @@ async def get_chat_response(self, chat_id: int, query: str) -> tuple[str, str]: plugins_used = () response = await self.__common_get_chat_response(chat_id, query) if self.config['enable_functions'] and not self.conversations_vision[chat_id]: - response, plugins_used = await self.__handle_function_call(chat_id, response) + response, plugins_used = await self.__handle_function_call(bot, tg_upd, chat_id, response) if is_direct_result(response): return response, '0' @@ -165,17 +165,19 @@ async def get_chat_response(self, chat_id: int, query: str) -> tuple[str, str]: return answer, response.usage.total_tokens - async def get_chat_response_stream(self, chat_id: int, query: str): + async def get_chat_response_stream(self, bot: ChatGPTTelegramBot, tg_upd: telegram.Update, chat_id: int, query: str): """ Stream response from the GPT model. :param chat_id: The chat ID :param query: The query to send to the model :return: The answer from the model and the number of tokens used, or 'not_finished' """ + import telegram_bot plugins_used = () response = await self.__common_get_chat_response(chat_id, query, stream=True) - if self.config['enable_functions'] and not self.conversations_vision[chat_id]: - response, plugins_used = await self.__handle_function_call(chat_id, response, stream=True) + + if self.config['enable_functions']: + response, plugins_used = await self.__handle_function_call(bot, tg_upd, chat_id, response, stream=True) if is_direct_result(response): yield response, '0' return @@ -269,7 +271,7 @@ async def __common_get_chat_response(self, chat_id: int, query: str, stream=Fals except Exception as e: raise Exception(f"⚠️ _{localized_text('error', bot_language)}._ ⚠️\n{str(e)}") from e - async def __handle_function_call(self, chat_id, response, stream=False, times=0, plugins_used=()): + async def __handle_function_call(self, bot: ChatGPTTelegramBot, tg_upd: telegram.Update, chat_id, response, stream=False, times=0, plugins_used=()): function_name = '' arguments = '' if stream: @@ -301,11 +303,15 @@ async def __handle_function_call(self, chat_id, response, stream=False, times=0, return response, plugins_used logging.info(f'Calling function {function_name} with arguments {arguments}') - function_response = await self.plugin_manager.call_function(function_name, self, arguments) + function_response, function_response_dict = await self.plugin_manager.call_function(bot, tg_upd, chat_id, function_name, arguments) if function_name not in plugins_used: plugins_used += (function_name,) + # if "result" in function_response_dict and function_response_dict["result"] == "Success": + # self.__add_function_call_to_history(chat_id=chat_id, function_name=function_name, content=function_response) + # return response, plugins_used + if is_direct_result(function_response): self.__add_function_call_to_history(chat_id=chat_id, function_name=function_name, content=json.dumps({'result': 'Done, the content has been sent' @@ -320,7 +326,7 @@ async def __handle_function_call(self, chat_id, response, stream=False, times=0, function_call='auto' if times < self.config['functions_max_consecutive_calls'] else 'none', stream=stream ) - return await self.__handle_function_call(chat_id, response, stream, times + 1, plugins_used) + return await self.__handle_function_call(bot, tg_upd, chat_id, response, stream, times + 1, plugins_used) async def generate_image(self, prompt: str) -> tuple[str, str]: """ diff --git a/bot/plugin_manager.py b/bot/plugin_manager.py index 370b1c3e..d0b07e0f 100644 --- a/bot/plugin_manager.py +++ b/bot/plugin_manager.py @@ -2,6 +2,7 @@ from plugins.gtts_text_to_speech import GTTSTextToSpeech from plugins.auto_tts import AutoTextToSpeech +from plugins.auto_dalle import AutoDalle from plugins.dice import DicePlugin from plugins.youtube_audio_extractor import YouTubeAudioExtractorPlugin from plugins.ddg_image_search import DDGImageSearchPlugin @@ -38,6 +39,7 @@ def __init__(self, config): 'deepl_translate': DeeplTranslatePlugin, 'gtts_text_to_speech': GTTSTextToSpeech, 'auto_tts': AutoTextToSpeech, + 'auto_dalle': AutoDalle, 'whois': WhoisPlugin, 'webshot': WebshotPlugin, } @@ -49,14 +51,15 @@ def get_functions_specs(self): """ return [spec for specs in map(lambda plugin: plugin.get_spec(), self.plugins) for spec in specs] - async def call_function(self, function_name, helper, arguments): + async def call_function(self, bot, tg_upd, chat_id, function_name, arguments): """ Call a function based on the name and parameters provided """ plugin = self.__get_plugin_by_function_name(function_name) if not plugin: return json.dumps({'error': f'Function {function_name} not found'}) - return json.dumps(await plugin.execute(function_name, helper, **json.loads(arguments)), default=str) + result = await plugin.execute(function_name, bot, tg_upd, chat_id, **json.loads(arguments)) + return json.dumps(result, default=str), result def get_plugin_source_name(self, function_name) -> str: """ diff --git a/bot/plugins/auto_dalle.py b/bot/plugins/auto_dalle.py new file mode 100644 index 00000000..9342906f --- /dev/null +++ b/bot/plugins/auto_dalle.py @@ -0,0 +1,40 @@ +import asyncio +import datetime +import tempfile +import traceback +from typing import Dict +import telegram + +from .plugin import Plugin + + +class AutoDalle(Plugin): + """ + A plugin to generate image using Openai image generation API + """ + + def get_source_name(self) -> str: + return "DALLE" + + def get_spec(self) -> [Dict]: + return [{ + "name": "dalle_image", + "description": "Create image from scratch based on a text prompt (DALL·E 3 and DALL·E 2). Send to user.", + "parameters": { + "type": "object", + "properties": { + "prompt": {"type": "string", "prompt": "Image description. Use English language for better results."}, + }, + "required": ["prompt"], + }, + }] + + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: + await bot.wrap_with_indicator(tg_upd, bot.image_gen(tg_upd, kwargs['prompt']), "upload_photo") + return { + 'direct_result': { + 'kind': 'none', + 'format': '', + 'value': 'none', + } + } \ No newline at end of file diff --git a/bot/plugins/auto_tts.py b/bot/plugins/auto_tts.py index 7118b963..876dc951 100644 --- a/bot/plugins/auto_tts.py +++ b/bot/plugins/auto_tts.py @@ -1,6 +1,7 @@ import datetime import tempfile from typing import Dict +import telegram from .plugin import Plugin @@ -15,8 +16,8 @@ def get_source_name(self) -> str: def get_spec(self) -> [Dict]: return [{ - "name": "translate_text_to_speech", - "description": "Translate text to speech using OpenAI API", + "name": "translate_text_to_speech_and_send", + "description": "Translate text to speech using OpenAI API and send result to user.", "parameters": { "type": "object", "properties": { @@ -26,19 +27,12 @@ def get_spec(self) -> [Dict]: }, }] - async def execute(self, function_name, helper, **kwargs) -> Dict: - try: - bytes, text_length = await helper.generate_speech(text=kwargs['text']) - with tempfile.NamedTemporaryFile(delete=False, suffix='.opus') as temp_file: - temp_file.write(bytes.getvalue()) - temp_file_path = temp_file.name - except Exception as e: - logging.exception(e) - return {"Result": "Exception: " + str(e)} + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: + await bot.wrap_with_indicator(tg_upd, bot.tts_gen(tg_upd, kwargs['text']), "record_voice") return { 'direct_result': { - 'kind': 'file', - 'format': 'path', - 'value': temp_file_path + 'kind': 'none', + 'format': '', + 'value': 'none', } - } + } \ No newline at end of file diff --git a/bot/plugins/crypto.py b/bot/plugins/crypto.py index 5559a22c..cb957a20 100644 --- a/bot/plugins/crypto.py +++ b/bot/plugins/crypto.py @@ -1,3 +1,4 @@ +import telegram from typing import Dict import requests @@ -26,5 +27,5 @@ def get_spec(self) -> [Dict]: }, }] - async def execute(self, function_name, helper, **kwargs) -> Dict: + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: return requests.get(f"https://api.coincap.io/v2/rates/{kwargs['asset']}").json() diff --git a/bot/plugins/ddg_image_search.py b/bot/plugins/ddg_image_search.py index a1db0353..0202e684 100644 --- a/bot/plugins/ddg_image_search.py +++ b/bot/plugins/ddg_image_search.py @@ -1,5 +1,6 @@ import os import random +import telegram from itertools import islice from typing import Dict @@ -49,7 +50,7 @@ def get_spec(self) -> [Dict]: }, }] - async def execute(self, function_name, helper, **kwargs) -> Dict: + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: with DDGS() as ddgs: image_type = kwargs.get('type', 'photo') ddgs_images_gen = ddgs.images( diff --git a/bot/plugins/ddg_translate.py b/bot/plugins/ddg_translate.py index 294ff05e..c7de3238 100644 --- a/bot/plugins/ddg_translate.py +++ b/bot/plugins/ddg_translate.py @@ -1,3 +1,4 @@ +import telegram from typing import Dict from duckduckgo_search import DDGS @@ -26,6 +27,6 @@ def get_spec(self) -> [Dict]: }, }] - async def execute(self, function_name, helper, **kwargs) -> Dict: + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: with DDGS() as ddgs: return ddgs.translate(kwargs['text'], to=kwargs['to_language']) diff --git a/bot/plugins/ddg_web_search.py b/bot/plugins/ddg_web_search.py index 07060a90..fa1e247c 100644 --- a/bot/plugins/ddg_web_search.py +++ b/bot/plugins/ddg_web_search.py @@ -1,5 +1,6 @@ import os from itertools import islice +import telegram from typing import Dict from duckduckgo_search import DDGS @@ -46,7 +47,7 @@ def get_spec(self) -> [Dict]: }, }] - async def execute(self, function_name, helper, **kwargs) -> Dict: + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: with DDGS() as ddgs: ddgs_gen = ddgs.text( kwargs['query'], diff --git a/bot/plugins/deepl.py b/bot/plugins/deepl.py index d2236c20..692d3750 100644 --- a/bot/plugins/deepl.py +++ b/bot/plugins/deepl.py @@ -1,6 +1,7 @@ import os from typing import Dict +import telegram import requests from .plugin import Plugin @@ -33,7 +34,7 @@ def get_spec(self) -> [Dict]: }, }] - async def execute(self, function_name, helper, **kwargs) -> Dict: + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: if self.api_key.endswith(':fx'): url = "https://api-free.deepl.com/v2/translate" else: diff --git a/bot/plugins/dice.py b/bot/plugins/dice.py index 1f2a0908..f9862ea4 100644 --- a/bot/plugins/dice.py +++ b/bot/plugins/dice.py @@ -1,3 +1,4 @@ +import telegram from typing import Dict from .plugin import Plugin @@ -28,7 +29,7 @@ def get_spec(self) -> [Dict]: }, }] - async def execute(self, function_name, **kwargs) -> Dict: + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: return { 'direct_result': { 'kind': 'dice', diff --git a/bot/plugins/gtts_text_to_speech.py b/bot/plugins/gtts_text_to_speech.py index 544de9f6..78329c71 100644 --- a/bot/plugins/gtts_text_to_speech.py +++ b/bot/plugins/gtts_text_to_speech.py @@ -1,4 +1,5 @@ import datetime +import telegram from typing import Dict from gtts import gTTS @@ -31,7 +32,7 @@ def get_spec(self) -> [Dict]: }, }] - async def execute(self, function_name, helper, **kwargs) -> Dict: + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: tts = gTTS(kwargs['text'], lang=kwargs.get('lang', 'en')) output = f'gtts_{datetime.datetime.now().timestamp()}.mp3' tts.save(output) diff --git a/bot/plugins/plugin.py b/bot/plugins/plugin.py index c9c734dd..dcd62a7d 100644 --- a/bot/plugins/plugin.py +++ b/bot/plugins/plugin.py @@ -1,3 +1,4 @@ +import telegram from abc import abstractmethod, ABC from typing import Dict @@ -23,7 +24,7 @@ def get_spec(self) -> [Dict]: pass @abstractmethod - async def execute(self, function_name, helper, **kwargs) -> Dict: + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: """ Execute the plugin and return a JSON serializable response """ diff --git a/bot/plugins/spotify.py b/bot/plugins/spotify.py index a578ed2c..f6f18bbb 100644 --- a/bot/plugins/spotify.py +++ b/bot/plugins/spotify.py @@ -1,4 +1,5 @@ import os +import telegram from typing import Dict import spotipy @@ -111,7 +112,7 @@ def get_spec(self) -> [Dict]: } ] - async def execute(self, function_name, helper, **kwargs) -> Dict: + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: time_range = kwargs.get('time_range', 'short_term') limit = kwargs.get('limit', 5) diff --git a/bot/plugins/weather.py b/bot/plugins/weather.py index 7b2b1f29..3efc21a5 100644 --- a/bot/plugins/weather.py +++ b/bot/plugins/weather.py @@ -1,3 +1,4 @@ +import telegram from datetime import datetime from typing import Dict @@ -57,7 +58,7 @@ def get_spec(self) -> [Dict]: } ] - async def execute(self, function_name, helper, **kwargs) -> Dict: + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: url = f'https://api.open-meteo.com/v1/forecast' \ f'?latitude={kwargs["latitude"]}' \ f'&longitude={kwargs["longitude"]}' \ diff --git a/bot/plugins/webshot.py b/bot/plugins/webshot.py index fa925629..a5f81dae 100644 --- a/bot/plugins/webshot.py +++ b/bot/plugins/webshot.py @@ -1,4 +1,5 @@ import os, requests, random, string +import telegram from typing import Dict from .plugin import Plugin @@ -26,7 +27,7 @@ def generate_random_string(self, length): characters = string.ascii_letters + string.digits return ''.join(random.choice(characters) for _ in range(length)) - async def execute(self, function_name, helper, **kwargs) -> Dict: + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: try: image_url = f'https://image.thum.io/get/maxAge/12/width/720/{kwargs["url"]}' diff --git a/bot/plugins/whois_.py b/bot/plugins/whois_.py index 91b81b5f..99866988 100644 --- a/bot/plugins/whois_.py +++ b/bot/plugins/whois_.py @@ -1,3 +1,4 @@ +import telegram from typing import Dict from .plugin import Plugin @@ -24,7 +25,7 @@ def get_spec(self) -> [Dict]: }, }] - async def execute(self, function_name, helper, **kwargs) -> Dict: + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: try: whois_result = whois.query(kwargs['domain']) if whois_result is None: diff --git a/bot/plugins/wolfram_alpha.py b/bot/plugins/wolfram_alpha.py index b151ae03..9c5e870f 100644 --- a/bot/plugins/wolfram_alpha.py +++ b/bot/plugins/wolfram_alpha.py @@ -1,4 +1,5 @@ import os +import telegram from typing import Dict import wolframalpha @@ -32,7 +33,7 @@ def get_spec(self) -> [Dict]: } }] - async def execute(self, function_name, helper, **kwargs) -> Dict: + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: client = wolframalpha.Client(self.app_id) res = client.query(kwargs['query']) try: diff --git a/bot/plugins/worldtimeapi.py b/bot/plugins/worldtimeapi.py index 2e09e366..7fc72339 100644 --- a/bot/plugins/worldtimeapi.py +++ b/bot/plugins/worldtimeapi.py @@ -1,4 +1,5 @@ import os, requests +import telegram from typing import Dict from datetime import datetime @@ -35,7 +36,7 @@ def get_spec(self) -> [Dict]: }, }] - async def execute(self, function_name, helper, **kwargs) -> Dict: + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: timezone = kwargs.get('timezone', self.default_timezone) url = f'https://worldtimeapi.org/api/timezone/{timezone}' diff --git a/bot/plugins/youtube_audio_extractor.py b/bot/plugins/youtube_audio_extractor.py index 804663a8..d7a546b4 100644 --- a/bot/plugins/youtube_audio_extractor.py +++ b/bot/plugins/youtube_audio_extractor.py @@ -1,5 +1,6 @@ import logging import re +import telegram from typing import Dict from pytube import YouTube @@ -28,7 +29,7 @@ def get_spec(self) -> [Dict]: }, }] - async def execute(self, function_name, helper, **kwargs) -> Dict: + async def execute(self, function_name, bot, tg_upd: telegram.Update, chat_id, **kwargs) -> Dict: link = kwargs['youtube_link'] try: video = YouTube(link) diff --git a/bot/telegram_bot.py b/bot/telegram_bot.py index 7a536b1f..17e5fd6f 100644 --- a/bot/telegram_bot.py +++ b/bot/telegram_bot.py @@ -16,7 +16,7 @@ from pydub import AudioSegment from PIL import Image -from utils import is_group_chat, get_thread_id, message_text, wrap_with_indicator, split_into_chunks, \ +from utils import is_group_chat, get_thread_id, message_text, split_into_chunks, \ edit_message_with_retry, get_stream_cutoff_values, is_allowed, get_remaining_budget, is_admin, is_within_budget, \ get_reply_to_message_id, add_chat_request_to_usage_tracker, error_handler, is_direct_result, handle_direct_result, \ cleanup_intermediate_files @@ -233,10 +233,45 @@ async def reset(self, update: Update, context: ContextTypes.DEFAULT_TYPE): text=localized_text('reset_done', self.config['bot_language']) ) - async def image(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + async def image_gen(self, update, image_query): """ Generates an image for the given prompt using DALL·E APIs """ + try: + image_url, image_size = await self.openai.generate_image(prompt=image_query) + if self.config['image_receive_mode'] == 'photo': + await update.effective_message.reply_photo( + reply_to_message_id=get_reply_to_message_id(self.config, update), + photo=image_url + ) + elif self.config['image_receive_mode'] == 'document': + await update.effective_message.reply_document( + reply_to_message_id=get_reply_to_message_id(self.config, update), + document=image_url + ) + else: + raise Exception(f"env variable IMAGE_RECEIVE_MODE has invalid value {self.config['image_receive_mode']}") + # add image request to users usage tracker + user_id = update.message.from_user.id + self.usage[user_id].add_image_request(image_size, self.config['image_prices']) + # add guest chat request to guest usage tracker + if str(user_id) not in self.config['allowed_user_ids'].split(',') and 'guests' in self.usage: + self.usage["guests"].add_image_request(image_size, self.config['image_prices']) + + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=f"{localized_text('image_fail', self.config['bot_language'])}: {str(e)}", + parse_mode=constants.ParseMode.MARKDOWN + ) + + + async def image(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + """ + User command wrapper. Generates an image for the given prompt using DALL·E APIs + """ if not self.config['enable_image_generation'] \ or not await self.check_allowed_and_within_budget(update, context): return @@ -252,42 +287,41 @@ async def image(self, update: Update, context: ContextTypes.DEFAULT_TYPE): logging.info(f'New image generation request received from user {update.message.from_user.name} ' f'(id: {update.message.from_user.id})') - async def _generate(): - try: - image_url, image_size = await self.openai.generate_image(prompt=image_query) - if self.config['image_receive_mode'] == 'photo': - await update.effective_message.reply_photo( - reply_to_message_id=get_reply_to_message_id(self.config, update), - photo=image_url - ) - elif self.config['image_receive_mode'] == 'document': - await update.effective_message.reply_document( - reply_to_message_id=get_reply_to_message_id(self.config, update), - document=image_url - ) - else: - raise Exception(f"env variable IMAGE_RECEIVE_MODE has invalid value {self.config['image_receive_mode']}") - # add image request to users usage tracker - user_id = update.message.from_user.id - self.usage[user_id].add_image_request(image_size, self.config['image_prices']) - # add guest chat request to guest usage tracker - if str(user_id) not in self.config['allowed_user_ids'].split(',') and 'guests' in self.usage: - self.usage["guests"].add_image_request(image_size, self.config['image_prices']) + await self.wrap_with_indicator(update, self.image_gen(update, image_query), constants.ChatAction.UPLOAD_PHOTO) - except Exception as e: - logging.exception(e) - await update.effective_message.reply_text( - message_thread_id=get_thread_id(update), - reply_to_message_id=get_reply_to_message_id(self.config, update), - text=f"{localized_text('image_fail', self.config['bot_language'])}: {str(e)}", - parse_mode=constants.ParseMode.MARKDOWN - ) - await wrap_with_indicator(update, context, _generate, constants.ChatAction.UPLOAD_PHOTO) + async def tts_gen(self, update: Update, tts_query: str): + """ + Generates an speech for the given input using TTS APIs + """ + try: + speech_file, text_length = await self.openai.generate_speech(text=tts_query) + + await update.effective_message.reply_voice( + reply_to_message_id=get_reply_to_message_id(self.config, update), + voice=speech_file + ) + speech_file.close() + # add image request to users usage tracker + user_id = update.message.from_user.id + self.usage[user_id].add_tts_request(text_length, self.config['tts_model'], self.config['tts_prices']) + # add guest chat request to guest usage tracker + if str(user_id) not in self.config['allowed_user_ids'].split(',') and 'guests' in self.usage: + self.usage["guests"].add_tts_request(text_length, self.config['tts_model'], self.config['tts_prices']) + + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=f"{localized_text('tts_fail', self.config['bot_language'])}: {str(e)}", + parse_mode=constants.ParseMode.MARKDOWN + ) + async def tts(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ - Generates an speech for the given input using TTS APIs + User command wrapper. Generates an speech for the given input using TTS APIs """ if not self.config['enable_tts_generation'] \ or not await self.check_allowed_and_within_budget(update, context): @@ -304,32 +338,7 @@ async def tts(self, update: Update, context: ContextTypes.DEFAULT_TYPE): logging.info(f'New speech generation request received from user {update.message.from_user.name} ' f'(id: {update.message.from_user.id})') - async def _generate(): - try: - speech_file, text_length = await self.openai.generate_speech(text=tts_query) - - await update.effective_message.reply_voice( - reply_to_message_id=get_reply_to_message_id(self.config, update), - voice=speech_file - ) - speech_file.close() - # add image request to users usage tracker - user_id = update.message.from_user.id - self.usage[user_id].add_tts_request(text_length, self.config['tts_model'], self.config['tts_prices']) - # add guest chat request to guest usage tracker - if str(user_id) not in self.config['allowed_user_ids'].split(',') and 'guests' in self.usage: - self.usage["guests"].add_tts_request(text_length, self.config['tts_model'], self.config['tts_prices']) - - except Exception as e: - logging.exception(e) - await update.effective_message.reply_text( - message_thread_id=get_thread_id(update), - reply_to_message_id=get_reply_to_message_id(self.config, update), - text=f"{localized_text('tts_fail', self.config['bot_language'])}: {str(e)}", - parse_mode=constants.ParseMode.MARKDOWN - ) - - await wrap_with_indicator(update, context, _generate, constants.ChatAction.UPLOAD_VOICE) + await self.wrap_with_indicator(update, self.tts_gen(update, tts_query), constants.ChatAction.UPLOAD_VOICE) async def transcribe(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ @@ -414,7 +423,7 @@ async def _execute(): ) else: # Get the response of the transcript - response, total_tokens = await self.openai.get_chat_response(chat_id=chat_id, query=transcript) + response, total_tokens = await self.openai.get_chat_response(chat_id=chat_id, bot=self, tg_upd=update, query=transcript) self.usage[user_id].add_chat_tokens(total_tokens, self.config['token_price']) if str(user_id) not in allowed_user_ids and 'guests' in self.usage: @@ -449,7 +458,7 @@ async def _execute(): if os.path.exists(filename): os.remove(filename) - await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING) + await self.wrap_with_indicator(update, _execute(), constants.ChatAction.TYPING) async def vision(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ @@ -641,7 +650,7 @@ async def _execute(): if str(user_id) not in allowed_user_ids and 'guests' in self.usage: self.usage["guests"].add_vision_tokens(total_tokens, vision_token_price) - await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING) + await self.wrap_with_indicator(update, _execute(), constants.ChatAction.TYPING) async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ @@ -687,7 +696,7 @@ async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): message_thread_id=get_thread_id(update) ) - stream_response = self.openai.get_chat_response_stream(chat_id=chat_id, query=prompt) + stream_response = self.openai.get_chat_response_stream(chat_id=chat_id, bot=self, tg_upd = update, query=prompt) i = 0 prev = '' sent_message = None @@ -767,7 +776,7 @@ async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): else: async def _reply(): nonlocal total_tokens - response, total_tokens = await self.openai.get_chat_response(chat_id=chat_id, query=prompt) + response, total_tokens = await self.openai.get_chat_response(chat_id=chat_id, bot=self, tg_upd=update, query=prompt) if is_direct_result(response): return await handle_direct_result(self.config, update, response) @@ -795,7 +804,7 @@ async def _reply(): except Exception as exception: raise exception - await wrap_with_indicator(update, context, _reply, constants.ChatAction.TYPING) + await self.wrap_with_indicator(update, _reply(), constants.ChatAction.TYPING) add_chat_request_to_usage_tracker(self.usage, self.config, user_id, total_tokens) @@ -887,7 +896,7 @@ async def handle_callback_inline_query(self, update: Update, context: CallbackCo unavailable_message = localized_text("function_unavailable_in_inline_mode", bot_language) if self.config['stream']: - stream_response = self.openai.get_chat_response_stream(chat_id=user_id, query=query) + stream_response = self.openai.get_chat_response_stream(chat_id=user_id, bot=self, tg_upd = update, query=query) i = 0 prev = '' backoff = 0 @@ -955,7 +964,7 @@ async def _send_inline_query_response(): parse_mode=constants.ParseMode.MARKDOWN) logging.info(f'Generating response for inline query by {name}') - response, total_tokens = await self.openai.get_chat_response(chat_id=user_id, query=query) + response, total_tokens = await self.openai.get_chat_response(chat_id=user_id, bot=self, tg_upd=update, query=query) if is_direct_result(response): cleanup_intermediate_files(response) @@ -974,7 +983,7 @@ async def _send_inline_query_response(): await edit_message_with_retry(context, chat_id=None, message_id=inline_message_id, text=text_content, is_inline=True) - await wrap_with_indicator(update, context, _send_inline_query_response, + await self.wrap_with_indicator(update, _send_inline_query_response(), constants.ChatAction.TYPING, is_inline=True) add_chat_request_to_usage_tracker(self.usage, self.config, user_id, total_tokens) @@ -1037,6 +1046,21 @@ async def send_budget_reached_message(self, update: Update, _: ContextTypes.DEFA result_id = str(uuid4()) await self.send_inline_query_result(update, result_id, message_content=self.budget_limit_message) + async def wrap_with_indicator(self, update: Update, coroutine, chat_action: constants.ChatAction = "", is_inline=False): + """ + Wraps a coroutine while repeatedly sending a chat action to the user. + """ + task = self.tg_app.create_task(coroutine, update=update) + while not task.done(): + if not is_inline: + self.tg_app.create_task( + update.effective_chat.send_action(chat_action, message_thread_id=get_thread_id(update)) + ) + try: + await asyncio.wait_for(asyncio.shield(task), 4.5) + except asyncio.TimeoutError: + pass + async def post_init(self, application: Application) -> None: """ Post initialization hook for the bot. @@ -1055,6 +1079,8 @@ def run(self): .post_init(self.post_init) \ .concurrent_updates(True) \ .build() + self.tg_app = application + self.tg = application.bot application.add_handler(CommandHandler('reset', self.reset)) application.add_handler(CommandHandler('help', self.help)) diff --git a/bot/utils.py b/bot/utils.py index d306dc6c..26e975f9 100644 --- a/bot/utils.py +++ b/bot/utils.py @@ -85,23 +85,6 @@ def split_into_chunks(text: str, chunk_size: int = 4096) -> list[str]: return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] -async def wrap_with_indicator(update: Update, context: CallbackContext, coroutine, - chat_action: constants.ChatAction = "", is_inline=False): - """ - Wraps a coroutine while repeatedly sending a chat action to the user. - """ - task = context.application.create_task(coroutine(), update=update) - while not task.done(): - if not is_inline: - context.application.create_task( - update.effective_chat.send_action(chat_action, message_thread_id=get_thread_id(update)) - ) - try: - await asyncio.wait_for(asyncio.shield(task), 4.5) - except asyncio.TimeoutError: - pass - - async def edit_message_with_retry(context: ContextTypes.DEFAULT_TYPE, chat_id: int | None, message_id: str, text: str, markdown: bool = True, is_inline: bool = False): """