Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image edit using DALL·E #462

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ Check out the [Budget Manual](https://github.com/n3d1117/chatgpt-telegram-bot/di
|-------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------|
| `ENABLE_QUOTING` | Whether to enable message quoting in private chats | `true` |
| `ENABLE_IMAGE_GENERATION` | Whether to enable image generation via the `/image` command | `true` |
| `ENABLE_IMAGE_EDITING` | Whether to enable image editing via the `/edit` command | `true` |
| `ENABLE_TRANSCRIPTION` | Whether to enable transcriptions of audio and video messages | `true` |
| `ENABLE_TTS_GENERATION` | Whether to enable text to speech generation via the `/tts` | `true` |
| `ENABLE_VISION` | Whether to enable vision capabilities in supported models | `true` |
Expand Down Expand Up @@ -120,6 +121,7 @@ Check out the [Budget Manual](https://github.com/n3d1117/chatgpt-telegram-bot/di
| `GROUP_TRIGGER_KEYWORD` | If set, the bot in group chats will only respond to messages that start with this keyword | - |
| `IGNORE_GROUP_TRANSCRIPTIONS` | If set to true, the bot will not process transcriptions in group chats | `true` |
| `IGNORE_GROUP_VISION` | If set to true, the bot will not process vision queries in group chats | `true` |
| `IGNORE_GROUP_IMAGE_EDITING` | If set to true, the bot will not process image editing in group chats | `true` |
| `BOT_LANGUAGE` | Language of general bot messages. Currently available: `en`, `de`, `ru`, `tr`, `it`, `fi`, `es`, `id`, `nl`, `zh-cn`, `zh-tw`, `vi`, `fa`, `pt-br`, `uk`, `ms`, `uz`, `ar`. [Contribute with additional translations](https://github.com/n3d1117/chatgpt-telegram-bot/discussions/219) | `en` |
| `WHISPER_PROMPT` | To improve the accuracy of Whisper's transcription service, especially for specific names or terms, you can set up a custom message. [Speech to text - Prompting](https://platform.openai.com/docs/guides/speech-to-text/prompting) | `-` |
| `TTS_VOICE` | The Text to Speech voice to use. Allowed values: `alloy`, `echo`, `fable`, `onyx`, `nova`, or `shimmer` | `alloy` |
Expand Down
2 changes: 2 additions & 0 deletions bot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def main():
'allowed_user_ids': os.environ.get('ALLOWED_TELEGRAM_USER_IDS', '*'),
'enable_quoting': os.environ.get('ENABLE_QUOTING', 'true').lower() == 'true',
'enable_image_generation': os.environ.get('ENABLE_IMAGE_GENERATION', 'true').lower() == 'true',
'enable_image_editing': os.environ.get('ENABLE_IMAGE_EDITING', 'true').lower() == 'true',
'enable_transcription': os.environ.get('ENABLE_TRANSCRIPTION', 'true').lower() == 'true',
'enable_vision': os.environ.get('ENABLE_VISION', 'true').lower() == 'true',
'enable_tts_generation': os.environ.get('ENABLE_TTS_GENERATION', 'true').lower() == 'true',
Expand All @@ -91,6 +92,7 @@ def main():
'voice_reply_prompts': os.environ.get('VOICE_REPLY_PROMPTS', '').split(';'),
'ignore_group_transcriptions': os.environ.get('IGNORE_GROUP_TRANSCRIPTIONS', 'true').lower() == 'true',
'ignore_group_vision': os.environ.get('IGNORE_GROUP_VISION', 'true').lower() == 'true',
'ignore_group_image_editing': os.environ.get('IGNORE_GROUP_IMAGE_EDITING', 'true').lower() == 'true',
'group_trigger_keyword': os.environ.get('GROUP_TRIGGER_KEYWORD', ''),
'token_price': float(os.environ.get('TOKEN_PRICE', 0.002)),
'image_prices': [float(i) for i in os.environ.get('IMAGE_PRICES', "0.016,0.018,0.02").split(",")],
Expand Down
34 changes: 33 additions & 1 deletion bot/openai_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type

from utils import is_direct_result, encode_image, decode_image
from utils import is_direct_result, encode_image, decode_image, compute_image_diff
from plugin_manager import PluginManager

# Models can be found here: https://platform.openai.com/docs/models/overview
Expand Down Expand Up @@ -355,6 +355,38 @@ async def generate_image(self, prompt: str) -> tuple[str, str]:
return response.data[0].url, self.config['image_size']
except Exception as e:
raise Exception(f"⚠️ _{localized_text('error', bot_language)}._ ⚠️\n{str(e)}") from e

async def edit_image(self, chat_id, orig_image, modified_image, prompt):
"""
Edits a given PNG image (and the mask) using the Dalle 3 model.
"""
try:

mask_image = compute_image_diff(orig_image, modified_image)

args = {
'n': 1,
'size':self.config['image_size'],
'image':orig_image,
'prompt':prompt,
'mask':mask_image
}

response = await self.client.images.edit(**args)



image_urls = [_.url for _ in response.data]

return image_urls, self.config['image_size']

except openai.RateLimitError as e:
raise e
except openai.BadRequestError as e:
raise Exception(f"⚠️ _{localized_text('openai_invalid', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e
except Exception as e:
logging.exception(e)
raise Exception(f"⚠️ _{localized_text('error', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e

async def generate_speech(self, text: str) -> tuple[any, int]:
"""
Expand Down
180 changes: 180 additions & 0 deletions bot/telegram_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from telegram import BotCommandScopeAllGroupChats, Update, constants
from telegram import InlineKeyboardMarkup, InlineKeyboardButton, InlineQueryResultArticle
from telegram import InputTextMessageContent, BotCommand
from telegram import PhotoSize, Document, InputMediaDocument
from telegram.error import RetryAfter, TimedOut, BadRequest
from telegram.ext import ApplicationBuilder, CommandHandler, MessageHandler, \
filters, InlineQueryHandler, CallbackQueryHandler, Application, ContextTypes, CallbackContext

from pydub import AudioSegment
from PIL import Image
from PIL import Image

from utils import is_group_chat, get_thread_id, message_text, wrap_with_indicator, split_into_chunks, \
edit_message_with_retry, get_stream_cutoff_values, is_allowed, get_remaining_budget, is_admin, is_within_budget, \
Expand Down Expand Up @@ -643,6 +645,183 @@ async def _execute():

await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING)

async def edit_image(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
"""
Edit image using Dalle-3.
"""
if not self.config['enable_image_editing'] or not await self.check_allowed_and_within_budget(update, context):
return

chat_id = update.effective_chat.id

if is_group_chat(update):
if self.config['ignore_group_image_editing']:
logging.info(f'Image edit coming from group chat, ignoring...')
return
else:
trigger_keyword = self.config['group_trigger_keyword']
if (prompt is None and trigger_keyword != '') or \
(prompt is not None and not prompt.lower().startswith(trigger_keyword.lower())):
logging.info(f'Image edit coming from group chat with wrong keyword, ignoring...')
return





async def _execute():
bot_language = self.config['bot_language']



if update.message.reply_to_message:




logging.info(f'New mask for image edit request received from user {update.message.from_user.name} '
f'(id: {update.message.from_user.id})')

try:
rmes = update.message.reply_to_message
prompt = rmes.caption.split('\n')[0][len('prompt: '):]

if isinstance(update.message.effective_attachment, tuple):
image_file = await update.message.effective_attachment[-1].get_file()
elif isinstance(update.message.effective_attachment, Document):
image_file = await update.message.effective_attachment.get_file()
else:
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=localized_text('wrong_image_mask', bot_language)
)
return

image = io.BytesIO(await image_file.download_as_bytearray())

orig_image_file = await rmes.effective_attachment.get_file()
orig_image = io.BytesIO(await orig_image_file.download_as_bytearray())

except Exception as e:
logging.exception(e)
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=(
f"{localized_text('media_download_fail', bot_language)[0]}: "
f"{str(e)}. {localized_text('media_download_fail', bot_language)[1]}"
)
)
return

user_id = update.message.from_user.id
if user_id not in self.usage:
self.usage[user_id] = UsageTracker(user_id, update.message.from_user.name)

try:
edited_image_urls, image_size = await self.openai.edit_image(chat_id, orig_image, image, prompt)
except Exception as e:
logging.exception(e)
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=str(e)
)

# add image request to users usage tracker
user_id = update.message.from_user.id
self.usage[user_id].add_image_request(image_size, self.config['image_prices'])
# add guest chat request to guest usage tracker
if str(user_id) not in self.config['allowed_user_ids'].split(',') and 'guests' in self.usage:
self.usage["guests"].add_image_request(image_size, self.config['image_prices'])

edited_images = [InputMediaDocument(media=url, filename=f'image{i+1}.png') for i, url in enumerate(edited_image_urls)]

await update.effective_message.reply_media_group(
reply_to_message_id=get_reply_to_message_id(self.config, update),
media=edited_images
)

else:

try:

if isinstance(update.message.effective_attachment, tuple):
image_file = await update.message.effective_attachment[-1].get_file()
elif isinstance(update.message.effective_attachment, Document):
image_file = await update.message.effective_attachment.get_file()
else:
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=localized_text('missing_image', bot_language)
)
return

logging.info(f'New image edit request received from user {update.message.from_user.name} '
f'(id: {update.message.from_user.id})')

image = io.BytesIO(await image_file.download_as_bytearray())


except Exception as e:
logging.exception(e)
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=(
f"{localized_text('media_download_fail', bot_language)[0]}: "
f"{str(e)}. {localized_text('media_download_fail', bot_language)[1]}"
)
)
return

# convert jpg from telegram to png as understood by openai

image_png = io.BytesIO()

try:
original_image = Image.open(image)

original_image.save(image_png, format='PNG')
image_png.seek(0)



except Exception as e:
logging.exception(e)
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=localized_text('media_type_fail', bot_language)
)





try:

prompt = update.message.caption[len('/edit '):].replace('\n', ' ')
caption = f'prompt: {prompt}\nReply to this message with the masked image'

await update.effective_message.reply_document(
reply_to_message_id=get_reply_to_message_id(self.config, update),
document=image_png, filename="image.png", caption=caption
)

except Exception as e:
logging.exception(e)
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=f"{localized_text('edit_image_fail', bot_language)}: {str(e)}",
parse_mode=constants.ParseMode.MARKDOWN
)

await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING)

async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
"""
React to incoming messages and respond accordingly.
Expand Down Expand Up @@ -1059,6 +1238,7 @@ def run(self):
application.add_handler(CommandHandler('reset', self.reset))
application.add_handler(CommandHandler('help', self.help))
application.add_handler(CommandHandler('image', self.image))
application.add_handler(MessageHandler(((filters.PHOTO|filters.Document.IMAGE)&filters.CaptionRegex('^/edit.*')), self.edit_image))
application.add_handler(CommandHandler('tts', self.tts))
application.add_handler(CommandHandler('start', self.help))
application.add_handler(CommandHandler('stats', self.stats))
Expand Down
44 changes: 44 additions & 0 deletions bot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import logging
import os
import base64
import io
from PIL import Image, ImageChops

import telegram
from telegram import Message, MessageEntity, Update, ChatMember, constants
Expand Down Expand Up @@ -388,3 +390,45 @@ def encode_image(fileobj):
def decode_image(imgbase64):
image = imgbase64[len('data:image/jpeg;base64,'):]
return base64.b64decode(image)


def compute_image_diff(im1, im2):

im1 = Image.open(im1)
im2 = Image.open(im2)

if im1.size != im2.size:
raise ValueError("The image and the mask must be of the same size.")

pixels1 = im1.load()
pixels2 = im2.load()

def pixel_difference(pixel1, pixel2):
channel_diff = sum(tuple(abs(c1 - c2) for c1, c2 in zip(pixel1, pixel2)))
return channel_diff

transparent_box = im1.convert('RGBA')

xtop, xbottom = im1.size[0], 0
ytop, ybottom = im1.size[1], 0
threshold = 256
for y in range(im1.size[1]):

for x in range(im1.size[0]):
if pixel_difference(pixels1[x, y], pixels2[x, y]) > threshold:
xtop = min(xtop, x)
xbottom = max(xbottom, x)
ytop = min(ytop, y)
ybottom = max(ybottom, y)

if xbottom >= xtop and ybottom >= ytop:
for x in range(xtop, xbottom + 1):
for y in range(ytop, ybottom + 1):
transparent_box.putpixel((x, y), (255, 255, 255, 0))
else:
raise('No difference detected in the images')

res = io.BytesIO()
transparent_box.save(res, format='PNG')
res.seek(0)
return res
Loading