Skip to content

Commit

Permalink
update MODELS
Browse files Browse the repository at this point in the history
  • Loading branch information
shaohongwu committed Nov 11, 2024
1 parent 5b7aa32 commit 277d1ee
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 77 deletions.
56 changes: 42 additions & 14 deletions gemini_pro_bot/bot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from telegram import Update
from telegram import Update, BotCommand
from telegram.ext import (
CommandHandler,
MessageHandler,
Application,
CallbackQueryHandler,
)
from gemini_pro_bot.filters import AuthFilter, MessageFilter, PhotoFilter
from dotenv import load_dotenv
Expand All @@ -13,26 +14,53 @@
newchat_command,
handle_message,
handle_image,
model_command,
model_callback,
)
import asyncio
import logging

# 加载环境变量
load_dotenv()

# 设置日志
logging.basicConfig(
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
level=logging.INFO
)

async def setup_commands(application: Application) -> None:
"""设置机器人的命令菜单"""
commands = [
BotCommand(command='start', description='开始使用机器人'),
BotCommand(command='help', description='获取帮助信息'),
BotCommand(command='new', description='开始新的对话'),
BotCommand(command='model', description='选择 AI 模型'),
]
# await application.bot.set_my_commands(commands)

def start_bot() -> None:
"""Start the bot."""
# Create the Application and pass it your bot's token.
application = Application.builder().token(os.getenv("BOT_TOKEN")).build()
"""启动机器人"""
try:
# 创建应用实例
application = Application.builder().token(os.getenv("BOT_TOKEN")).build()
# 添加命令处理器
application.add_handler(CommandHandler("start", start, filters=AuthFilter))
application.add_handler(CommandHandler("help", help_command, filters=AuthFilter))
application.add_handler(CommandHandler("new", newchat_command, filters=AuthFilter))

# on different commands - answer in Telegram
application.add_handler(CommandHandler("start", start, filters=AuthFilter))
application.add_handler(CommandHandler("help", help_command, filters=AuthFilter))
application.add_handler(CommandHandler("new", newchat_command, filters=AuthFilter))
# 处理文本消息
application.add_handler(MessageHandler(MessageFilter, handle_message))

# Any text message is sent to LLM to generate a response
application.add_handler(MessageHandler(MessageFilter, handle_message))
# 处理图片消息
application.add_handler(MessageHandler(PhotoFilter, handle_image))

# Any image is sent to LLM to generate a response
application.add_handler(MessageHandler(PhotoFilter, handle_image))
# 添加模型选择命令
application.add_handler(CommandHandler("model", model_command, filters=AuthFilter))

# Run the bot until the user presses Ctrl-C
application.run_polling(allowed_updates=Update.ALL_TYPES)
# 添加回调处理器
application.add_handler(CallbackQueryHandler(model_callback, pattern="^model_"))
application.run_polling(allowed_updates=Update.ALL_TYPES)
except Exception as e:
logging.error(f"启动机器人时发生错误: {e}")
raise
186 changes: 126 additions & 60 deletions gemini_pro_bot/handlers.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import asyncio
from gemini_pro_bot.llm import model, img_model
from gemini_pro_bot.llm import model, llm_manager
from google.generativeai.types.generation_types import (
StopCandidateException,
BlockedPromptException,
)
from telegram import Update
import google.generativeai as genai
from telegram import Update , InlineKeyboardButton , InlineKeyboardMarkup ,BotCommand
from telegram.ext import (
ContextTypes,
ContextTypes,Application
)
from telegram.error import NetworkError, BadRequest
from telegram.constants import ChatAction, ParseMode
from gemini_pro_bot.html_format import format_message
import PIL.Image as load_image
from io import BytesIO
from datetime import datetime
import os


def new_chat(context: ContextTypes.DEFAULT_TYPE) -> None:
Expand All @@ -34,6 +37,7 @@ async def help_command(update: Update, _: ContextTypes.DEFAULT_TYPE) -> None:
Basic commands:
/start - Start the bot
/help - Get help. Shows this message
/model - Select LLM model to use
Chat commands:
/new - Start a new chat session (model will forget previously generated messages)
Expand Down Expand Up @@ -65,7 +69,7 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
new_chat(context)
text = update.message.text
init_msg = await update.message.reply_text(
text="Generating...", reply_to_message_id=update.message.message_id
text="请稍后...", reply_to_message_id=update.message.message_id
)
await update.message.chat.send_action(ChatAction.TYPING)
# Generate a response using the text-generation pipeline
Expand Down Expand Up @@ -133,63 +137,125 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
await asyncio.sleep(0.1)


async def handle_image(update: Update, _: ContextTypes.DEFAULT_TYPE) -> None:
async def handle_image(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle incoming images with captions and generate a response."""
init_msg = await update.message.reply_text(
text="Generating...", reply_to_message_id=update.message.message_id
text="请稍后...",
reply_to_message_id=update.message.message_id
)
images = update.message.photo
unique_images: dict = {}
for img in images:
file_id = img.file_id[:-7]
if file_id not in unique_images:
unique_images[file_id] = img
elif img.file_size > unique_images[file_id].file_size:
unique_images[file_id] = img
file_list = list(unique_images.values())
file = await file_list[0].get_file()
a_img = load_image.open(BytesIO(await file.download_as_bytearray()))
prompt = None
if update.message.caption:
prompt = update.message.caption
try:
# 获取图片文件
images = update.message.photo
if not images:
await init_msg.edit_text("No image found in the message.")
return

# 获取最大尺寸的图片
image = max(images, key=lambda x: x.file_size)
file = await image.get_file()

# 下载图片数据
image_data = await file.download_as_bytearray()

# 上传图片到 Gemini
gemini_file = upload_to_gemini(image_data)

# 准备文件列表
files = [gemini_file]

# 获取提示文本
prompt = update.message.caption if update.message.caption else "Analyse this image and generate response"
if context.chat_data.get("chat") is None:
new_chat(context)
# 生成响应
await update.message.chat.send_action(ChatAction.TYPING)
# Generate a response using the text-generation pipeline
chat_session = context.chat_data.get("chat")
chat_session.history.append({
"role": "user",
"parts": [
files[0],
],
})
# 使用 Gemini 生成响应
response = await chat_session.send_message_async(
prompt,
stream=True
)
# 处理响应
full_plain_message = ""
async for chunk in response:
try:
if chunk.text:
full_plain_message += chunk.text
message = format_message(full_plain_message)
init_msg = await init_msg.edit_text(
text=message,
parse_mode=ParseMode.HTML,
disable_web_page_preview=True,
)
except Exception as e:
print(f"Error in response streaming: {e}")
if not full_plain_message:
await init_msg.edit_text(f"Error generating response: {str(e)}")
break
await asyncio.sleep(0.1)

except Exception as e:
print(f"Error processing image: {e}")
await init_msg.edit_text(f"Error processing image: {str(e)}")

def upload_to_gemini(image_data, mime_type="image/png"):
"""Uploads the given image data to Gemini."""
# 生成临时文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
temp_filename = f"temp_image_{timestamp}.png"

try:
# 保存临时文件
with open(temp_filename, 'wb') as f:
f.write(image_data)

# 上传到 Gemini
file = genai.upload_file(temp_filename, mime_type=mime_type)
print(f"Uploaded file '{file.display_name}' as: {file.uri}")
return file
finally:
# 删除临时文件
if os.path.exists(temp_filename):
os.remove(temp_filename)


async def model_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle the /model command - show model selection menu."""
keyboard = []
models = llm_manager.get_available_models()

for model_id, model_info in models.items():
# 为每个模型创建一个按钮
keyboard.append([InlineKeyboardButton(
f"{model_info['name']} {'✓' if model_id == llm_manager.current_model else ''}",
callback_data=f"model_{model_id}"
)])

reply_markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(
"选择要使用的模型:",
reply_markup=reply_markup
)

async def model_callback(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle model selection callback."""
query = update.callback_query
await query.answer()

# 从callback_data中提取模型ID
model_id = query.data.replace("model_", "")

if llm_manager.switch_model(model_id):
models = llm_manager.get_available_models()
await query.edit_message_text(
f"已切换到 {models[model_id]['name']} 模型"
)
else:
prompt = "Analyse this image and generate response"
response = await img_model.generate_content_async([prompt, a_img], stream=True)
full_plain_message = ""
async for chunk in response:
try:
if chunk.text:
full_plain_message += chunk.text
message = format_message(full_plain_message)
init_msg = await init_msg.edit_text(
text=message,
parse_mode=ParseMode.HTML,
disable_web_page_preview=True,
)
except StopCandidateException:
await init_msg.edit_text("The model unexpectedly stopped generating.")
except BadRequest:
await response.resolve()
continue
except NetworkError:
raise NetworkError(
"Looks like you're network is down. Please try again later."
)
except IndexError:
await init_msg.reply_text(
"Some index error occurred. This response is not supported."
)
await response.resolve()
continue
except Exception as e:
print(e)
if chunk.text:
full_plain_message = chunk.text
message = format_message(full_plain_message)
init_msg = await update.message.reply_text(
text=message,
parse_mode=ParseMode.HTML,
reply_to_message_id=init_msg.message_id,
disable_web_page_preview=True,
)
await asyncio.sleep(0.1)
await query.edit_message_text("模型切换失败")
46 changes: 43 additions & 3 deletions gemini_pro_bot/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,49 @@
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
}

MODELS = {
"gemini-1.5-pro": {
"name": "gemini-1.5-pro",
"model": "gemini-1.5-pro",
"type": "text"
},
"gemini-1.5-flash": {
"name": "gemini-1.5-flash",
"model": "gemini-1.5-flash",
"type": "vision"
},
"gemini-1.5-flash-8b": {
"name": "gemini-1.5-flash-8b",
"model": "gemini-1.5-flash-8b",
"type": "vision"
},
}

class LLMManager:
def __init__(self):
self.current_model = "gemini-1.5-pro"
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
self._init_models()

def _init_models(self):
self.models = {}
for model_id, config in MODELS.items():
self.models[model_id] = genai.GenerativeModel(
config["model"],
safety_settings=SAFETY_SETTINGS
)

def get_current_model(self):
return self.models[self.current_model]

genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
def switch_model(self, model_id):
if model_id in MODELS:
self.current_model = model_id
return True
return False

def get_available_models(self):
return MODELS

model = genai.GenerativeModel("gemini-pro", safety_settings=SAFETY_SETTINGS)
img_model = genai.GenerativeModel("gemini-pro-vision", safety_settings=SAFETY_SETTINGS)
llm_manager = LLMManager()
model = llm_manager.get_current_model()

0 comments on commit 277d1ee

Please sign in to comment.