From c464d07ecf0e4c51d84de1effeb0dd3160595004 Mon Sep 17 00:00:00 2001 From: gilcu3 <828241+gilcu3@users.noreply.github.com> Date: Wed, 15 Nov 2023 17:43:39 +0100 Subject: [PATCH] support for image editing using dalle 2 --- README.md | 2 + bot/main.py | 2 + bot/openai_helper.py | 35 ++++++++- bot/telegram_bot.py | 181 +++++++++++++++++++++++++++++++++++++++++++ bot/utils.py | 45 ++++++++++- translations.json | 5 +- 6 files changed, 267 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 34e23b38..85ee9dbb 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ Check out the [Budget Manual](https://github.com/n3d1117/chatgpt-telegram-bot/di |------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------| | `ENABLE_QUOTING` | Whether to enable message quoting in private chats | `true` | | `ENABLE_IMAGE_GENERATION` | Whether to enable image generation via the `/image` command | `true` | +| `ENABLE_IMAGE_EDITING` | Whether to enable image editing via the `/edit` command | `true` | | `ENABLE_TRANSCRIPTION` | Whether to enable transcriptions of audio and video messages | `true` | | `PROXY` | Proxy to be used for OpenAI and Telegram bot (e.g. `http://localhost:8080`) | - | | `OPENAI_MODEL` | The OpenAI model to use for generating responses. You can find all available models [here](https://platform.openai.com/docs/models/) | `gpt-3.5-turbo` | @@ -98,6 +99,7 @@ Check out the [Budget Manual](https://github.com/n3d1117/chatgpt-telegram-bot/di | `IMAGE_SIZE` | The DALL·E generated image size. Allowed values: `256x256`, `512x512` or `1024x1024` | `512x512` | | `GROUP_TRIGGER_KEYWORD` | If set, the bot in group chats will only respond to messages that start with this keyword | - | | `IGNORE_GROUP_TRANSCRIPTIONS` | If set to true, the bot will not process transcriptions in group chats | `true` | +| `IGNORE_GROUP_IMAGE_EDITING` | If set to true, the bot will not process image editing in group chats | `true` | | `BOT_LANGUAGE` | Language of general bot messages. Currently available: `en`, `de`, `ru`, `tr`, `it`, `fi`, `es`, `id`, `nl`, `zh-cn`, `zh-tw`, `vi`, `fa`, `pt-br`, `uk`, `ms`. [Contribute with additional translations](https://github.com/n3d1117/chatgpt-telegram-bot/discussions/219) | `en` | | `WHISPER_PROMPT` | To improve the accuracy of Whisper's transcription service, especially for specific names or terms, you can set up a custom message. [Speech to text - Prompting](https://platform.openai.com/docs/guides/speech-to-text/prompting) | `-` | diff --git a/bot/main.py b/bot/main.py index d7605fd5..430184a9 100644 --- a/bot/main.py +++ b/bot/main.py @@ -69,6 +69,7 @@ def main(): 'allowed_user_ids': os.environ.get('ALLOWED_TELEGRAM_USER_IDS', '*'), 'enable_quoting': os.environ.get('ENABLE_QUOTING', 'true').lower() == 'true', 'enable_image_generation': os.environ.get('ENABLE_IMAGE_GENERATION', 'true').lower() == 'true', + 'enable_image_editing': os.environ.get('ENABLE_IMAGE_EDITING', 'true').lower() == 'true', 'enable_transcription': os.environ.get('ENABLE_TRANSCRIPTION', 'true').lower() == 'true', 'budget_period': os.environ.get('BUDGET_PERIOD', 'monthly').lower(), 'user_budgets': os.environ.get('USER_BUDGETS', os.environ.get('MONTHLY_USER_BUDGETS', '*')), @@ -78,6 +79,7 @@ def main(): 'voice_reply_transcript': os.environ.get('VOICE_REPLY_WITH_TRANSCRIPT_ONLY', 'false').lower() == 'true', 'voice_reply_prompts': os.environ.get('VOICE_REPLY_PROMPTS', '').split(';'), 'ignore_group_transcriptions': os.environ.get('IGNORE_GROUP_TRANSCRIPTIONS', 'true').lower() == 'true', + 'ignore_group_image_editing': os.environ.get('IGNORE_GROUP_IMAGE_EDITING', 'true').lower() == 'true', 'group_trigger_keyword': os.environ.get('GROUP_TRIGGER_KEYWORD', ''), 'token_price': float(os.environ.get('TOKEN_PRICE', 0.002)), 'image_prices': [float(i) for i in os.environ.get('IMAGE_PRICES', "0.016,0.018,0.02").split(",")], diff --git a/bot/openai_helper.py b/bot/openai_helper.py index 4fe39bf4..9714fc4a 100644 --- a/bot/openai_helper.py +++ b/bot/openai_helper.py @@ -15,7 +15,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type -from utils import is_direct_result +from utils import is_direct_result, compute_image_diff from plugin_manager import PluginManager # Models can be found here: https://platform.openai.com/docs/models/overview @@ -333,6 +333,39 @@ async def generate_image(self, prompt: str) -> tuple[str, str]: return response.data[0].url, self.config['image_size'] except Exception as e: raise Exception(f"⚠️ _{localized_text('error', bot_language)}._ ⚠️\n{str(e)}") from e + + async def edit_image(self, chat_id, orig_image, modified_image, prompt): + """ + Edits a given PNG image (and the mask) using the Dalle 2 model. + """ + try: + + mask_image = compute_image_diff(orig_image, modified_image) + + args = { + 'model': 'dall-e-2', # for now only this is supported + 'n': 1, + 'size':self.config['image_size'], + 'image':orig_image, + 'prompt':prompt, + 'mask':mask_image + } + + response = await self.client.images.edit(**args) + + + + image_urls = [_.url for _ in response.data] + + return image_urls, self.config['image_size'] + + except openai.RateLimitError as e: + raise e + except openai.BadRequestError as e: + raise Exception(f"⚠️ _{localized_text('openai_invalid', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e + except Exception as e: + logging.exception(e) + raise Exception(f"⚠️ _{localized_text('error', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e async def transcribe(self, filename): """ diff --git a/bot/telegram_bot.py b/bot/telegram_bot.py index 57a3f0ae..24290f7e 100644 --- a/bot/telegram_bot.py +++ b/bot/telegram_bot.py @@ -3,16 +3,19 @@ import asyncio import logging import os +import io from uuid import uuid4 from telegram import BotCommandScopeAllGroupChats, Update, constants from telegram import InlineKeyboardMarkup, InlineKeyboardButton, InlineQueryResultArticle from telegram import InputTextMessageContent, BotCommand +from telegram import PhotoSize, Document, InputMediaDocument from telegram.error import RetryAfter, TimedOut from telegram.ext import ApplicationBuilder, CommandHandler, MessageHandler, \ filters, InlineQueryHandler, CallbackQueryHandler, Application, ContextTypes, CallbackContext from pydub import AudioSegment +from PIL import Image from utils import is_group_chat, get_thread_id, message_text, wrap_with_indicator, split_into_chunks, \ edit_message_with_retry, get_stream_cutoff_values, is_allowed, get_remaining_budget, is_admin, is_within_budget, \ @@ -370,6 +373,183 @@ async def _execute(): await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING) + async def edit_image(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + """ + Edit image using Dalle-2. + """ + if not self.config['enable_image_editing'] or not await self.check_allowed_and_within_budget(update, context): + return + + chat_id = update.effective_chat.id + + if is_group_chat(update): + if self.config['ignore_group_image_editing']: + logging.info(f'Image edit coming from group chat, ignoring...') + return + else: + trigger_keyword = self.config['group_trigger_keyword'] + if (prompt is None and trigger_keyword != '') or \ + (prompt is not None and not prompt.lower().startswith(trigger_keyword.lower())): + logging.info(f'Image edit coming from group chat with wrong keyword, ignoring...') + return + + + + + + async def _execute(): + bot_language = self.config['bot_language'] + + + + if update.message.reply_to_message: + + + + + logging.info(f'New mask for image edit request received from user {update.message.from_user.name} ' + f'(id: {update.message.from_user.id})') + + try: + rmes = update.message.reply_to_message + prompt = rmes.caption.split('\n')[0][len('prompt: '):] + + if isinstance(update.message.effective_attachment, tuple): + image_file = await update.message.effective_attachment[-1].get_file() + elif isinstance(update.message.effective_attachment, Document): + image_file = await update.message.effective_attachment.get_file() + else: + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=localized_text('wrong_image_mask', bot_language) + ) + return + + image = io.BytesIO(await image_file.download_as_bytearray()) + + orig_image_file = await rmes.effective_attachment.get_file() + orig_image = io.BytesIO(await orig_image_file.download_as_bytearray()) + + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=( + f"{localized_text('media_download_fail', bot_language)[0]}: " + f"{str(e)}. {localized_text('media_download_fail', bot_language)[1]}" + ) + ) + return + + user_id = update.message.from_user.id + if user_id not in self.usage: + self.usage[user_id] = UsageTracker(user_id, update.message.from_user.name) + + try: + edited_image_urls, image_size = await self.openai.edit_image(chat_id, orig_image, image, prompt) + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=str(e) + ) + + # add image request to users usage tracker + user_id = update.message.from_user.id + self.usage[user_id].add_image_request(image_size, self.config['image_prices']) + # add guest chat request to guest usage tracker + if str(user_id) not in self.config['allowed_user_ids'].split(',') and 'guests' in self.usage: + self.usage["guests"].add_image_request(image_size, self.config['image_prices']) + + edited_images = [InputMediaDocument(media=url, filename=f'image{i+1}.png') for i, url in enumerate(edited_image_urls)] + + await update.effective_message.reply_media_group( + reply_to_message_id=get_reply_to_message_id(self.config, update), + media=edited_images + ) + + else: + + try: + + if isinstance(update.message.effective_attachment, tuple): + image_file = await update.message.effective_attachment[-1].get_file() + elif isinstance(update.message.effective_attachment, Document): + image_file = await update.message.effective_attachment.get_file() + else: + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=localized_text('missing_image', bot_language) + ) + return + + logging.info(f'New image edit request received from user {update.message.from_user.name} ' + f'(id: {update.message.from_user.id})') + + image = io.BytesIO(await image_file.download_as_bytearray()) + + + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=( + f"{localized_text('media_download_fail', bot_language)[0]}: " + f"{str(e)}. {localized_text('media_download_fail', bot_language)[1]}" + ) + ) + return + + # convert jpg from telegram to png as understood by openai + + image_png = io.BytesIO() + + try: + original_image = Image.open(image) + + original_image.save(image_png, format='PNG') + image_png.seek(0) + + + + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=localized_text('media_type_fail', bot_language) + ) + + + + + + try: + + prompt = update.message.caption[len('/edit '):].replace('\n', ' ') + caption = f'prompt: {prompt}\nReply to this message with the mask in png format' + + await update.effective_message.reply_document( + reply_to_message_id=get_reply_to_message_id(self.config, update), + document=image_png, filename="image.png", caption=caption + ) + + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=f"{localized_text('edit_image_fail', bot_language)}: {str(e)}", + parse_mode=constants.ParseMode.MARKDOWN + ) + + await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING) + async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ React to incoming messages and respond accordingly. @@ -786,6 +966,7 @@ def run(self): application.add_handler(CommandHandler('reset', self.reset)) application.add_handler(CommandHandler('help', self.help)) application.add_handler(CommandHandler('image', self.image)) + application.add_handler(MessageHandler(((filters.PHOTO|filters.Document.IMAGE)&filters.CaptionRegex('^/edit.*')), self.edit_image)) application.add_handler(CommandHandler('start', self.help)) application.add_handler(CommandHandler('stats', self.stats)) application.add_handler(CommandHandler('resend', self.resend)) diff --git a/bot/utils.py b/bot/utils.py index 6ce2e98e..947ce848 100644 --- a/bot/utils.py +++ b/bot/utils.py @@ -5,6 +5,8 @@ import json import logging import os +import io +from PIL import Image, ImageChops import telegram from telegram import Message, MessageEntity, Update, ChatMember, constants @@ -376,4 +378,45 @@ def cleanup_intermediate_files(response: any): if format == 'path': if os.path.exists(value): - os.remove(value) \ No newline at end of file + os.remove(value) + +def compute_image_diff(im1, im2): + + im1 = Image.open(im1) + im2 = Image.open(im2) + + if im1.size != im2.size: + raise ValueError("The image and the mask must be of the same size.") + + pixels1 = im1.load() + pixels2 = im2.load() + + def pixel_difference(pixel1, pixel2): + channel_diff = sum(tuple(abs(c1 - c2) for c1, c2 in zip(pixel1, pixel2))) + return channel_diff + + transparent_box = im1.convert('RGBA') + + xtop, xbottom = im1.size[0], 0 + ytop, ybottom = im1.size[1], 0 + threshold = 256 + for y in range(im1.size[1]): + + for x in range(im1.size[0]): + if pixel_difference(pixels1[x, y], pixels2[x, y]) > threshold: + xtop = min(xtop, x) + xbottom = max(xbottom, x) + ytop = min(ytop, y) + ybottom = max(ybottom, y) + + if xbottom >= xtop and ybottom >= ytop: + for x in range(xtop, xbottom + 1): + for y in range(ytop, ybottom + 1): + transparent_box.putpixel((x, y), (255, 255, 255, 0)) + else: + raise('No difference detected in the images') + + res = io.BytesIO() + transparent_box.save(res, format='PNG') + res.seek(0) + return res diff --git a/translations.json b/translations.json index f603efb9..21e45ea7 100644 --- a/translations.json +++ b/translations.json @@ -25,7 +25,10 @@ "reset_done":"Done!", "image_no_prompt":"Please provide a prompt! (e.g. /image cat)", "image_fail":"Failed to generate image", - "media_download_fail":["Failed to download audio file", "Make sure the file is not too large. (max 20MB)"], + "edit_image_fail":"Failed to edit image", + "missing_image":"Missing image to edit", + "wrong_image_mask":"Wrong image mask. It must be a PNG file sent as document", + "media_download_fail":["Failed to download audio/image/video file", "Make sure the file is not too large. (max 20MB)"], "media_type_fail":"Unsupported file type", "transcript":"Transcript", "answer":"Answer",