From 29e9fced897bef0263d7ded6d3ca2e028010004f Mon Sep 17 00:00:00 2001 From: liuooo Date: Fri, 15 Nov 2024 19:04:30 +0800 Subject: [PATCH] add stream_options in streaming mode --- app/core/runner/llm_backend.py | 5 +++ app/core/runner/llm_callback_handler.py | 4 +++ app/core/runner/pub_handler.py | 18 +++++++++++ app/core/runner/thread_runner.py | 1 + app/models/run.py | 2 ++ examples/run_assistant_stream.py | 10 ++++++ .../versions/2024-11-15-18-30_5b2b73d0fdf6.py | 31 +++++++++++++++++++ 7 files changed, 71 insertions(+) create mode 100644 migrations/versions/2024-11-15-18-30_5b2b73d0fdf6.py diff --git a/app/core/runner/llm_backend.py b/app/core/runner/llm_backend.py index ed9e848..facbc44 100644 --- a/app/core/runner/llm_backend.py +++ b/app/core/runner/llm_backend.py @@ -22,6 +22,7 @@ def run( tools: List = None, tool_choice="auto", stream=False, + stream_options=None, extra_body=None, temperature=None, top_p=None, @@ -38,6 +39,10 @@ def run( if "n" in model_params: raise ValueError("n is not allowed in model_params") chat_params.update(model_params) + if stream_options: + if isinstance(stream_options, dict): + if "include_usage" in stream_options: + chat_params["stream_options"] = {"include_usage": bool(stream_options["include_usage"])} if temperature: chat_params["temperature"] = temperature if top_p: diff --git a/app/core/runner/llm_callback_handler.py b/app/core/runner/llm_callback_handler.py index 878b018..05f3805 100644 --- a/app/core/runner/llm_callback_handler.py +++ b/app/core/runner/llm_callback_handler.py @@ -41,6 +41,10 @@ def handle_llm_response( for chunk in response_stream: logging.debug(chunk) + if chunk.usage: + self.event_handler.pub_message_usage(chunk) + continue + if not chunk.choices: continue diff --git a/app/core/runner/pub_handler.py b/app/core/runner/pub_handler.py index 71d7836..4ddf122 100644 --- a/app/core/runner/pub_handler.py +++ b/app/core/runner/pub_handler.py @@ -214,6 +214,24 @@ def pub_message_in_progress(self, message): events.ThreadMessageInProgress(data=_data_adjust_message(message), event="thread.message.in_progress") ) + def pub_message_usage(self, chunk): + """ + 目前 stream 未有 usage 相关 event,借用 thread.message.in_progress 进行传输,待官方更新 + """ + data = { + "id": chunk.id, + "content": [], + "created_at": 0, + "object": "thread.message", + "role": "assistant", + "status": "in_progress", + "thread_id": "", + "metadata": {"usage": chunk.usage.json()} + } + self.pub_event( + events.ThreadMessageInProgress(data=data, event="thread.message.in_progress") + ) + def pub_message_completed(self, message): self.pub_event( events.ThreadMessageCompleted(data=_data_adjust_message(message), event="thread.message.completed") diff --git a/app/core/runner/thread_runner.py b/app/core/runner/thread_runner.py index 50828fc..5f2f518 100644 --- a/app/core/runner/thread_runner.py +++ b/app/core/runner/thread_runner.py @@ -129,6 +129,7 @@ def __run_step( tools=[tool.openai_function for tool in tools], tool_choice="auto" if len(run_steps) < self.max_step else "none", stream=True, + stream_options=run.stream_options, extra_body=run.extra_body, temperature=run.temperature, top_p=run.top_p, diff --git a/app/models/run.py b/app/models/run.py index 5ebb2ba..bc34bea 100644 --- a/app/models/run.py +++ b/app/models/run.py @@ -49,6 +49,7 @@ class RunBase(BaseModel): failed_at: Optional[datetime] = Field(default=None) additional_instructions: Optional[str] = Field(default=None, max_length=32768, sa_column=Column(TEXT)) extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON)) + stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON)) incomplete_details: Optional[str] = Field(default=None) # 未完成详情 max_completion_tokens: Optional[int] = Field(default=None) # 最大完成长度 max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度 @@ -74,6 +75,7 @@ class RunCreate(BaseModel): tools: Optional[list] = [] extra_body: Optional[dict[str, Union[dict[str, Union[Authentication, Any]], Any]]] = {} stream: Optional[bool] = False + stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON)) additional_messages: Optional[list[MessageCreate]] = Field(default=[], sa_column=Column(JSON)) # 消息列表 max_completion_tokens: Optional[int] = None # 最大完成长度 max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度 diff --git a/examples/run_assistant_stream.py b/examples/run_assistant_stream.py index bfb9079..c728ed2 100644 --- a/examples/run_assistant_stream.py +++ b/examples/run_assistant_stream.py @@ -4,6 +4,8 @@ import logging from openai import AssistantEventHandler +from openai.types.beta import AssistantStreamEvent +from openai.types.beta.assistant_stream_event import ThreadMessageInProgress from openai.types.beta.threads.message import Message from openai.types.beta.threads.runs import ToolCall, ToolCallDelta @@ -47,6 +49,11 @@ def on_text_delta(self, delta, snapshot) -> None: def on_text_done(self, text) -> None: logging.info("text done: %s\n", text) + def on_event(self, event: AssistantStreamEvent) -> None: + if isinstance(event, ThreadMessageInProgress): + logging.info("event: %s\n", event) + + if __name__ == "__main__": assistant = client.beta.assistants.create( name="Assistant Demo", @@ -70,5 +77,8 @@ def on_text_done(self, text) -> None: thread_id=thread.id, assistant_id=assistant.id, event_handler=event_handler, + extra_body={ + "stream_options": {"include_usage": True} + } ) as stream: stream.until_done() \ No newline at end of file diff --git a/migrations/versions/2024-11-15-18-30_5b2b73d0fdf6.py b/migrations/versions/2024-11-15-18-30_5b2b73d0fdf6.py new file mode 100644 index 0000000..bd2d7c4 --- /dev/null +++ b/migrations/versions/2024-11-15-18-30_5b2b73d0fdf6.py @@ -0,0 +1,31 @@ +"""empty message + +Revision ID: 5b2b73d0fdf6 +Revises: b217fafdb5f0 +Create Date: 2024-11-15 18:30:43.391344 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = '5b2b73d0fdf6' +down_revision: Union[str, None] = 'b217fafdb5f0' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('run', sa.Column('stream_options', sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('run', 'stream_options') + # ### end Alembic commands ###