diff --git a/README.md b/README.md index 61b7a41a..3f48a551 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,7 @@ Check out the [Budget Manual](https://github.com/n3d1117/chatgpt-telegram-bot/di |-------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------| | `ENABLE_QUOTING` | Whether to enable message quoting in private chats | `true` | | `ENABLE_IMAGE_GENERATION` | Whether to enable image generation via the `/image` command | `true` | +| `ENABLE_IMAGE_EDITING` | Whether to enable image editing via the `/edit` command | `true` | | `ENABLE_TRANSCRIPTION` | Whether to enable transcriptions of audio and video messages | `true` | | `ENABLE_TTS_GENERATION` | Whether to enable text to speech generation via the `/tts` | `true` | | `ENABLE_VISION` | Whether to enable vision capabilities in supported models | `true` | @@ -120,6 +121,7 @@ Check out the [Budget Manual](https://github.com/n3d1117/chatgpt-telegram-bot/di | `GROUP_TRIGGER_KEYWORD` | If set, the bot in group chats will only respond to messages that start with this keyword | - | | `IGNORE_GROUP_TRANSCRIPTIONS` | If set to true, the bot will not process transcriptions in group chats | `true` | | `IGNORE_GROUP_VISION` | If set to true, the bot will not process vision queries in group chats | `true` | +| `IGNORE_GROUP_IMAGE_EDITING` | If set to true, the bot will not process image editing in group chats | `true` | | `BOT_LANGUAGE` | Language of general bot messages. Currently available: `en`, `de`, `ru`, `tr`, `it`, `fi`, `es`, `id`, `nl`, `zh-cn`, `zh-tw`, `vi`, `fa`, `pt-br`, `uk`, `ms`, `uz`, `ar`. [Contribute with additional translations](https://github.com/n3d1117/chatgpt-telegram-bot/discussions/219) | `en` | | `WHISPER_PROMPT` | To improve the accuracy of Whisper's transcription service, especially for specific names or terms, you can set up a custom message. [Speech to text - Prompting](https://platform.openai.com/docs/guides/speech-to-text/prompting) | `-` | | `TTS_VOICE` | The Text to Speech voice to use. Allowed values: `alloy`, `echo`, `fable`, `onyx`, `nova`, or `shimmer` | `alloy` | diff --git a/bot/main.py b/bot/main.py index 8e0118d2..c4d53b5e 100644 --- a/bot/main.py +++ b/bot/main.py @@ -79,6 +79,7 @@ def main(): 'allowed_user_ids': os.environ.get('ALLOWED_TELEGRAM_USER_IDS', '*'), 'enable_quoting': os.environ.get('ENABLE_QUOTING', 'true').lower() == 'true', 'enable_image_generation': os.environ.get('ENABLE_IMAGE_GENERATION', 'true').lower() == 'true', + 'enable_image_editing': os.environ.get('ENABLE_IMAGE_EDITING', 'true').lower() == 'true', 'enable_transcription': os.environ.get('ENABLE_TRANSCRIPTION', 'true').lower() == 'true', 'enable_vision': os.environ.get('ENABLE_VISION', 'true').lower() == 'true', 'enable_tts_generation': os.environ.get('ENABLE_TTS_GENERATION', 'true').lower() == 'true', @@ -91,6 +92,7 @@ def main(): 'voice_reply_prompts': os.environ.get('VOICE_REPLY_PROMPTS', '').split(';'), 'ignore_group_transcriptions': os.environ.get('IGNORE_GROUP_TRANSCRIPTIONS', 'true').lower() == 'true', 'ignore_group_vision': os.environ.get('IGNORE_GROUP_VISION', 'true').lower() == 'true', + 'ignore_group_image_editing': os.environ.get('IGNORE_GROUP_IMAGE_EDITING', 'true').lower() == 'true', 'group_trigger_keyword': os.environ.get('GROUP_TRIGGER_KEYWORD', ''), 'token_price': float(os.environ.get('TOKEN_PRICE', 0.002)), 'image_prices': [float(i) for i in os.environ.get('IMAGE_PRICES', "0.016,0.018,0.02").split(",")], diff --git a/bot/openai_helper.py b/bot/openai_helper.py index 5a1896cf..3de6c12b 100644 --- a/bot/openai_helper.py +++ b/bot/openai_helper.py @@ -17,7 +17,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type -from utils import is_direct_result, encode_image, decode_image +from utils import is_direct_result, encode_image, decode_image, compute_image_diff from plugin_manager import PluginManager # Models can be found here: https://platform.openai.com/docs/models/overview @@ -355,6 +355,38 @@ async def generate_image(self, prompt: str) -> tuple[str, str]: return response.data[0].url, self.config['image_size'] except Exception as e: raise Exception(f"⚠️ _{localized_text('error', bot_language)}._ ⚠️\n{str(e)}") from e + + async def edit_image(self, chat_id, orig_image, modified_image, prompt): + """ + Edits a given PNG image (and the mask) using the Dalle 3 model. + """ + try: + + mask_image = compute_image_diff(orig_image, modified_image) + + args = { + 'n': 1, + 'size':self.config['image_size'], + 'image':orig_image, + 'prompt':prompt, + 'mask':mask_image + } + + response = await self.client.images.edit(**args) + + + + image_urls = [_.url for _ in response.data] + + return image_urls, self.config['image_size'] + + except openai.RateLimitError as e: + raise e + except openai.BadRequestError as e: + raise Exception(f"⚠️ _{localized_text('openai_invalid', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e + except Exception as e: + logging.exception(e) + raise Exception(f"⚠️ _{localized_text('error', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e async def generate_speech(self, text: str) -> tuple[any, int]: """ diff --git a/bot/telegram_bot.py b/bot/telegram_bot.py index 4cf9fa4c..4023157e 100644 --- a/bot/telegram_bot.py +++ b/bot/telegram_bot.py @@ -9,12 +9,14 @@ from telegram import BotCommandScopeAllGroupChats, Update, constants from telegram import InlineKeyboardMarkup, InlineKeyboardButton, InlineQueryResultArticle from telegram import InputTextMessageContent, BotCommand +from telegram import PhotoSize, Document, InputMediaDocument from telegram.error import RetryAfter, TimedOut, BadRequest from telegram.ext import ApplicationBuilder, CommandHandler, MessageHandler, \ filters, InlineQueryHandler, CallbackQueryHandler, Application, ContextTypes, CallbackContext from pydub import AudioSegment from PIL import Image +from PIL import Image from utils import is_group_chat, get_thread_id, message_text, wrap_with_indicator, split_into_chunks, \ edit_message_with_retry, get_stream_cutoff_values, is_allowed, get_remaining_budget, is_admin, is_within_budget, \ @@ -643,6 +645,183 @@ async def _execute(): await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING) + async def edit_image(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + """ + Edit image using Dalle-3. + """ + if not self.config['enable_image_editing'] or not await self.check_allowed_and_within_budget(update, context): + return + + chat_id = update.effective_chat.id + + if is_group_chat(update): + if self.config['ignore_group_image_editing']: + logging.info(f'Image edit coming from group chat, ignoring...') + return + else: + trigger_keyword = self.config['group_trigger_keyword'] + if (prompt is None and trigger_keyword != '') or \ + (prompt is not None and not prompt.lower().startswith(trigger_keyword.lower())): + logging.info(f'Image edit coming from group chat with wrong keyword, ignoring...') + return + + + + + + async def _execute(): + bot_language = self.config['bot_language'] + + + + if update.message.reply_to_message: + + + + + logging.info(f'New mask for image edit request received from user {update.message.from_user.name} ' + f'(id: {update.message.from_user.id})') + + try: + rmes = update.message.reply_to_message + prompt = rmes.caption.split('\n')[0][len('prompt: '):] + + if isinstance(update.message.effective_attachment, tuple): + image_file = await update.message.effective_attachment[-1].get_file() + elif isinstance(update.message.effective_attachment, Document): + image_file = await update.message.effective_attachment.get_file() + else: + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=localized_text('wrong_image_mask', bot_language) + ) + return + + image = io.BytesIO(await image_file.download_as_bytearray()) + + orig_image_file = await rmes.effective_attachment.get_file() + orig_image = io.BytesIO(await orig_image_file.download_as_bytearray()) + + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=( + f"{localized_text('media_download_fail', bot_language)[0]}: " + f"{str(e)}. {localized_text('media_download_fail', bot_language)[1]}" + ) + ) + return + + user_id = update.message.from_user.id + if user_id not in self.usage: + self.usage[user_id] = UsageTracker(user_id, update.message.from_user.name) + + try: + edited_image_urls, image_size = await self.openai.edit_image(chat_id, orig_image, image, prompt) + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=str(e) + ) + + # add image request to users usage tracker + user_id = update.message.from_user.id + self.usage[user_id].add_image_request(image_size, self.config['image_prices']) + # add guest chat request to guest usage tracker + if str(user_id) not in self.config['allowed_user_ids'].split(',') and 'guests' in self.usage: + self.usage["guests"].add_image_request(image_size, self.config['image_prices']) + + edited_images = [InputMediaDocument(media=url, filename=f'image{i+1}.png') for i, url in enumerate(edited_image_urls)] + + await update.effective_message.reply_media_group( + reply_to_message_id=get_reply_to_message_id(self.config, update), + media=edited_images + ) + + else: + + try: + + if isinstance(update.message.effective_attachment, tuple): + image_file = await update.message.effective_attachment[-1].get_file() + elif isinstance(update.message.effective_attachment, Document): + image_file = await update.message.effective_attachment.get_file() + else: + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=localized_text('missing_image', bot_language) + ) + return + + logging.info(f'New image edit request received from user {update.message.from_user.name} ' + f'(id: {update.message.from_user.id})') + + image = io.BytesIO(await image_file.download_as_bytearray()) + + + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=( + f"{localized_text('media_download_fail', bot_language)[0]}: " + f"{str(e)}. {localized_text('media_download_fail', bot_language)[1]}" + ) + ) + return + + # convert jpg from telegram to png as understood by openai + + image_png = io.BytesIO() + + try: + original_image = Image.open(image) + + original_image.save(image_png, format='PNG') + image_png.seek(0) + + + + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=localized_text('media_type_fail', bot_language) + ) + + + + + + try: + + prompt = update.message.caption[len('/edit '):].replace('\n', ' ') + caption = f'prompt: {prompt}\nReply to this message with the masked image' + + await update.effective_message.reply_document( + reply_to_message_id=get_reply_to_message_id(self.config, update), + document=image_png, filename="image.png", caption=caption + ) + + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=f"{localized_text('edit_image_fail', bot_language)}: {str(e)}", + parse_mode=constants.ParseMode.MARKDOWN + ) + + await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING) + async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ React to incoming messages and respond accordingly. @@ -1059,6 +1238,7 @@ def run(self): application.add_handler(CommandHandler('reset', self.reset)) application.add_handler(CommandHandler('help', self.help)) application.add_handler(CommandHandler('image', self.image)) + application.add_handler(MessageHandler(((filters.PHOTO|filters.Document.IMAGE)&filters.CaptionRegex('^/edit.*')), self.edit_image)) application.add_handler(CommandHandler('tts', self.tts)) application.add_handler(CommandHandler('start', self.help)) application.add_handler(CommandHandler('stats', self.stats)) diff --git a/bot/utils.py b/bot/utils.py index d306dc6c..8778ff6b 100644 --- a/bot/utils.py +++ b/bot/utils.py @@ -6,6 +6,8 @@ import logging import os import base64 +import io +from PIL import Image, ImageChops import telegram from telegram import Message, MessageEntity, Update, ChatMember, constants @@ -388,3 +390,45 @@ def encode_image(fileobj): def decode_image(imgbase64): image = imgbase64[len('data:image/jpeg;base64,'):] return base64.b64decode(image) + + +def compute_image_diff(im1, im2): + + im1 = Image.open(im1) + im2 = Image.open(im2) + + if im1.size != im2.size: + raise ValueError("The image and the mask must be of the same size.") + + pixels1 = im1.load() + pixels2 = im2.load() + + def pixel_difference(pixel1, pixel2): + channel_diff = sum(tuple(abs(c1 - c2) for c1, c2 in zip(pixel1, pixel2))) + return channel_diff + + transparent_box = im1.convert('RGBA') + + xtop, xbottom = im1.size[0], 0 + ytop, ybottom = im1.size[1], 0 + threshold = 256 + for y in range(im1.size[1]): + + for x in range(im1.size[0]): + if pixel_difference(pixels1[x, y], pixels2[x, y]) > threshold: + xtop = min(xtop, x) + xbottom = max(xbottom, x) + ytop = min(ytop, y) + ybottom = max(ybottom, y) + + if xbottom >= xtop and ybottom >= ytop: + for x in range(xtop, xbottom + 1): + for y in range(ytop, ybottom + 1): + transparent_box.putpixel((x, y), (255, 255, 255, 0)) + else: + raise('No difference detected in the images') + + res = io.BytesIO() + transparent_box.save(res, format='PNG') + res.seek(0) + return res diff --git a/translations.json b/translations.json index e13198ba..1baf94ac 100644 --- a/translations.json +++ b/translations.json @@ -29,9 +29,12 @@ "image_no_prompt":"Please provide a prompt! (e.g. /image cat)", "image_fail":"Failed to generate image", "vision_fail":"Failed to interpret image", + "edit_image_fail":"Failed to edit image", + "missing_image":"Missing image to edit", + "wrong_image_mask":"Wrong image mask. It must be a modification of the initial image", "tts_no_prompt":"Please provide text! (e.g. /tts my house)", "tts_fail":"Failed to generate speech", - "media_download_fail":["Failed to download audio file", "Make sure the file is not too large. (max 20MB)"], + "media_download_fail":["Failed to download audio/image/video file", "Make sure the file is not too large. (max 20MB)"], "media_type_fail":"Unsupported file type", "transcript":"Transcript", "answer":"Answer",