diff --git a/app/core/runner/llm_backend.py b/app/core/runner/llm_backend.py index f27a553..63d8b0a 100644 --- a/app/core/runner/llm_backend.py +++ b/app/core/runner/llm_backend.py @@ -16,7 +16,15 @@ def __init__(self, base_url: str, api_key) -> None: self.client = OpenAI(base_url=self.base_url, api_key=self.api_key) def run( - self, messages: List, model: str, tools: List = None, tool_choice="auto", stream=False, extra_body=None + self, + messages: List, + model: str, + tools: List = None, + tool_choice="auto", + stream=False, + extra_body=None, + temperature=None, + top_p=None, ) -> ChatCompletion | Stream[ChatCompletionChunk]: chat_params = { "messages": messages, @@ -29,6 +37,10 @@ def run( if "n" in model_params: raise ValueError("n is not allowed in model_params") chat_params.update(model_params) + if temperature: + chat_params["temperature"] = temperature + if top_p: + chat_params["top_p"] = top_p if tools: chat_params["tools"] = tools chat_params["tool_choice"] = tool_choice if tool_choice else "auto" diff --git a/app/core/runner/thread_runner.py b/app/core/runner/thread_runner.py index 70c99af..f7203d7 100644 --- a/app/core/runner/thread_runner.py +++ b/app/core/runner/thread_runner.py @@ -115,6 +115,8 @@ def __run_step(self, llm: LLMBackend, run: Run, run_steps: List[RunStep], instru tool_choice="auto" if len(run_steps) < self.max_step else "none", stream=True, extra_body=run.extra_body, + temperature=run.temperature, + top_p=run.top_p, ) # create message callback diff --git a/app/models/assistant.py b/app/models/assistant.py index 28d46d5..b5b6952 100644 --- a/app/models/assistant.py +++ b/app/models/assistant.py @@ -15,6 +15,10 @@ class AssistantBase(BaseModel): name: Optional[str] = Field(default=None) tools: Optional[list] = Field(default=None, sa_column=Column(JSON)) extra_body: Optional[dict] = Field(default=None, sa_column=Column(JSON)) + response_format: Optional[str] = Field(default=None) # 响应格式 + tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 工具资源 + temperature: Optional[float] = Field(default=None) # 温度 + top_p: Optional[float] = Field(default=None) # top_p class Assistant(AssistantBase, PrimaryKeyMixin, TimeStampMixin, table=True): @@ -34,3 +38,7 @@ class AssistantUpdate(BaseModel): name: Optional[str] = Field(default=None) tools: Optional[list] = Field(default=None, sa_column=Column(JSON)) extra_body: Optional[dict] = Field(default=None, sa_column=Column(JSON)) + response_format: Optional[str] = Field(default=None) # 响应格式 + tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 工具资源 + temperature: Optional[float] = Field(default=None) # 温度 + top_p: Optional[float] = Field(default=None) # top_p diff --git a/app/models/run.py b/app/models/run.py index c06b3d4..ea5c121 100644 --- a/app/models/run.py +++ b/app/models/run.py @@ -46,6 +46,15 @@ class Run(BaseModel, PrimaryKeyMixin, TimeStampMixin, table=True): failed_at: Optional[Timestamp] = Field(default=None) additional_instructions: Optional[str] = Field(default=None, max_length=32768, sa_column=Column(TEXT)) extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON)) + incomplete_details: Optional[str] = Field(default=None) # 未完成详情 + max_completion_tokens: Optional[int] = Field(default=None) # 最大完成长度 + max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度 + response_format: Optional[str] = Field(default=None) # 返回格式 + tool_choice: Optional[str] = Field(default=None) # 工具选择 + truncation_strategy: Optional[dict] = Field(default={}, sa_column=Column(JSON)) # 截断策略 + usage: Optional[dict] = Field(default={}, sa_column=Column(JSON)) # 调用使用情况 + temperature: Optional[float] = Field(default=None) # 温度 + top_p: Optional[float] = Field(default=None) # top_p class RunRead(Run): ... @@ -62,6 +71,14 @@ class RunCreate(BaseModel): tools: Optional[list] = [] extra_body: Optional[dict[str, Union[dict[str, Union[Authentication, Any]], Any]]] = {} stream: Optional[bool] = False + additional_messages: Optional[list] = Field(default=[], sa_column=Column(JSON)) # 消息列表 + max_completion_tokens: Optional[int] = None # 最大完成长度 + max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度 + truncation_strategy: Optional[dict] = Field(default={}, sa_column=Column(JSON)) # 截断策略 + response_format: Optional[str] = Field(default=None) # 返回格式 + tool_choice: Optional[str] = Field(default=None) # 工具选择 + temperature: Optional[float] = Field(default=None) # 温度 + top_p: Optional[float] = Field(default=None) # top_p @root_validator() def root_validator(cls, data: Any): diff --git a/app/services/message/message.py b/app/services/message/message.py index 4f40f0b..2d35f0a 100644 --- a/app/services/message/message.py +++ b/app/services/message/message.py @@ -124,3 +124,25 @@ async def copy_messages(*, session: AsyncSession, from_thread_id: str, to_thread ) session.add(new_message) await session.commit() + + @staticmethod + async def create_messages(*, session: AsyncSession, thread_id: str, run_id: str, assistant_id: str, messages: list): + for original_message in messages: + content = [ + { + "type": "text", + "text": {"value": original_message["content"], "annotations": []}, + } + ] + + new_message = Message.model_validate( + original_message, + update={ + "thread_id": thread_id, + "run_id": run_id, + "assistant_id": assistant_id, + "content": content, + "role": original_message["role"], + }, + ) + session.add(new_message) diff --git a/app/services/run/run.py b/app/services/run/run.py index 9ee59ff..d960870 100644 --- a/app/services/run/run.py +++ b/app/services/run/run.py @@ -11,6 +11,7 @@ from app.schemas.runs import SubmitToolOutputsRunRequest from app.schemas.threads import CreateThreadAndRun from app.services.assistant.assistant import AssistantService +from app.services.message.message import MessageService from app.services.thread.thread import ThreadService @@ -34,9 +35,24 @@ async def create_run( body.tools = db_asst.tools if not body.extra_body and db_asst.extra_body: body.extra_body = db_asst.extra_body + if not body.temperature and db_asst.temperature: + body.temperature = db_asst.temperature + if not body.top_p and db_asst.top_p: + body.top_p = db_asst.top_p # create run db_run = Run.model_validate(body.model_dump(), update={"thread_id": thread_id, "file_ids": db_asst.file_ids}) session.add(db_run) + session.refresh(db_run) + run_id = db_run.id + if body.additional_messages: + # create messages + await MessageService.create_messages( + session=session, + thread_id=thread_id, + run_id=str(run_id), + assistant_id=body.assistant_id, + messages=body.additional_messages, + ) await session.commit() await session.refresh(db_run) return db_run diff --git a/migrations/versions/2024-04-22-17-19_aa4bda3363e3.py b/migrations/versions/2024-04-22-17-19_aa4bda3363e3.py new file mode 100644 index 0000000..c3f1c33 --- /dev/null +++ b/migrations/versions/2024-04-22-17-19_aa4bda3363e3.py @@ -0,0 +1,56 @@ +"""update models + +Revision ID: aa4bda3363e3 +Revises: 8dbb8f38ef77 +Create Date: 2024-04-22 17:19:59.829072 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = "aa4bda3363e3" +down_revision: Union[str, None] = "8dbb8f38ef77" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("assistant", sa.Column("response_format", sqlmodel.sql.sqltypes.AutoString(), nullable=True)) + op.add_column("assistant", sa.Column("tool_resources", sa.JSON(), nullable=True)) + op.add_column("assistant", sa.Column("temperature", sa.Float(), nullable=True)) + op.add_column("assistant", sa.Column("top_p", sa.Float(), nullable=True)) + op.add_column("run", sa.Column("incomplete_details", sqlmodel.sql.sqltypes.AutoString(), nullable=True)) + op.add_column("run", sa.Column("max_completion_tokens", sa.Integer(), nullable=True)) + op.add_column("run", sa.Column("max_prompt_tokens", sa.Integer(), nullable=True)) + op.add_column("run", sa.Column("response_format", sqlmodel.sql.sqltypes.AutoString(), nullable=True)) + op.add_column("run", sa.Column("tool_choice", sqlmodel.sql.sqltypes.AutoString(), nullable=True)) + op.add_column("run", sa.Column("truncation_strategy", sa.JSON(), nullable=True)) + op.add_column("run", sa.Column("usage", sa.JSON(), nullable=True)) + op.add_column("run", sa.Column("temperature", sa.Float(), nullable=True)) + op.add_column("run", sa.Column("top_p", sa.Float(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("run", "top_p") + op.drop_column("run", "temperature") + op.drop_column("run", "usage") + op.drop_column("run", "truncation_strategy") + op.drop_column("run", "tool_choice") + op.drop_column("run", "response_format") + op.drop_column("run", "max_prompt_tokens") + op.drop_column("run", "max_completion_tokens") + op.drop_column("run", "incomplete_details") + op.drop_column("assistant", "top_p") + op.drop_column("assistant", "temperature") + op.drop_column("assistant", "tool_resources") + op.drop_column("assistant", "response_format") + # ### end Alembic commands ### diff --git a/tests/e2e/assistant_test.py b/tests/e2e/assistant_test.py index a0df57d..89af60f 100644 --- a/tests/e2e/assistant_test.py +++ b/tests/e2e/assistant_test.py @@ -2,11 +2,10 @@ from app.models.assistant import Assistant from app.providers.database import session -from app.services.assistant.assistant import AssistantService # 测试创建动作 -def test_create_assistant(): +def test_create_assistant_with_extra_body(): client = openai.OpenAI(base_url="http://localhost:8086/api/v1", api_key="xxx") assistant = client.beta.assistants.create( name="Assistant Demo", @@ -43,3 +42,45 @@ def test_create_assistant(): "top_p": 1, } } + session.close() + + +def test_create_assistant_with_temperature_and_top_p(): + client = openai.OpenAI(base_url="http://localhost:8086/api/v1", api_key="xxx") + assistant = client.beta.assistants.create( + name="Assistant Demo", + instructions="你是一个有用的助手", + temperature=1, + top_p=1, + # https://platform.openai.com/docs/api-reference/chat/create 具体参数看这里 + model="gpt-3.5-turbo-1106", + ) + query = session.query(Assistant).filter(Assistant.id == assistant.id) + assistant = query.one() + assert assistant.name == "Assistant Demo" + assert assistant.instructions == "你是一个有用的助手" + assert assistant.model == "gpt-3.5-turbo-1106" + assert assistant.temperature == 1 + assert assistant.top_p == 1 + session.close() + + +def test_update_assistant_with_temperature_and_top_p(): + client = openai.OpenAI(base_url="http://localhost:8086/api/v1", api_key="xxx") + assistant = client.beta.assistants.create( + name="Assistant Demo", + instructions="你是一个有用的助手", + temperature=1, + top_p=1, + # https://platform.openai.com/docs/api-reference/chat/create 具体参数看这里 + model="gpt-3.5-turbo-1106", + ) + assistant = client.beta.assistants.update(assistant.id, temperature=2, top_p=0.9) + query = session.query(Assistant).filter(Assistant.id == assistant.id) + assistant = query.one() + assert assistant.name == "Assistant Demo" + assert assistant.instructions == "你是一个有用的助手" + assert assistant.model == "gpt-3.5-turbo-1106" + assert assistant.temperature == 2 + assert assistant.top_p == 0.9 + session.close() diff --git a/tests/e2e/run_test.py b/tests/e2e/run_test.py new file mode 100644 index 0000000..df53769 --- /dev/null +++ b/tests/e2e/run_test.py @@ -0,0 +1,48 @@ +import openai + +from app.models.message import Message +from app.models.run import Run +from app.providers.database import session + + +# 测试创建动作 +def test_create_run_with_additional_messages_and_other_parmas(): + client = openai.OpenAI(base_url="http://localhost:8086/api/v1", api_key="xxx") + assistant = client.beta.assistants.create( + name="Assistant Demo", + instructions="你是一个有用的助手", + model="gpt-3.5-turbo-1106", + ) + thread = client.beta.threads.create() + run = client.beta.threads.runs.create( + thread_id=thread.id, + assistant_id=assistant.id, + instructions="", + additional_messages=[ + {"role": "user", "content": "100 + 100 等于多少"}, + {"role": "assistant", "content": "100 + 100 等于200"}, + {"role": "user", "content": "如果是乘是多少呢?"}, + ], + max_completion_tokens=100, + max_prompt_tokens=100, + temperature=0.5, + top_p=0.5, + ) + query = session.query(Run).filter(Run.id == run.id) + run = query.one() + assert run.instructions == "你是一个有用的助手" + assert run.model == "gpt-3.5-turbo-1106" + query = session.query(Message).filter(Message.run_id == run.id).order_by(Message.created_at) + messages = query.all() + [messgae1, messgae2, messgae3] = messages + assert messgae1.content == [{"text": {"value": "100 + 100 等于多少", "annotations": []}, "type": "text"}] + assert messgae1.role == "user" + assert messgae2.content == [{"text": {"value": "100 + 100 等于200", "annotations": []}, "type": "text"}] + assert messgae2.role == "assistant" + assert messgae3.content == [{"text": {"value": "如果是乘是多少呢?", "annotations": []}, "type": "text"}] + assert messgae3.role == "user" + assert run.max_completion_tokens == 100 + assert run.max_prompt_tokens == 100 + assert run.temperature == 0.5 + assert run.top_p == 0.5 + session.close()