Skip to content

Commit

Permalink
Merge pull request #58 from YangZhiBoGreenHand/yzb/feat/update-run-as…
Browse files Browse the repository at this point in the history
…sistant-model

feat: update run and assistant model
  • Loading branch information
liuooo authored Apr 25, 2024
2 parents 898f6ad + fc9e3f8 commit 81269ae
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 3 deletions.
14 changes: 13 additions & 1 deletion app/core/runner/llm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@ def __init__(self, base_url: str, api_key) -> None:
self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)

def run(
self, messages: List, model: str, tools: List = None, tool_choice="auto", stream=False, extra_body=None
self,
messages: List,
model: str,
tools: List = None,
tool_choice="auto",
stream=False,
extra_body=None,
temperature=None,
top_p=None,
) -> ChatCompletion | Stream[ChatCompletionChunk]:
chat_params = {
"messages": messages,
Expand All @@ -29,6 +37,10 @@ def run(
if "n" in model_params:
raise ValueError("n is not allowed in model_params")
chat_params.update(model_params)
if temperature:
chat_params["temperature"] = temperature
if top_p:
chat_params["top_p"] = top_p
if tools:
chat_params["tools"] = tools
chat_params["tool_choice"] = tool_choice if tool_choice else "auto"
Expand Down
2 changes: 2 additions & 0 deletions app/core/runner/thread_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def __run_step(self, llm: LLMBackend, run: Run, run_steps: List[RunStep], instru
tool_choice="auto" if len(run_steps) < self.max_step else "none",
stream=True,
extra_body=run.extra_body,
temperature=run.temperature,
top_p=run.top_p,
)

# create message callback
Expand Down
8 changes: 8 additions & 0 deletions app/models/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ class AssistantBase(BaseModel):
name: Optional[str] = Field(default=None)
tools: Optional[list] = Field(default=None, sa_column=Column(JSON))
extra_body: Optional[dict] = Field(default=None, sa_column=Column(JSON))
response_format: Optional[str] = Field(default=None) # 响应格式
tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 工具资源
temperature: Optional[float] = Field(default=None) # 温度
top_p: Optional[float] = Field(default=None) # top_p


class Assistant(AssistantBase, PrimaryKeyMixin, TimeStampMixin, table=True):
Expand All @@ -34,3 +38,7 @@ class AssistantUpdate(BaseModel):
name: Optional[str] = Field(default=None)
tools: Optional[list] = Field(default=None, sa_column=Column(JSON))
extra_body: Optional[dict] = Field(default=None, sa_column=Column(JSON))
response_format: Optional[str] = Field(default=None) # 响应格式
tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 工具资源
temperature: Optional[float] = Field(default=None) # 温度
top_p: Optional[float] = Field(default=None) # top_p
17 changes: 17 additions & 0 deletions app/models/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ class Run(BaseModel, PrimaryKeyMixin, TimeStampMixin, table=True):
failed_at: Optional[Timestamp] = Field(default=None)
additional_instructions: Optional[str] = Field(default=None, max_length=32768, sa_column=Column(TEXT))
extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
incomplete_details: Optional[str] = Field(default=None) # 未完成详情
max_completion_tokens: Optional[int] = Field(default=None) # 最大完成长度
max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度
response_format: Optional[str] = Field(default=None) # 返回格式
tool_choice: Optional[str] = Field(default=None) # 工具选择
truncation_strategy: Optional[dict] = Field(default={}, sa_column=Column(JSON)) # 截断策略
usage: Optional[dict] = Field(default={}, sa_column=Column(JSON)) # 调用使用情况
temperature: Optional[float] = Field(default=None) # 温度
top_p: Optional[float] = Field(default=None) # top_p


class RunRead(Run): ...
Expand All @@ -62,6 +71,14 @@ class RunCreate(BaseModel):
tools: Optional[list] = []
extra_body: Optional[dict[str, Union[dict[str, Union[Authentication, Any]], Any]]] = {}
stream: Optional[bool] = False
additional_messages: Optional[list] = Field(default=[], sa_column=Column(JSON)) # 消息列表
max_completion_tokens: Optional[int] = None # 最大完成长度
max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度
truncation_strategy: Optional[dict] = Field(default={}, sa_column=Column(JSON)) # 截断策略
response_format: Optional[str] = Field(default=None) # 返回格式
tool_choice: Optional[str] = Field(default=None) # 工具选择
temperature: Optional[float] = Field(default=None) # 温度
top_p: Optional[float] = Field(default=None) # top_p

@root_validator()
def root_validator(cls, data: Any):
Expand Down
22 changes: 22 additions & 0 deletions app/services/message/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,25 @@ async def copy_messages(*, session: AsyncSession, from_thread_id: str, to_thread
)
session.add(new_message)
await session.commit()

@staticmethod
async def create_messages(*, session: AsyncSession, thread_id: str, run_id: str, assistant_id: str, messages: list):
for original_message in messages:
content = [
{
"type": "text",
"text": {"value": original_message["content"], "annotations": []},
}
]

new_message = Message.model_validate(
original_message,
update={
"thread_id": thread_id,
"run_id": run_id,
"assistant_id": assistant_id,
"content": content,
"role": original_message["role"],
},
)
session.add(new_message)
16 changes: 16 additions & 0 deletions app/services/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from app.schemas.runs import SubmitToolOutputsRunRequest
from app.schemas.threads import CreateThreadAndRun
from app.services.assistant.assistant import AssistantService
from app.services.message.message import MessageService
from app.services.thread.thread import ThreadService


Expand All @@ -34,9 +35,24 @@ async def create_run(
body.tools = db_asst.tools
if not body.extra_body and db_asst.extra_body:
body.extra_body = db_asst.extra_body
if not body.temperature and db_asst.temperature:
body.temperature = db_asst.temperature
if not body.top_p and db_asst.top_p:
body.top_p = db_asst.top_p
# create run
db_run = Run.model_validate(body.model_dump(), update={"thread_id": thread_id, "file_ids": db_asst.file_ids})
session.add(db_run)
session.refresh(db_run)
run_id = db_run.id
if body.additional_messages:
# create messages
await MessageService.create_messages(
session=session,
thread_id=thread_id,
run_id=str(run_id),
assistant_id=body.assistant_id,
messages=body.additional_messages,
)
await session.commit()
await session.refresh(db_run)
return db_run
Expand Down
56 changes: 56 additions & 0 deletions migrations/versions/2024-04-22-17-19_aa4bda3363e3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""update models
Revision ID: aa4bda3363e3
Revises: 8dbb8f38ef77
Create Date: 2024-04-22 17:19:59.829072
"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
import sqlmodel


# revision identifiers, used by Alembic.
revision: str = "aa4bda3363e3"
down_revision: Union[str, None] = "8dbb8f38ef77"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("assistant", sa.Column("response_format", sqlmodel.sql.sqltypes.AutoString(), nullable=True))
op.add_column("assistant", sa.Column("tool_resources", sa.JSON(), nullable=True))
op.add_column("assistant", sa.Column("temperature", sa.Float(), nullable=True))
op.add_column("assistant", sa.Column("top_p", sa.Float(), nullable=True))
op.add_column("run", sa.Column("incomplete_details", sqlmodel.sql.sqltypes.AutoString(), nullable=True))
op.add_column("run", sa.Column("max_completion_tokens", sa.Integer(), nullable=True))
op.add_column("run", sa.Column("max_prompt_tokens", sa.Integer(), nullable=True))
op.add_column("run", sa.Column("response_format", sqlmodel.sql.sqltypes.AutoString(), nullable=True))
op.add_column("run", sa.Column("tool_choice", sqlmodel.sql.sqltypes.AutoString(), nullable=True))
op.add_column("run", sa.Column("truncation_strategy", sa.JSON(), nullable=True))
op.add_column("run", sa.Column("usage", sa.JSON(), nullable=True))
op.add_column("run", sa.Column("temperature", sa.Float(), nullable=True))
op.add_column("run", sa.Column("top_p", sa.Float(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("run", "top_p")
op.drop_column("run", "temperature")
op.drop_column("run", "usage")
op.drop_column("run", "truncation_strategy")
op.drop_column("run", "tool_choice")
op.drop_column("run", "response_format")
op.drop_column("run", "max_prompt_tokens")
op.drop_column("run", "max_completion_tokens")
op.drop_column("run", "incomplete_details")
op.drop_column("assistant", "top_p")
op.drop_column("assistant", "temperature")
op.drop_column("assistant", "tool_resources")
op.drop_column("assistant", "response_format")
# ### end Alembic commands ###
45 changes: 43 additions & 2 deletions tests/e2e/assistant_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from app.models.assistant import Assistant
from app.providers.database import session
from app.services.assistant.assistant import AssistantService


# 测试创建动作
def test_create_assistant():
def test_create_assistant_with_extra_body():
client = openai.OpenAI(base_url="http://localhost:8086/api/v1", api_key="xxx")
assistant = client.beta.assistants.create(
name="Assistant Demo",
Expand Down Expand Up @@ -43,3 +42,45 @@ def test_create_assistant():
"top_p": 1,
}
}
session.close()


def test_create_assistant_with_temperature_and_top_p():
client = openai.OpenAI(base_url="http://localhost:8086/api/v1", api_key="xxx")
assistant = client.beta.assistants.create(
name="Assistant Demo",
instructions="你是一个有用的助手",
temperature=1,
top_p=1,
# https://platform.openai.com/docs/api-reference/chat/create 具体参数看这里
model="gpt-3.5-turbo-1106",
)
query = session.query(Assistant).filter(Assistant.id == assistant.id)
assistant = query.one()
assert assistant.name == "Assistant Demo"
assert assistant.instructions == "你是一个有用的助手"
assert assistant.model == "gpt-3.5-turbo-1106"
assert assistant.temperature == 1
assert assistant.top_p == 1
session.close()


def test_update_assistant_with_temperature_and_top_p():
client = openai.OpenAI(base_url="http://localhost:8086/api/v1", api_key="xxx")
assistant = client.beta.assistants.create(
name="Assistant Demo",
instructions="你是一个有用的助手",
temperature=1,
top_p=1,
# https://platform.openai.com/docs/api-reference/chat/create 具体参数看这里
model="gpt-3.5-turbo-1106",
)
assistant = client.beta.assistants.update(assistant.id, temperature=2, top_p=0.9)
query = session.query(Assistant).filter(Assistant.id == assistant.id)
assistant = query.one()
assert assistant.name == "Assistant Demo"
assert assistant.instructions == "你是一个有用的助手"
assert assistant.model == "gpt-3.5-turbo-1106"
assert assistant.temperature == 2
assert assistant.top_p == 0.9
session.close()
48 changes: 48 additions & 0 deletions tests/e2e/run_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import openai

from app.models.message import Message
from app.models.run import Run
from app.providers.database import session


# 测试创建动作
def test_create_run_with_additional_messages_and_other_parmas():
client = openai.OpenAI(base_url="http://localhost:8086/api/v1", api_key="xxx")
assistant = client.beta.assistants.create(
name="Assistant Demo",
instructions="你是一个有用的助手",
model="gpt-3.5-turbo-1106",
)
thread = client.beta.threads.create()
run = client.beta.threads.runs.create(
thread_id=thread.id,
assistant_id=assistant.id,
instructions="",
additional_messages=[
{"role": "user", "content": "100 + 100 等于多少"},
{"role": "assistant", "content": "100 + 100 等于200"},
{"role": "user", "content": "如果是乘是多少呢?"},
],
max_completion_tokens=100,
max_prompt_tokens=100,
temperature=0.5,
top_p=0.5,
)
query = session.query(Run).filter(Run.id == run.id)
run = query.one()
assert run.instructions == "你是一个有用的助手"
assert run.model == "gpt-3.5-turbo-1106"
query = session.query(Message).filter(Message.run_id == run.id).order_by(Message.created_at)
messages = query.all()
[messgae1, messgae2, messgae3] = messages
assert messgae1.content == [{"text": {"value": "100 + 100 等于多少", "annotations": []}, "type": "text"}]
assert messgae1.role == "user"
assert messgae2.content == [{"text": {"value": "100 + 100 等于200", "annotations": []}, "type": "text"}]
assert messgae2.role == "assistant"
assert messgae3.content == [{"text": {"value": "如果是乘是多少呢?", "annotations": []}, "type": "text"}]
assert messgae3.role == "user"
assert run.max_completion_tokens == 100
assert run.max_prompt_tokens == 100
assert run.temperature == 0.5
assert run.top_p == 0.5
session.close()

0 comments on commit 81269ae

Please sign in to comment.