diff --git a/gemini_pro_bot/bot.py b/gemini_pro_bot/bot.py index 3bf816d..409ecce 100644 --- a/gemini_pro_bot/bot.py +++ b/gemini_pro_bot/bot.py @@ -1,9 +1,10 @@ import os -from telegram import Update +from telegram import Update, BotCommand from telegram.ext import ( CommandHandler, MessageHandler, Application, + CallbackQueryHandler, ) from gemini_pro_bot.filters import AuthFilter, MessageFilter, PhotoFilter from dotenv import load_dotenv @@ -13,26 +14,53 @@ newchat_command, handle_message, handle_image, + model_command, + model_callback, ) +import asyncio +import logging +# 加载环境变量 load_dotenv() +# 设置日志 +logging.basicConfig( + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + level=logging.INFO +) + +async def setup_commands(application: Application) -> None: + """设置机器人的命令菜单""" + commands = [ + BotCommand(command='start', description='开始使用机器人'), + BotCommand(command='help', description='获取帮助信息'), + BotCommand(command='new', description='开始新的对话'), + BotCommand(command='model', description='选择 AI 模型'), + ] + # await application.bot.set_my_commands(commands) def start_bot() -> None: - """Start the bot.""" - # Create the Application and pass it your bot's token. - application = Application.builder().token(os.getenv("BOT_TOKEN")).build() + """启动机器人""" + try: + # 创建应用实例 + application = Application.builder().token(os.getenv("BOT_TOKEN")).build() + # 添加命令处理器 + application.add_handler(CommandHandler("start", start, filters=AuthFilter)) + application.add_handler(CommandHandler("help", help_command, filters=AuthFilter)) + application.add_handler(CommandHandler("new", newchat_command, filters=AuthFilter)) - # on different commands - answer in Telegram - application.add_handler(CommandHandler("start", start, filters=AuthFilter)) - application.add_handler(CommandHandler("help", help_command, filters=AuthFilter)) - application.add_handler(CommandHandler("new", newchat_command, filters=AuthFilter)) + # 处理文本消息 + application.add_handler(MessageHandler(MessageFilter, handle_message)) - # Any text message is sent to LLM to generate a response - application.add_handler(MessageHandler(MessageFilter, handle_message)) + # 处理图片消息 + application.add_handler(MessageHandler(PhotoFilter, handle_image)) - # Any image is sent to LLM to generate a response - application.add_handler(MessageHandler(PhotoFilter, handle_image)) + # 添加模型选择命令 + application.add_handler(CommandHandler("model", model_command, filters=AuthFilter)) - # Run the bot until the user presses Ctrl-C - application.run_polling(allowed_updates=Update.ALL_TYPES) + # 添加回调处理器 + application.add_handler(CallbackQueryHandler(model_callback, pattern="^model_")) + application.run_polling(allowed_updates=Update.ALL_TYPES) + except Exception as e: + logging.error(f"启动机器人时发生错误: {e}") + raise diff --git a/gemini_pro_bot/handlers.py b/gemini_pro_bot/handlers.py index 4d50644..94d922b 100644 --- a/gemini_pro_bot/handlers.py +++ b/gemini_pro_bot/handlers.py @@ -1,18 +1,21 @@ import asyncio -from gemini_pro_bot.llm import model, img_model +from gemini_pro_bot.llm import model, llm_manager from google.generativeai.types.generation_types import ( StopCandidateException, BlockedPromptException, ) -from telegram import Update +import google.generativeai as genai +from telegram import Update , InlineKeyboardButton , InlineKeyboardMarkup ,BotCommand from telegram.ext import ( - ContextTypes, + ContextTypes,Application ) from telegram.error import NetworkError, BadRequest from telegram.constants import ChatAction, ParseMode from gemini_pro_bot.html_format import format_message import PIL.Image as load_image from io import BytesIO +from datetime import datetime +import os def new_chat(context: ContextTypes.DEFAULT_TYPE) -> None: @@ -34,6 +37,7 @@ async def help_command(update: Update, _: ContextTypes.DEFAULT_TYPE) -> None: Basic commands: /start - Start the bot /help - Get help. Shows this message +/model - Select LLM model to use Chat commands: /new - Start a new chat session (model will forget previously generated messages) @@ -65,7 +69,7 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> new_chat(context) text = update.message.text init_msg = await update.message.reply_text( - text="Generating...", reply_to_message_id=update.message.message_id + text="请稍后...", reply_to_message_id=update.message.message_id ) await update.message.chat.send_action(ChatAction.TYPING) # Generate a response using the text-generation pipeline @@ -133,63 +137,125 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> await asyncio.sleep(0.1) -async def handle_image(update: Update, _: ContextTypes.DEFAULT_TYPE) -> None: +async def handle_image(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle incoming images with captions and generate a response.""" init_msg = await update.message.reply_text( - text="Generating...", reply_to_message_id=update.message.message_id + text="请稍后...", + reply_to_message_id=update.message.message_id ) - images = update.message.photo - unique_images: dict = {} - for img in images: - file_id = img.file_id[:-7] - if file_id not in unique_images: - unique_images[file_id] = img - elif img.file_size > unique_images[file_id].file_size: - unique_images[file_id] = img - file_list = list(unique_images.values()) - file = await file_list[0].get_file() - a_img = load_image.open(BytesIO(await file.download_as_bytearray())) - prompt = None - if update.message.caption: - prompt = update.message.caption + try: + # 获取图片文件 + images = update.message.photo + if not images: + await init_msg.edit_text("No image found in the message.") + return + + # 获取最大尺寸的图片 + image = max(images, key=lambda x: x.file_size) + file = await image.get_file() + + # 下载图片数据 + image_data = await file.download_as_bytearray() + + # 上传图片到 Gemini + gemini_file = upload_to_gemini(image_data) + + # 准备文件列表 + files = [gemini_file] + + # 获取提示文本 + prompt = update.message.caption if update.message.caption else "Analyse this image and generate response" + if context.chat_data.get("chat") is None: + new_chat(context) + # 生成响应 + await update.message.chat.send_action(ChatAction.TYPING) + # Generate a response using the text-generation pipeline + chat_session = context.chat_data.get("chat") + chat_session.history.append({ + "role": "user", + "parts": [ + files[0], + ], + }) + # 使用 Gemini 生成响应 + response = await chat_session.send_message_async( + prompt, + stream=True + ) + # 处理响应 + full_plain_message = "" + async for chunk in response: + try: + if chunk.text: + full_plain_message += chunk.text + message = format_message(full_plain_message) + init_msg = await init_msg.edit_text( + text=message, + parse_mode=ParseMode.HTML, + disable_web_page_preview=True, + ) + except Exception as e: + print(f"Error in response streaming: {e}") + if not full_plain_message: + await init_msg.edit_text(f"Error generating response: {str(e)}") + break + await asyncio.sleep(0.1) + + except Exception as e: + print(f"Error processing image: {e}") + await init_msg.edit_text(f"Error processing image: {str(e)}") + +def upload_to_gemini(image_data, mime_type="image/png"): + """Uploads the given image data to Gemini.""" + # 生成临时文件名 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + temp_filename = f"temp_image_{timestamp}.png" + + try: + # 保存临时文件 + with open(temp_filename, 'wb') as f: + f.write(image_data) + + # 上传到 Gemini + file = genai.upload_file(temp_filename, mime_type=mime_type) + print(f"Uploaded file '{file.display_name}' as: {file.uri}") + return file + finally: + # 删除临时文件 + if os.path.exists(temp_filename): + os.remove(temp_filename) + + +async def model_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle the /model command - show model selection menu.""" + keyboard = [] + models = llm_manager.get_available_models() + + for model_id, model_info in models.items(): + # 为每个模型创建一个按钮 + keyboard.append([InlineKeyboardButton( + f"{model_info['name']} {'✓' if model_id == llm_manager.current_model else ''}", + callback_data=f"model_{model_id}" + )]) + + reply_markup = InlineKeyboardMarkup(keyboard) + await update.message.reply_text( + "选择要使用的模型:", + reply_markup=reply_markup + ) + +async def model_callback(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle model selection callback.""" + query = update.callback_query + await query.answer() + + # 从callback_data中提取模型ID + model_id = query.data.replace("model_", "") + + if llm_manager.switch_model(model_id): + models = llm_manager.get_available_models() + await query.edit_message_text( + f"已切换到 {models[model_id]['name']} 模型" + ) else: - prompt = "Analyse this image and generate response" - response = await img_model.generate_content_async([prompt, a_img], stream=True) - full_plain_message = "" - async for chunk in response: - try: - if chunk.text: - full_plain_message += chunk.text - message = format_message(full_plain_message) - init_msg = await init_msg.edit_text( - text=message, - parse_mode=ParseMode.HTML, - disable_web_page_preview=True, - ) - except StopCandidateException: - await init_msg.edit_text("The model unexpectedly stopped generating.") - except BadRequest: - await response.resolve() - continue - except NetworkError: - raise NetworkError( - "Looks like you're network is down. Please try again later." - ) - except IndexError: - await init_msg.reply_text( - "Some index error occurred. This response is not supported." - ) - await response.resolve() - continue - except Exception as e: - print(e) - if chunk.text: - full_plain_message = chunk.text - message = format_message(full_plain_message) - init_msg = await update.message.reply_text( - text=message, - parse_mode=ParseMode.HTML, - reply_to_message_id=init_msg.message_id, - disable_web_page_preview=True, - ) - await asyncio.sleep(0.1) + await query.edit_message_text("模型切换失败") diff --git a/gemini_pro_bot/llm.py b/gemini_pro_bot/llm.py index eaacd1f..7c69563 100644 --- a/gemini_pro_bot/llm.py +++ b/gemini_pro_bot/llm.py @@ -13,9 +13,49 @@ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, } +MODELS = { + "gemini-1.5-pro": { + "name": "gemini-1.5-pro", + "model": "gemini-1.5-pro", + "type": "text" + }, + "gemini-1.5-flash": { + "name": "gemini-1.5-flash", + "model": "gemini-1.5-flash", + "type": "vision" + }, + "gemini-1.5-flash-8b": { + "name": "gemini-1.5-flash-8b", + "model": "gemini-1.5-flash-8b", + "type": "vision" + }, +} + +class LLMManager: + def __init__(self): + self.current_model = "gemini-1.5-pro" + genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) + self._init_models() + + def _init_models(self): + self.models = {} + for model_id, config in MODELS.items(): + self.models[model_id] = genai.GenerativeModel( + config["model"], + safety_settings=SAFETY_SETTINGS + ) + + def get_current_model(self): + return self.models[self.current_model] -genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) + def switch_model(self, model_id): + if model_id in MODELS: + self.current_model = model_id + return True + return False + def get_available_models(self): + return MODELS -model = genai.GenerativeModel("gemini-pro", safety_settings=SAFETY_SETTINGS) -img_model = genai.GenerativeModel("gemini-pro-vision", safety_settings=SAFETY_SETTINGS) +llm_manager = LLMManager() +model = llm_manager.get_current_model()